WideNDeep教程代码

时间:2016-09-07 21:54:16

标签: tensorflow deep-learning

关于WideNDeep tutorial中的这行代码:

m.fit(input_fn=lambda: input_fn(df_train), steps=FLAGS.train_steps)

用于训练深层模型的batch_size是什么? 目前,在我看来,该模型不是批量训练的?有没有默认的batch_size?

由于

2 个答案:

答案 0 :(得分:0)

您可以将batch_size作为参数传递给fit。 See the documentation on BaseEstimator.fit

答案 1 :(得分:0)

我更改了本教程以进行批处理,如下所示:

  1. 将CSV数据转换为TensorFlow格式(TFRecord文件中的示例);然后
  2. 从tf.contrib.learn(队列)中提供的read_batch函数创建input_fn
  3. 这是我使用的代码:

    https://gist.github.com/cirocavani/7d9e827102093139acd400b02d2e7afb

    input_fn是这样的:

    def input_fn(mode, data_file, batch_size):
        input_features = create_feature_columns()
        features = tf.contrib.layers.create_feature_spec_for_parsing(input_features)
    
        feature_map = tf.contrib.learn.io.read_batch_record_features(
            file_pattern=[data_file],
            batch_size=batch_size,
            features=features,
            name="read_batch_features_{}".format(mode))
    
        target = feature_map.pop("label")
    
        return feature_map, target
    

    我认为它会有一个更简单的解决方案,但我不知道TensorFlow还提供了一个:)