如何修复与tensorflow,keras中的model.predict()有关的未知错误-Python Chatbot

时间:2019-01-22 05:45:11

标签: python tensorflow keras jupyter-notebook chatbot

我正在使用带有tensorflow的python聊天机器人,遇到了错误。训练完模型后,我试图通过测试模型来预测输出。但是,我收到一个不确定的错误消息:

        output = model.predict([encoder_input, decoder_input]).argmax(axis=2)

下面是所有适当的代码:

#Testing models
def prediction(raw_input):
    clean_input = clean_text(raw_input)
    input_tok = [nltk.word_tokenize(clean_input)]
    input_tok = [input_tok[0][::-1]] #reversing input seq
    encoder_input = transform(encoding, input_tok, 20) 
    decoder_input = np.zeros(shape = (len(encoder_input),OUTPUT_LENGTH))
    decoder_input[:,0] = WORD_CODE_START
    for i in range(1, OUTPUT_LENGTH):
        output = model.predict([encoder_input, decoder_input]).argmax(axis=2)
        decoder_input[:,i] = output[:,i]
    return output

def decode(decoding, vector):
    text = ''
    for i in vector:
        if i == 0:
            break
        text += ' '
        text += decoding[i]
    return text

for i in range(20):
    seq_index = np.random.randint(1,len(short_questions))
    print(short_questions[seq_index])
    output = prediction(short_questions[seq_index])
    print("Q: ", short_questions[seq_index])
    print("A: ", decode(decoding, output[0]))

这是错误:

---------------------------------------------------------------------------
NotFoundError                             Traceback (most recent call last)
<ipython-input-59-a18a053992fe> in <module>
      2     seq_index = np.random.randint(1,len(short_questions))
      3     print(short_questions[seq_index])
----> 4     output = prediction(short_questions[seq_index])
      5     print("Q: ", short_questions[seq_index])
      6     print("A: ", decode(decoding, output[0]))

<ipython-input-57-d2074d1c87eb> in prediction(raw_input)
      8     decoder_input[:,0] = WORD_CODE_START
      9     for i in range(1, OUTPUT_LENGTH):
---> 10         output = model.predict([encoder_input, decoder_input]).argmax(axis=2)
     11         decoder_input[:,i] = output[:,i]
     12     return output

/usr/local/lib/python3.6/site-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps)
   1167                                             batch_size=batch_size,
   1168                                             verbose=verbose,
-> 1169                                             steps=steps)
   1170 
   1171     def train_on_batch(self, x, y,

/usr/local/lib/python3.6/site-packages/keras/engine/training_arrays.py in predict_loop(model, f, ins, batch_size, verbose, steps)
    292                 ins_batch[i] = ins_batch[i].toarray()
    293 
--> 294             batch_outs = f(ins_batch)
    295             batch_outs = to_list(batch_outs)
    296             if batch_index == 0:

/usr/local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2713                 return self._legacy_call(inputs)
   2714 
-> 2715             return self._call(inputs)
   2716         else:
   2717             if py_any(is_tensor(x) for x in inputs):

/usr/local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
   2669                                 feed_symbols,
   2670                                 symbol_vals,
-> 2671                                 session)
   2672         if self.run_metadata:
   2673             fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)

/usr/local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session)
   2621             callable_opts.run_options.CopyFrom(self.run_options)
   2622         # Create callable.
-> 2623         callable_fn = session._make_callable_from_options(callable_opts)
   2624         # Cache parameters corresponding to the generated callable, so that
   2625         # we can detect future mismatches and refresh the callable.

/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _make_callable_from_options(self, callable_options)
   1469     """
   1470     self._extend_graph()
-> 1471     return BaseSession._Callable(self, callable_options)
   1472 
   1473 

/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, session, callable_options)
   1423         with errors.raise_exception_on_not_ok_status() as status:
   1424           self._handle = tf_session.TF_SessionMakeCallable(
-> 1425               session._session, options_ptr, status)
   1426       finally:
   1427         tf_session.TF_DeleteBuffer(options_ptr)

/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    526             None, None,
    527             compat.as_text(c_api.TF_Message(self.status.status)),
--> 528             c_api.TF_GetCode(self.status.status))
    529     # Delete the underlying status object from memory otherwise it stays alive
    530     # as there is a reference to status from this from the traceback due to

NotFoundError: PruneForTargets: Some target nodes not found: group_deps_1 

0 个答案:

没有答案