pytorch在哪里实现嵌入“ max_norm”?

时间:2018-09-03 05:27:35

标签: pytorch

“嵌入”类文档https://pytorch.org/docs/stable/nn.html

max_norm (float, optional) – If given, will renormalize the embedding vectors to have a norm lesser than this before extracting.

1)在我的模型中,我将此嵌入类用作参数,而不仅仅是输入(模型学习嵌入)。在这种情况下,我假设每次发生更新时,嵌入都会重新规范化,不仅初始化时。我的理解正确吗?

2)我想通过查看源代码来确认1),但是在pytorch嵌入类中找不到实现。 https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html 有人可以指出我的max_norm实现吗?

1 个答案:

答案 0 :(得分:1)

如果您在嵌入类here中看到forward函数,则对torch.nn.functional.embedding的引用使用了cpp文档here中的embedding_renorm_,这意味着它是一个cpp实施。在pytorch回购上的一些github搜索指向此文件(12)。

回答1是。以上是2的答案。