Tensorflow 2.1 BERt嵌入-如果启用了Tensor相等,则该变量不可散列。而是使用tensor.experimental_ref()作为键

时间:2020-02-09 22:58:18

标签: python tensorflow

我正试图从tensorflow集线器加载following BERt预保存的模型。按照前面提到的页面,这是我要执行的代码

max_seq_length = 128  # Your choice here.
input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                       name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                   name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                    name="segment_ids")
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/1",
                            trainable=True)
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])

但是,一旦我执行Hub.KerasLayer,我就会遇到以下异常。

In [10]: bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/1", 
    ...:                             trainable=True)                                                                                                                                                        

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-f1fd8e265590> in <module>
      1 bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/1",
----> 2                             trainable=True)

~/anaconda3/lib/python3.7/site-packages/tensorflow_hub/keras_layer.py in __init__(self, handle, trainable, arguments, **kwargs)
    111       for v in self._func.trainable_variables:
    112         self._add_existing_weight(v, trainable=True)
--> 113       trainable_variables = set(self._func.trainable_variables)
    114     else:
    115       trainable_variables = set()

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py in __hash__(self)
   1087   def __hash__(self):
   1088     if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():  # pylint: disable=protected-access
-> 1089       raise TypeError("Variable is unhashable if Tensor equality is enabled. "
   1090                       "Instead, use tensor.experimental_ref() as the key.")
   1091     else:

TypeError: Variable is unhashable if Tensor equality is enabled. Instead, use tensor.experimental_ref() as the key.

我想念什么?我在macOS 10.15和TF 2.1上运行以上所有内容

0 个答案:

没有答案