如何创建利用batch_size的输入函数?

时间:2016-09-17 12:11:14

标签: tensorflow

调用predict函数需要10GB的内存,这在我的GPU中是不可用的:

estimator = tf.contrib.learn.Estimator(model_fn=model_fn, model_dir=model_dir)
probs = estimator.predict(input_fn=lambda: my_input_fn(valid_records))

predict函数的batch_size参数在使用input_fn时不可用。看来我有两个选择(让我知道是否有另一个):

  1. input_fn替换为x参数,然后使用batch_size参数。目前,我不知道该怎么做!
  2. 修改我的输入函数以返回不同批次的数据。我不知道怎么做!

1 个答案:

答案 0 :(得分:0)

利用tf.train.batch函数批量传递数据。将input_fn函数传递给predict方法时,不要忘记设置as_iterable=True