我正在使用最新版本的Keras / Tensorflow,但出现此错误:
predict_fcn = K.function(model.inputs, model.outputs)
predict_fcn(input_values_test)
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-17-8d6b6ae9c940> in <module>()
----> 1 predict_fcn(input_values_test)
~/Python_Libraries/keras/keras/backend/tensorflow_backend.py in __call__(self, inputs)
2664 return self._legacy_call(inputs)
2665
-> 2666 return self._call(inputs)
2667 else:
2668 if py_any(is_tensor(x) for x in inputs):
~/Python_Libraries/keras/keras/backend/tensorflow_backend.py in _call(self, inputs)
2634 symbol_vals,
2635 session)
-> 2636 fetched = self._callable_fn(*array_vals)
2637 return fetched[:len(self.outputs)]
2638
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __call__(self, *args)
1452 else:
1453 return tf_session.TF_DeprecatedSessionRunCallable(
-> 1454 self._session._session, self._handle, args, status, None)
1455
1456 def __del__(self):
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
517 None, None,
518 compat.as_text(c_api.TF_Message(self.status.status)),
--> 519 c_api.TF_GetCode(self.status.status))
520 # Delete the underlying status object from memory otherwise it stays alive
521 # as there is a reference to status from this from the traceback due to
InvalidArgumentError: ConcatOp : Expected concatenating dimensions in the range [-2, 2), but got 2
[[Node: Merge_Embeddings/concat = ConcatV2[N=8, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:GPU:0"](0_Embedding/embedding_lookup, 1_Embedding/embedding_lookup, 2_Embedding/embedding_lookup, 3_Embedding/embedding_lookup, 4_Embedding/embedding_lookup, 5_Embedding/embedding_lookup, 6_Embedding/embedding_lookup, Embedding_Average/Mean, Merge_Embeddings/concat/axis)]]
[[Node: Outputs/BiasAdd/_989 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1274_Outputs/BiasAdd", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
但是,如果简单地执行model.predict(input_values_test)
,则效果很好。仔细阅读错误消息,似乎在将模型作为函数而不是模型执行时,串联行为是不同的吗?