如何使用while循环迭代器修改tf变量

时间:2018-07-18 07:19:02

标签: python tensorflow

我想切片张量并将其存储在变量中。固定编号的切片效果很好,例如:t[0:2]。但是用另一个张量切一个变量不起作用。例如t[t1:t2]也可以将切片存储在张量中很好,但是当我尝试将其存储在tf中时,变量会出错。

import tensorflow as tf
import numpy

i=tf.zeros([2,1],tf.int32)
i2=tf.get_variable('i2_variable',initializer=i) #putting a multidimensional tensor in a variable
i4=tf.ones([10,1],tf.int32)
sess=tf.Session()
sess.run(tf.global_variables_initializer()) #initializing variables
itr=tf.constant(0,tf.int32)

def w_c(i2,itr):
    return tf.less(itr,2)

def w_b(i2,itr):
    i2=i4[(itr*0):((itr*0)+2)] #doesnt work
    #i2=i4[0:2] #works
    #i=i4[(itr*0):((itr*0)+2)] #works with tensor i 
    itr=tf.add(itr,1)
    return[i2,itr]
OP=tf.while_loop(w_c,w_b,[i2,itr])
print(sess.run(OP))

我收到以下错误:

ValueError: Input tensor 'i2_variable/read:0' enters the 
loop with shape (2, 1), but has shape (?, 1) after one iteration. 
To allow the shape to vary across iterations, 
use the `shape_invariants` argument of tf.while_loop to specify a less-specific shape.

1 个答案:

答案 0 :(得分:1)

如果您指定shape_invariants,则代码不会引发错误。

   OP=tf.while_loop(w_c,w_b,[i2,itr],shape_invariants=
                                        [ tf.TensorShape([None, None]), 
                                          itr.get_shape()])

它返回这个。

  [array([[1],
   [1]]), 2]