我是tensorflow的新手,我尝试使用以下示例代码来运行预测函数。
https://www.tensorflow.org/tutorials/sequences/recurrent_quickdraw#loss_predictions_and_optimizer
使用tensorflow 1.12
我在model_fn中放置了以下行(功能,标签,模式,参数):
# Set None value for predict mode
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = tf.argmax(logits, axis=1)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions={"logits": logits, "predictions": predictions})
调用以下函数以运行预测
def run_predict(self, estimator):
result = estimator.predict(
input_fn=get_input_fn(mode=tf.estimator.ModeKeys.PREDICT,
batch_size= g_config['batch_size']))
print(next(result))
return result
def predict(self):
estimator = self.create_estimator(
run_config = tf.estimator.RunConfig(
model_dir = self.model_path))
result = self.run_predict(estimator)
print(next(result))
但是我遇到了这些错误
Traceback (most recent call last):
File "learn_dat.py", line 116, in <module>
main(sys.argv)
File "learn_dat.py", line 105, in main
rnn.predict_data(league, season, date_str)
File "/root/work/breg_new/lib/prdc/rnn_model.py", line 438, in predict_data
model.predict()
File "/root/work/breg_new/lib/prdc/rnn_model.py", line 141, in predict
result = self.run_predict(estimator)
File "/root/work/breg_new/lib/prdc/rnn_model.py", line 105, in run_predict
print(next(result))
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 549, in predict
input_fn, model_fn_lib.ModeKeys.PREDICT)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 1024, in _get_features_from_input_fn
result = self._call_input_fn(input_fn, mode)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 1136, in _call_input_fn
return input_fn(**kwargs)
File "/root/work/breg_new/lib/prdc/rnn_model.py", line 257, in _input_fn
num_parallel_calls=10)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1007, in map
return ParallelMapDataset(self, map_func, num_parallel_calls)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 2248, in __init__
super(ParallelMapDataset, self).__init__(input_dataset, map_func)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 2216, in __init__
map_func, "Dataset.map()", input_dataset)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1473, in __init__
self._function.add_to_graph(ops.get_default_graph())
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 479, in add_to_graph
self._create_definition_if_needed()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 335, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 344, in _create_definition_if_needed_impl
self._capture_by_value, self._caller_device)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 865, in func_graph_from_py_func
outputs = func(*func_graph.inputs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1456, in tf_data_structured_function_wrapper
"%s: %s." % (transformation_name, t))
TypeError: Unsupported return value from function passed to Dataset.map(): None.