OperatorNotAllowedInGraphError:在图形执行中不允许将tf.Tensor作为Python bool使用

时间:2020-01-25 22:53:14

标签: python tensorflow break

我正在尝试执行这些功能

def evaluate(sentence):
  sentence = preprocess_sentence(sentence)

  sentence = tf.expand_dims(
      START_TOKEN + tokenizer.encode(sentence) + END_TOKEN, axis=0)

  output = tf.expand_dims(START_TOKEN, 0)

  for i in range(MAX_LENGTH):

    predictions = model(inputs=[sentence, output], training=False)

    # select the last word from the seq_len dimension
    predictions = predictions[:, -1:, :]
    predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

    # return the result if the predicted_id is equal to the end token

    if tf.equal(predicted_id, END_TOKEN[0]):
      break
    #check()
    #tf.cond(tf.equal(predicted_id, END_TOKEN[0]),true_fn=break,false_fn=lambda: tf.no_op())


    # concatenated the predicted_id to the output which is given to the decoder
    # as its input.
    output = tf.concat([output, predicted_id], axis=-1)

  return tf.squeeze(output, axis=0)


def predict(sentence):
  prediction = evaluate(sentence)

  predicted_sentence = tokenizer.decode(
      [i for i in prediction if i < tokenizer.vocab_size])

  print('Input: {}'.format(sentence))
  print('Output: {}'.format(predicted_sentence))

  return predicted_sentence

但是,我遇到以下错误: OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function. 我确实知道我必须以tf.cond()的形式重写if条件。但是,我不知道如何在张量流中编写break,也不确定哪个条件导致了问题,因为此笔记本中的相同功能是否正常工作? https://colab.research.google.com/github/tensorflow/examples/blob/master/community/en/transformer_chatbot.ipynb#scrollTo=_NURhwYz5AXa 有帮助吗?

2 个答案:

答案 0 :(得分:0)

笔记本中的代码有效,因为它使用的是TF 2.0,默认情况下它启用了急切的执行功能。您可以使用tf.enable_eager_execution在旧版本中将其打开。

或者,如果您使用tf.function or tf.autograph,则可以在图表模式下使用break而不编写tf.cond,但是它们对您可以运行的代码有一些限制。

答案 1 :(得分:0)

break 语句没有任何问题。问题出在别处。

if tf.equal(predicted_id, END_TOKEN[0]):
   break

在张量操作中使用 Python bool 会出错。由于您已经使用了 tf.equal 条件,这可能会令人困惑。解决方法很简单。正在为

抛出错误

if (boolean):python 语法。

您必须注意这个(布尔)Python 语法并根据您计划实现的目标转换为张量样式。请记住,条件返回布尔值的张量。阅读这个张量,然后继续做你想做的事情......所以例如无论条件的值如何,以下都将无条件地工作:

if tf.equal(predicted_id, END_TOKEN[0]) is not None:
   break