我为Google和Tensorflow发布的新模型(即wide_n_deep学习)感到兴奋。所以我试图通过运行the tutorial example来玩它。
作为机器学习的一个常见技巧,当整个训练数据集很大时,批量学习很重要。所以我尝试修改提供的wide_n_deep学习教程示例以获得批量学习,如下所示:
index_in_epoch = 0
num_examples = df_train.shape[0]
for i in xrange(FLAGS.train_steps):
startTime = datetime.now()
print("start step %i" %i)
start = index_in_epoch
index_in_epoch += batch_size
if index_in_epoch > num_examples:
if start < num_examples:
m.fit(input_fn=lambda: input_fn(df_train[start:num_examples], steps=1)
df_train.reindex(np.random.permutation(df_train.index)
start = 0
index_in_epoch = batch_size
if i%5 == 1:
results = m.evaluate(input_fn=lambda: input_fn(df_test), steps = 1)
for key in sorted(results):
print("%s: %s %(key, results[key]))
end = index_in_epoch
m.fit(input_fn=lambda: input_fn(df_train[start:end], steps=1)
简单地说,我逐批迭代整个训练数据集,对于每一批,我称之为#34; fit&#34;重新训练模型的功能。
这种天真的策略的问题是处理时间太慢(例如,我们希望迭代400万个记录数据集100次,批量大小为100k,培训和评估时间会大约1周)。所以我真的怀疑我是以正确的方式使用批量学习。
如果有任何人才可以分享您在使用wide_n_deep学习模型时处理批量学习的经验,我将不胜感激。
答案 0 :(得分:0)
每个fit / evaluate调用都会创建一个图形和一个会话,然后执行操作。如果你在循环中这样做,它会很慢。
为了加快速度,您需要提供一个input_fn
,它将逐批称为张量。
如果您从数据框中读取数据,则可以使用to_feature_columns_and_input_fn
如果您从包含tf.Example
的文件中读取数据,则可以在read_batch_examples
中使用input_fn
之类的内容。