在张量流图执行期间访问训练数据

时间:2021-03-17 20:52:28

标签: python tensorflow keras

我想在我的张量流图执行模型中使用预先训练好的句子嵌入。嵌入可从函数调用中动态获得,该函数调用接收一个句子数组并输出一个句子嵌入数组。此函数使用预训练的 pytorch 模型,因此必须与我正在训练的 tensorflow 模型保持分离:

def get_pretrained_embeddings(sentences):
  return pretrained_pytorch_model.encode(sentences)

我的 tensorflow 模型如下所示:

class SentenceModel(tf.keras.Model):

  def __init__(self):
    super().__init__()
    
  def call(self, sentences):
    embedding_layer = tf.keras.layers.Embedding(
    10_000,
    256,
    embeddings_initializer=tf.keras.initializers.Constant(get_pretrained_embeddings(sentences)),
    trainable=False,
  )
  sentence_text_embedding = tf.keras.Sequential([
    embedding_layer,
    tf.keras.layers.GlobalAveragePooling1D(),
  ])
  return sentence_text_embedding,

但是当我尝试使用

训练这个模型时
cached_train = train.shuffle(100_000).batch(1024)
model.fit(cached_train)

我的 embeddings_initializer 调用出现错误:

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

我认为这是因为 tensorflow 试图使用符号数据编译图形。如何让依赖于当前训练数据批次的外部函数与 tensorflow 的图训练一起使用?

0 个答案:

没有答案