如何训练Huggingface TFT5ForConditionalGeneration模型?

时间:2020-08-26 09:50:43

标签: tensorflow huggingface-transformers

我的代码如下:

batch_size=8
sequence_length=25
vocab_size=100
import tensorflow as tf
from transformers import T5Config, TFT5ForConditionalGeneration
configT5 = T5Config(
    vocab_size=vocab_size,
    d_ff =512, 
)  
model = TFT5ForConditionalGeneration(configT5)

model.compile(
    optimizer = tf.keras.optimizers.Adam(),
    loss = tf.keras.losses.SparseCategoricalCrossentropy()
)
input = tf.random.uniform([batch_size,sequence_length],0,vocab_size,dtype=tf.int32)
labels = tf.random.uniform([batch_size,sequence_length],0,vocab_size,dtype=tf.int32)
input = {'inputs': input, 'decoder_input_ids': input}
model.fit(input, labels)

它产生一个错误:

logit和标签的第一维必须相同,并且具有logits形状 [1600,64]和标签形状[200] [[node sparse_categorical_crossentropy_3 / SparseSoftmaxCrossEntropyWithLogits / SparseSoftmaxCrossEntropyWithLogits (在C:\ Users \ FA.PROJECTOR-MSK \ GoogleДиск\ Colab中定义 Notebooks \ PoetryTransformer \ experiments \ TFT5.py:30)]] [Op:__ inference_train_function_25173]函数调用堆栈: train_function

我不明白-为什么模型返回张量[1600,64]。根据{{​​3}}模型返回[batch_size,sequence_len,vocab_size]。

1 个答案:

答案 0 :(得分:0)

由于TFT5ForConditionalGeneration的Sub Main() Dim xDoc As XmlDocument Dim result As XmlNodeList xDoc = New XmlDocument xDoc.Load("test.xml") result = xDoc.SelectNodes("/config/entry/content/Issue/id") Print(result.Count) End Sub 方法的非标准签名,因此无法调用fit()。我必须重写call()才能使TFT5正常工作。看到这里-https://colab.research.google.com/github/snapthat/TF-T5-text-to-text/blob/master/snapthatT5/notebooks/TF-T5-Datasets%20Training.ipynb#scrollTo=cgxRVn34Z0wb