如何进行张量变换并保留渐变?

时间:2017-10-22 19:46:28

标签: python tensorflow

在Tensorflow中,我有一个浮点张量T,其形状为[batch_size,3]。例如,T[0] = [4, 4, 3]

我希望将其转换为大小为5的热点,以便从嵌入字典中产生条目。在上面的例子中,这看起来像

T[0] = [[0, 0, 0, 0, 1], [0, 0, 0, 0, 1], [0, 0, 0, 1, 0]].

如果我可以使用这种格式,那么我可以将它乘以嵌入字典。然而,这是在图的中间,我需要渐变流过它。是否有一种聪明的方法来使用stop_gradient la How Can I Define Only the Gradient for a Tensorflow Subgraph?来完成这项工作?我差不多了。

1 个答案:

答案 0 :(得分:1)

我能够通过以下方式解决这个问题:

expanded = tf.expand_dims(inputs, 2)
embedding_input = tf.cast(tf.one_hot(tf.to_int32(inputs), 5), inputs.dtype)
embedding_input = tf.stop_gradient(embedding_input - expanded) + expanded