我有一个像这样的keras RNN模型,它使用了预先训练的Word2Vec权重
model = Sequential()
model.add(L.Embedding(input_dim=vocab_size, output_dim=embedding_size,
input_length=max_phrase_length,
weights=[pretrained_weights],trainable=False))
model.add((L.LSTM(units=rnn_units)))
model.add((L.Dense(vocab_size,activation='sigmoid')))
adam=Adam(lr)
model.compile(optimizer=adam, loss='cosine_proximity',
metrics=['cosine_proximity'])
在培训期间,我想创建一个自定义损失函数来比较与预测整数和真实整数索引关联的预测单词和真实单词向量。
def custom_loss(y_true,y_pred):
A=extract_the_word_vectors_for_the_indices(y_true)
B=extract_the_word_vectors_for_the_indices(y_pred)
return some keras backend function of A and B
例如,假设我的批处理大小为4。然后从model.fit中,我可以将y_pred
通过argmax
传递,这样K.argmax(y_pred)=[i1,i2,i3,4]
就是与单词向量{{ 1}}。我想对预测向量进行一些数学运算,并将其与地面真实向量进行比较,以监视进度(而不是作为损失函数)。因此,我需要一种“充满Keras”的方式来做到这一点。
如果vectors[i1], vectors[i2], vectors[i3], vectors[i4]
是一个索引的数字数组,而y_true
是我的word2vec模型,那么我只要做word_model
就可以得到向量数组。但是,将word_model.wv.vectors[y_true]
从张量转换为numpy,然后再返回到张量似乎非常浪费。所以我似乎无法在本地keras中使用任何东西,当我尝试将张量提取到numpy数组并使用它们时,我也会遇到错误。 r ...
我想必须有一种方法可以从y_pred和y_true的嵌入层中提取单词向量,但是我不知道如何。有人吗?
答案 0 :(得分:1)
一个简单的解决方案是使用功能性api,并且您随时可以调用自定义损失函数。
from keras.models import Model
from keras.layers import Input, Embedding, LSTM, Dense
from keras.optimizers import Adam
model_input = Input((max_phrase_length, vocab_size))
embedding_layer = Embedding(input_dim=vocab_size, output_dim=embedding_size,
input_length=max_phrase_length,
weights=[pretrained_weights],trainable=False)
x = embedding_layer(model_input)
x = LSTM(units=rnn_units)(x)
x = Dense(units=vocab_size, activation='sigmoid')(x)
orignal_model = Model(inputs=model_input, outputs=x)
orignal_model.compile(optimizer=Adam(lr),
loss='cosine_proximity',
metrics=['cosine_proximity'])
embedding_model = Model(inputs=model_input, outputs=embedding_layer(model_input))
现在,您可以使用embedding_model执行所需的操作:
def custom_loss(y_true,y_pred, embedding_model):
A = embedding_model.predict(np.argmax(y_true))
B = embedding_model.predict(np.argmax(y_pred))
return some keras backend function of A and B
我没有检查代码,因此可能需要一些调整。