如何在Tensorflow 2中实现稀疏嵌入,例如Pytorch嵌入(sparse = True)?

时间:2020-07-24 08:31:37

标签: tensorflow pytorch

我的网络中有很多需要嵌入的物品。

但是,在每个培训批次中,实际上只会使用很小一部分的项目。

如果我使用常规的tf.keras.layers.Embedding层,它将把所有项目添加到网络参数中,从而在分布式训练中会消耗大量内存并显着降低速度,因为在每个步骤中,所有未使用的grads仍会同步

我想要的是,在每个训练步骤中,仅将实际使用的嵌入权重添加到图形中,并进行计算和同步。

Pytorch已通过torch.nn.Embedding(sparse=True)提供了此功能。

如何在Tensorflow 2中实现这一点?

1 个答案:

答案 0 :(得分:1)

我的问题...检查tf.GradientTape()告诉我tf.gather的梯度已经是稀疏张量,因此无需理会。