在微调Bert的下游任务时添加其他功能时出错

时间:2020-10-28 12:34:31

标签: tensorflow keras bert-language-model transfer-learning

我们基于本教程来实现Bert。 https://www.kaggle.com/anasofiauzsoy/tutorial-notebook

我们通过将余弦相似度之类的其他功能添加到Bert的下游微调任务中来扩展教程笔记本,但是我们遇到了一个错误,我们无法理解。我们正在尝试进行NLI,并尝试通过计算假设和前提的相似性来增强Bert。 然后,我们将这些其他信息作为张量添加到微调过程中。我们将Bert嵌入和余弦相似度连接到merge_layer中,并继续前馈。

the shape of our input data

max_len = 50

def build_model():
    bert_encoder = TFBertModel.from_pretrained(model_name)
    input_word_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
    input_mask = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_mask")
    input_type_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_type_ids")
    
    embedding = bert_encoder([input_word_ids, input_mask, input_type_ids])[0]
    splice_layer1 = tf.keras.layers.Dense(1024, activation='relu')(embedding[:,0,:])
    #From here, splice the similarity feature into the model
    input_similarities = tf.keras.Input(shape=(1,), dtype=tf.float32, name="input_similarities")
    splice_layer2 = tf.keras.layers.Dense(1024, activation='relu')(input_similarities)
    
    merge_layer = tf.keras.layers.concatenate(splice_layer1, splice_layer2)
    
    fine_tune_layer1 = tf.keras.layers.Dense(1024, activation="relu")(merge_layer)
    fine_tune_layer2 = tf.keras.layers.Dense(1024, activation="relu")(fine_tune_layer1)

    
    output = tf.keras.layers.Dense(3, activation='softmax')(fine_tune_layer2)
    
    model = tf.keras.Model(inputs=[input_word_ids, input_mask, input_type_ids, input_similarities], outputs=output)
    model.compile(tf.keras.optimizers.Adam(lr=1e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    
    return model

with strategy.scope():
    model = build_model()
    model.summary()

model.fit(train_input, train_labels, epochs = 10, verbose = 1, batch_size = 64, validation_split = 0.2)

使用模型,我们会收到以下错误:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-36-adbd973191e4> in <module>
----> 1 model.fit(train_input, train_labels, epochs = 10, verbose = 1, batch_size = 64, validation_split = 0.2)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
    106   def _method_wrapper(self, *args, **kwargs):
    107     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
--> 108       return method(self, *args, **kwargs)
    109 
    110     # Running inside `run_distribute_coordinator` already.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1096                 batch_size=batch_size):
   1097               callbacks.on_train_batch_begin(step)
-> 1098               tmp_logs = train_function(iterator)
   1099               if data_handler.should_sync:
   1100                 context.async_wait()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call_(self, *args, **kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args, **kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    838         # Lifting succeeded, so variables are initialized and we can run the
    839         # stateless function.
--> 840         return self._stateless_fn(*args, **kwds)
    841     else:
    842       canon_args, canon_kwds = \

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_(self, *args, **kwargs)
   2827     with self._lock:
   2828       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2829     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2830 
   2831   @property

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs, cancellation_manager)
   1846                            resource_variable_ops.BaseResourceVariable))],
   1847         captured_inputs=self.captured_inputs,
-> 1848         cancellation_manager=cancellation_manager)
   1849 
   1850   def _call_flat(self, args, captured_inputs, cancellation_manager=None):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1922       # No tape is watching; skip to running the function.
   1923       return self._build_call_outputs(self._inference_function.call(
-> 1924           ctx, args, cancellation_manager=cancellation_manager))
   1925     forward_backward = self._select_forward_and_backward_functions(
   1926         args,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    548               inputs=args,
    549               attrs=attrs,
--> 550               ctx=ctx)
    551         else:
    552           outputs = execute.execute_with_cancellation(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  indices[60,0] = 101 is not in [0, 2)
     [[node functional_9/tf_bert_model_4/bert/embeddings/token_type_embeddings/embedding_lookup (defined at /opt/conda/lib/python3.7/site-packages/transformers/modeling_tf_bert.py:190) ]] [Op:__inference_train_function_204495]

Errors may have originated from an input operation.
Input Source operations connected to node functional_9/tf_bert_model_4/bert/embeddings/token_type_embeddings/embedding_lookup:
 functional_9/tf_bert_model_4/bert/embeddings/token_type_embeddings/embedding_lookup/195396 (defined at /opt/conda/lib/python3.7/contextlib.py:112) 
 IteratorGetNext (defined at <ipython-input-36-adbd973191e4>:1)

Function call stack:
train_function

预先感谢:)

0 个答案:

没有答案