调用predict
函数需要10GB的内存,这在我的GPU中是不可用的:
estimator = tf.contrib.learn.Estimator(model_fn=model_fn, model_dir=model_dir)
probs = estimator.predict(input_fn=lambda: my_input_fn(valid_records))
predict
函数的batch_size
参数在使用input_fn
时不可用。看来我有两个选择(让我知道是否有另一个):
input_fn
替换为x
参数,然后使用batch_size
参数。目前,我不知道该怎么做!答案 0 :(得分:0)
利用tf.train.batch函数批量传递数据。将input_fn函数传递给predict方法时,不要忘记设置as_iterable=True
。