Tensorflow中的循环切片分配

时间:2018-04-17 19:46:45

标签: python tensorflow

我有一个5x3的零矩阵,我想在while_loop中用一个更新。我想使用循环变量作为scatter_nd_update函数的indices参数。我的代码是这样的:

# Zeros matrix
num = tf.get_variable('num', shape=[5, 3], initializer=tf.zeros_initializer(), dtype=tf.float32)
# Looping variable
i = tf.constant(0, dtype=tf.int32)
# Conditional
c = lambda i, num: tf.less(i, 2)
def body(i, num):
    # Update values
    updates = tf.ones([1, 3], dtype=tf.float32)
    num = tf.scatter_nd_update(num, [[i]], updates)
    return tf.add(i, 1), num
i, num = tf.while_loop(c, body, [i, num])
# Session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    num_out = sess.run(num)
    print(num_out.shape)
    print(num_out)

这会引发错误:AttributeError: 'Tensor' object has no attribute 'handle'并指向num = tf.scatter_nd_update(num, [[i]], updates)

当我通过使用不同的i值运行num = tf.scatter_nd_update(num, [[i]], updates)两次运行此代码而没有循环时,它会工作,我得到一个包含2行1的矩阵,但是当我在while_loop中尝试相同的事情时会发生此错误

1 个答案:

答案 0 :(得分:1)

问题围绕以下事实:tf.scatter_nd_update()需要 变量 进行更改,而tf.while_loop()使用 张量< / em> 作为循环变量。从根本上说,tf.while_loop()会在设置图表时运行循环,而tf.scatter_nd_update()是在网络运行时运行的操作

换句话说,您创建的网络将有三个num张量:一个带有原始零,然后在第一行被替换后跟随另一个,然后跟随另一个,前两行更换。为了实现这一点,您可以使用此代码(已测试),下面有更多解释:

import tensorflow as tf
num = tf.zeros( shape = ( 5, 3 ), dtype = tf.float32 )
# Looping variable
i = tf.zeros( shape=(), dtype=tf.int32)
# Conditional
c = lambda i, num: tf.less(i, 2)
def body(i, num):
    # Update values
    updates = tf.ones([1, 3], dtype=tf.float32)
    num_shape = num.get_shape()
    num = tf.concat( [ num[ : i ], updates, num[ i + 1 : ] ], axis = 0 )
    num.set_shape( num_shape )
    return tf.add(i, tf.ones( shape=(), dtype = tf.int32 ) ), num
i, num = tf.while_loop( c, body, [ i, num ] )
# Session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    num_out = sess.run( [ num ] )
    print(num_out)

输出:

  

[array([[1。,1.,1。],
         [1.,1.,1。],
         [0.,0.,0。],
         [0.,0.,0。],
         [0.,0.,0。]],dtype = float32)]

首先,请注意我已将num更改为变量的张量。这将允许它在tf.while_loop()中用作循环变量。其次,分散操作在张量上没有很好的方法,所以我基本上将num分开(在i之前和之后 - i,并在其间插入update)。我们还必须设置num的形状,否则tf.while_loop()会抱怨形状不确定(因为tf.concat();有一种方法可以使用{tf.while_loop()在{{3}}中的1}}参数但对我们的情况来说这更容易。)