我的网络中有很多需要嵌入的物品。
但是,在每个培训批次中,实际上只会使用很小一部分的项目。
如果我使用常规的tf.keras.layers.Embedding
层,它将把所有项目添加到网络参数中,从而在分布式训练中会消耗大量内存并显着降低速度,因为在每个步骤中,所有未使用的grads仍会同步
我想要的是,在每个训练步骤中,仅将实际使用的嵌入权重添加到图形中,并进行计算和同步。
Pytorch
已通过torch.nn.Embedding(sparse=True)
提供了此功能。
如何在Tensorflow 2中实现这一点?
答案 0 :(得分:1)
我的问题...检查tf.GradientTape()告诉我tf.gather的梯度已经是稀疏张量,因此无需理会。 p>