我们基于本教程来实现Bert。 https://www.kaggle.com/anasofiauzsoy/tutorial-notebook
我们通过将余弦相似度之类的其他功能添加到Bert的下游微调任务中来扩展教程笔记本,但是我们遇到了一个错误,我们无法理解。我们正在尝试进行NLI,并尝试通过计算假设和前提的相似性来增强Bert。 然后,我们将这些其他信息作为张量添加到微调过程中。我们将Bert嵌入和余弦相似度连接到merge_layer中,并继续前馈。
max_len = 50
def build_model():
bert_encoder = TFBertModel.from_pretrained(model_name)
input_word_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_type_ids")
embedding = bert_encoder([input_word_ids, input_mask, input_type_ids])[0]
splice_layer1 = tf.keras.layers.Dense(1024, activation='relu')(embedding[:,0,:])
#From here, splice the similarity feature into the model
input_similarities = tf.keras.Input(shape=(1,), dtype=tf.float32, name="input_similarities")
splice_layer2 = tf.keras.layers.Dense(1024, activation='relu')(input_similarities)
merge_layer = tf.keras.layers.concatenate(splice_layer1, splice_layer2)
fine_tune_layer1 = tf.keras.layers.Dense(1024, activation="relu")(merge_layer)
fine_tune_layer2 = tf.keras.layers.Dense(1024, activation="relu")(fine_tune_layer1)
output = tf.keras.layers.Dense(3, activation='softmax')(fine_tune_layer2)
model = tf.keras.Model(inputs=[input_word_ids, input_mask, input_type_ids, input_similarities], outputs=output)
model.compile(tf.keras.optimizers.Adam(lr=1e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model
with strategy.scope():
model = build_model()
model.summary()
model.fit(train_input, train_labels, epochs = 10, verbose = 1, batch_size = 64, validation_split = 0.2)
使用模型,我们会收到以下错误:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-36-adbd973191e4> in <module>
----> 1 model.fit(train_input, train_labels, epochs = 10, verbose = 1, batch_size = 64, validation_split = 0.2)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
106 def _method_wrapper(self, *args, **kwargs):
107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
--> 108 return method(self, *args, **kwargs)
109
110 # Running inside `run_distribute_coordinator` already.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1096 batch_size=batch_size):
1097 callbacks.on_train_batch_begin(step)
-> 1098 tmp_logs = train_function(iterator)
1099 if data_handler.should_sync:
1100 context.async_wait()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call_(self, *args, **kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args, **kwds)
781
782 new_tracing_count = self._get_tracing_count()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
838 # Lifting succeeded, so variables are initialized and we can run the
839 # stateless function.
--> 840 return self._stateless_fn(*args, **kwds)
841 else:
842 canon_args, canon_kwds = \
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_(self, *args, **kwargs)
2827 with self._lock:
2828 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2829 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2830
2831 @property
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs, cancellation_manager)
1846 resource_variable_ops.BaseResourceVariable))],
1847 captured_inputs=self.captured_inputs,
-> 1848 cancellation_manager=cancellation_manager)
1849
1850 def _call_flat(self, args, captured_inputs, cancellation_manager=None):
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1922 # No tape is watching; skip to running the function.
1923 return self._build_call_outputs(self._inference_function.call(
-> 1924 ctx, args, cancellation_manager=cancellation_manager))
1925 forward_backward = self._select_forward_and_backward_functions(
1926 args,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
548 inputs=args,
549 attrs=attrs,
--> 550 ctx=ctx)
551 else:
552 outputs = execute.execute_with_cancellation(
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: indices[60,0] = 101 is not in [0, 2)
[[node functional_9/tf_bert_model_4/bert/embeddings/token_type_embeddings/embedding_lookup (defined at /opt/conda/lib/python3.7/site-packages/transformers/modeling_tf_bert.py:190) ]] [Op:__inference_train_function_204495]
Errors may have originated from an input operation.
Input Source operations connected to node functional_9/tf_bert_model_4/bert/embeddings/token_type_embeddings/embedding_lookup:
functional_9/tf_bert_model_4/bert/embeddings/token_type_embeddings/embedding_lookup/195396 (defined at /opt/conda/lib/python3.7/contextlib.py:112)
IteratorGetNext (defined at <ipython-input-36-adbd973191e4>:1)
Function call stack:
train_function
预先感谢:)