Tensorflow:如何通过tf.gather传播渐变?

时间:2016-07-24 04:51:38

标签: tensorflow

我有一些问题试图传播我的损失函数的梯度相对于表示聚集索引的变量,类似于空间变换器网络(https://github.com/tensorflow/models/blob/master/transformer/spatial_transformer.py)中所做的。我觉得我可能会遗漏一些非常简单的事情。这是我想要做的简化玩具示例:

import tensorflow as tf
import numpy as np
lf = np.array([1.0,2.0,3.0])
lf_b = 2.0
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, shape=(3))
pt = tf.Variable(0, name='point')
y_ = tf.placeholder(tf.float32, shape=())
sess.run(tf.initialize_all_variables())
y = tf.gather(x, pt)
data_loss = tf.reduce_mean(tf.squared_difference(y,y_))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(data_loss)

目前,这会返回错误: ValueError:没有为任何变量提供渐变

1 个答案:

答案 0 :(得分:0)

问题在于你无法区分,因为pt是一个整数。它在x占位符中选择一个索引,因此它没有派生。通常,当您执行此操作时,您将输入一个整数并使用它来选择浮点值。你反过来也是这样做的。