不使用@ tf.function,脚本可以完美运行
我想用它来加快训练速度,但是在重新使用嵌入层的权重矩阵时却出现了错误。
我认为该错误是由get_weights()引起的,因为它将张量转换回numpy
我尝试使用tf.keras.layers.Dense而不是重新使用嵌入的权重,并且效果很好。
static
在我的火车脚本中。 我做到了
class Example(tf.keras.Model):
def __init__(self,):
super(Example, self).__init__()
self.embed_dim = embed_dim
self.vocab_size = vocab_size
self.embed = tf.keras.layers.Embedding(self.vocab_size, self.embed_dim)
...
def call(self, inputs, trianing):
...
embed_matrix = self.embed.get_weights()
# a dense layer
Vhid = tf.matmul(self.kernel, tf.transpose(embed_matrix[0]))
pred_w = tf.matmul(pred, Vhid) + self.bias
答案 0 :(得分:0)
找到最简单的解决方案,将训练速度提高了50%(122个小时至〜65个小时)
只是改变
embed_matrix = self.embed.get_weights()
到
embed_matrix = self.embed.weights
可以解决问题。