我正在使用带有tensorflow的python聊天机器人,遇到了错误。训练完模型后,我试图通过测试模型来预测输出。但是,我收到一个不确定的错误消息:
output = model.predict([encoder_input, decoder_input]).argmax(axis=2)
下面是所有适当的代码:
#Testing models
def prediction(raw_input):
clean_input = clean_text(raw_input)
input_tok = [nltk.word_tokenize(clean_input)]
input_tok = [input_tok[0][::-1]] #reversing input seq
encoder_input = transform(encoding, input_tok, 20)
decoder_input = np.zeros(shape = (len(encoder_input),OUTPUT_LENGTH))
decoder_input[:,0] = WORD_CODE_START
for i in range(1, OUTPUT_LENGTH):
output = model.predict([encoder_input, decoder_input]).argmax(axis=2)
decoder_input[:,i] = output[:,i]
return output
def decode(decoding, vector):
text = ''
for i in vector:
if i == 0:
break
text += ' '
text += decoding[i]
return text
for i in range(20):
seq_index = np.random.randint(1,len(short_questions))
print(short_questions[seq_index])
output = prediction(short_questions[seq_index])
print("Q: ", short_questions[seq_index])
print("A: ", decode(decoding, output[0]))
这是错误:
---------------------------------------------------------------------------
NotFoundError Traceback (most recent call last)
<ipython-input-59-a18a053992fe> in <module>
2 seq_index = np.random.randint(1,len(short_questions))
3 print(short_questions[seq_index])
----> 4 output = prediction(short_questions[seq_index])
5 print("Q: ", short_questions[seq_index])
6 print("A: ", decode(decoding, output[0]))
<ipython-input-57-d2074d1c87eb> in prediction(raw_input)
8 decoder_input[:,0] = WORD_CODE_START
9 for i in range(1, OUTPUT_LENGTH):
---> 10 output = model.predict([encoder_input, decoder_input]).argmax(axis=2)
11 decoder_input[:,i] = output[:,i]
12 return output
/usr/local/lib/python3.6/site-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps)
1167 batch_size=batch_size,
1168 verbose=verbose,
-> 1169 steps=steps)
1170
1171 def train_on_batch(self, x, y,
/usr/local/lib/python3.6/site-packages/keras/engine/training_arrays.py in predict_loop(model, f, ins, batch_size, verbose, steps)
292 ins_batch[i] = ins_batch[i].toarray()
293
--> 294 batch_outs = f(ins_batch)
295 batch_outs = to_list(batch_outs)
296 if batch_index == 0:
/usr/local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
2713 return self._legacy_call(inputs)
2714
-> 2715 return self._call(inputs)
2716 else:
2717 if py_any(is_tensor(x) for x in inputs):
/usr/local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
2669 feed_symbols,
2670 symbol_vals,
-> 2671 session)
2672 if self.run_metadata:
2673 fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
/usr/local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session)
2621 callable_opts.run_options.CopyFrom(self.run_options)
2622 # Create callable.
-> 2623 callable_fn = session._make_callable_from_options(callable_opts)
2624 # Cache parameters corresponding to the generated callable, so that
2625 # we can detect future mismatches and refresh the callable.
/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _make_callable_from_options(self, callable_options)
1469 """
1470 self._extend_graph()
-> 1471 return BaseSession._Callable(self, callable_options)
1472
1473
/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, session, callable_options)
1423 with errors.raise_exception_on_not_ok_status() as status:
1424 self._handle = tf_session.TF_SessionMakeCallable(
-> 1425 session._session, options_ptr, status)
1426 finally:
1427 tf_session.TF_DeleteBuffer(options_ptr)
/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
526 None, None,
527 compat.as_text(c_api.TF_Message(self.status.status)),
--> 528 c_api.TF_GetCode(self.status.status))
529 # Delete the underlying status object from memory otherwise it stays alive
530 # as there is a reference to status from this from the traceback due to
NotFoundError: PruneForTargets: Some target nodes not found: group_deps_1