使用@ tf.function重用嵌入层的权重矩阵

时间:2019-09-10 21:01:14

标签: tensorflow keras tensorflow2.0

不使用@ tf.function,脚本可以完美运行

我想用它来加快训练速度,但是在重新使用嵌入层的权重矩阵时却出现了错误。

我认为该错误是由get_weights()引起的,因为它将张量转换回numpy

我尝试使用tf.keras.layers.Dense而不是重新使用嵌入的权重,并且效果很好。

static

在我的火车脚本中。 我做到了

class Example(tf.keras.Model):
    def __init__(self,):
        super(Example, self).__init__()
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.embed = tf.keras.layers.Embedding(self.vocab_size, self.embed_dim)
        ...

    def call(self, inputs, trianing):
        ...
        embed_matrix = self.embed.get_weights()

        # a dense layer
        Vhid = tf.matmul(self.kernel, tf.transpose(embed_matrix[0]))
        pred_w = tf.matmul(pred, Vhid) + self.bias

1 个答案:

答案 0 :(得分:0)

找到最简单的解决方案,将训练速度提高了50%(122个小时至〜65个小时)

只是改变

embed_matrix = self.embed.get_weights()

embed_matrix = self.embed.weights

可以解决问题。