我正在使用保存的Tensorflow模型进行预测:
predict_fn = tf.contrib.predictor.from_saved_model('export/1570330026')
SavedModelPredictor with feed tensors {'input': <tf.Tensor 'serialized_example:0' shape=(?,) dtype=string>} and fetch_tensors {'outputs': <tf.Tensor 'transformer/while/Exit_3:0' shape=(?, ?) dtype=int64>, 'scores': <tf.Tensor 'transformer/while/Exit_36:0' shape=(?,) dtype=float32>, 'batch_prediction_key': <tf.Tensor 'Identity:0' shape=(?, 1) dtype=int32>}
输入应为<tf.Tensor 'serialized_example:0' shape=(?,) dtype=string>
为什么我下面的序列化输入出现错误?
data = "This is a test"
inputs = encoders["inputs"].encode(data) + [1] # add EOS id
example = tf.train.Example(features=tf.train.Features(feature={'inputs': tf.train.Feature(int64_list=tf.train.Int64List(value=inputs))}))
example = example.SerializeToString()
preprocessed_inputs = {'input': {"b64": base64.b64encode(example).decode()}}
predict_fn(preprocessed_inputs)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow_core/contrib/predictor/predictor.py", line 77, in __call__
return self._session.run(fetches=self.fetch_tensors, feed_dict=feed_dict)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow_core/python/client/session.py", line 956, in run
run_metadata_ptr)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow_core/python/client/session.py", line 1156, in _run
(np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape () for Tensor 'serialized_example:0', which has shape '(?,)'