Tensorflow Wide& Deep教程示例批处理

时间:2016-08-05 06:31:45

标签: tensorflow deep-learning

我为Google和Tensorflow发布的新模型(即wide_n_deep学习)感到兴奋。所以我试图通过运行the tutorial example来玩它。

作为机器学习的一个常见技巧,当整个训练数据集很大时,批量学习很重要。所以我尝试修改提供的wide_n_deep学习教程示例以获得批量学习,如下所示:

index_in_epoch = 0
num_examples = df_train.shape[0]
for i in xrange(FLAGS.train_steps):
    startTime = datetime.now()
    print("start step %i" %i)
    start = index_in_epoch
    index_in_epoch += batch_size
    if index_in_epoch > num_examples:
        if start < num_examples:
          m.fit(input_fn=lambda: input_fn(df_train[start:num_examples], steps=1)
        df_train.reindex(np.random.permutation(df_train.index)
        start = 0
        index_in_epoch = batch_size
    if i%5 == 1:
        results = m.evaluate(input_fn=lambda: input_fn(df_test), steps = 1)
        for key in sorted(results):
          print("%s: %s %(key, results[key]))
    end = index_in_epoch
    m.fit(input_fn=lambda: input_fn(df_train[start:end], steps=1)

简单地说,我逐批迭代整个训练数据集,对于每一批,我称之为#34; fit&#34;重新训练模型的功能。

这种天真的策略的问题是处理时间太慢(例如,我们希望迭代400万个记录数据集100次,批量大小为100k,培训和评估时间会大约1周)。所以我真的怀疑我是以正确的方式使用批量学习。

如果有任何人才可以分享您在使用wide_n_deep学习模型时处理批量学习的经验,我将不胜感激。

1 个答案:

答案 0 :(得分:0)

每个fit / evaluate调用都会创建一个图形和一个会话,然后执行操作。如果你在循环中这样做,它会很慢。 为了加快速度,您需要提供一个input_fn,它将逐批称为张量。 如果您从数据框中读取数据,则可以使用to_feature_columns_and_input_fn 如果您从包含tf.Example的文件中读取数据,则可以在read_batch_examples中使用input_fn之类的内容。