更新张量流中的张量

时间:2018-03-07 09:29:28

标签: python tensorflow variable-assignment tensorflow-gradient

我在tensorflow中定义了一个无监督的问题,我需要在每次迭代时更新我的​​B和我的tfZ,但我不知道如何使用tensorflow会话更新我的tfZ

tfY = tf.placeholder(shape=(15, 15), dtype=tf.float32)

with tf.variable_scope('test'):
    B = tf.Variable(tf.zeros([]))
    tfZ = tf.convert_to_tensor(Z, dtype=tf.float32)

def loss(tfY):
    r = tf.reduce_sum(tfZ*tfZ, 1)
    r = tf.reshape(r, [-1, 1])
    D = tf.sqrt(r - 2*tf.matmul(tfZ, tf.transpose(tfZ)) + tf.transpose(r) + 1e-9)
    return tf.reduce_sum(tfY*tf.log(tf.sigmoid(D+B))+(1-tfY)*tf.log(1-tf.sigmoid(D+B)))

LOSS = loss(Y)
GRADIENT = tf.gradients(LOSS, [B, tfZ])

sess = tf.Session()
sess.run(tf.global_variables_initializer())

tot_loss = sess.run(LOSS, feed_dict={tfY: Y})

loss_grad = sess.run(GRADIENT, feed_dict={tfY: Y})

learning_rate = 1e-4
for i in range(1000):
    sess.run(B.assign(B - learning_rate * loss_grad[0]))
    print(tfZ)
    sess.run(tfZ.assign(tfZ - learning_rate * loss_grad[1]))

    tot_loss = sess.run(LOSS, feed_dict={tfY: Y})
    if i%10==0:
        print(tot_loss)

此代码打印以下内容:

Tensor("test_18/Const:0", shape=(15, 2), dtype=float32)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-35-74ddafc0bf3a> in <module>()
     25     sess.run(B.assign(B - learning_rate * loss_grad[0]))
     26     print(tfZ)
---> 27     sess.run(tfZ.assign(tfZ - learning_rate * loss_grad[1]))
     28 
     29     tot_loss = sess.run(LOSS, feed_dict={tfY: Y})

AttributeError: 'Tensor' object has no attribute 'assign'

张量对象正确地没有赋值属性,但我找不到附加到该对象的任何其他函数。如何正确更新张量?

1 个答案:

答案 0 :(得分:6)

tf.Variable不同,tf.Tensor不提供assign方法;如果张量是 mutable ,则必须明确调用tf.assign函数:

tf.assign(tfZ, tfZ - learning_rate * loss_grad[1])

更新:并非所有张量都是可变的,例如你的tfZ不是。截至目前,可变张量只是那些与this answer中解释的变量相对应的张量(至少在tensorflow 1.x中,这可以在将来扩展)。普通张量是op结果的句柄,即它们与该操作和它的输入绑定。要更改不可变张量值,必须更改源张量(占位符或变量)。在您的特定情况下,也可以更轻松地创建tfZ变量。

顺便说一下,tf.Variable.assign()只是tf.assign的包装器,并且必须在会话中运行结果op才能实际执行赋值。

请注意,在这两种情况下,都会创建图表中的新节点。如果你在一个循环中调用它(就像在你的代码片段中那样),图形将被一千个节点膨胀。在实际生产代码中这样做是一种不好的做法,因为它很容易导致OOM。