使用TensorArray Tensorflow的渐变错误

时间:2017-02-23 12:38:24

标签: python while-loop tensorflow lstm

我正在尝试在tensorflow中实现multidimentional lstm,我使用TensorArray来记住以前的状态,我使用一种复杂的方式来获得两个neigbours状态(上面和左边)。 tf.cond希望存在既有条件并且具有相同数量的输入。这就是为什么我再添加一个cell.zero_state到状态的(最后一个索引+1)。然后我使用函数来获取状态的正确索引。当我试图使用优化器以最小化成本时,我得到了这个错误:

  

InvalidArgumentError(参见上面的回溯):TensorArray   MultiDimentionalLSTMCell-l1-multi-l1 / state_ta_262 @ gradient:无法   从TensorArray索引809读取,因为它尚未编写   到。

有人能说出如何修复它吗?

Ps:没有优化器就可以了!

class MultiDimentionalLSTMCell(tf.nn.rnn_cell.RNNCell):
    """
    Note that state_is_tuple is always True.
    """

    def __init__(self, num_units, forget_bias=1.0, activation=tf.nn.tanh):
        self._num_units = num_units
        self._forget_bias = forget_bias
        self._activation = activation

    @property
    def state_size(self):
        return tf.nn.rnn_cell.LSTMStateTuple(self._num_units, self._num_units)

    @property
    def output_size(self):
        return self._num_units

    def __call__(self, inputs, state, scope=None):
        """Long short-term memory cell (LSTM).
        @param: imputs (batch,n)
        @param state: the states and hidden unit of the two cells
        """
        with tf.variable_scope(scope or type(self).__name__):
            c1,c2,h1,h2 = state

            # change bias argument to False since LN will add bias via shift
            concat = tf.nn.rnn_cell._linear([inputs, h1, h2], 5 * self._num_units, False)

            i, j, f1, f2, o = tf.split(1, 5, concat)

            new_c = (c1 * tf.nn.sigmoid(f1 + self._forget_bias) + 
                     c2 * tf.nn.sigmoid(f2 + self._forget_bias) + tf.nn.sigmoid(i) *
                   self._activation(j))

            new_h = self._activation(new_c) * tf.nn.sigmoid(o)
            new_state = tf.nn.rnn_cell.LSTMStateTuple(new_c, new_h)
            return new_h, new_state


def multiDimentionalRNN_whileLoop(rnn_size,input_data,sh,dims=None,scopeN="layer1"):
        """Implements naive multidimentional recurent neural networks

        @param rnn_size: the hidden units
        @param input_data: the data to process of shape [batch,h,w,chanels]
        @param sh: [heigth,width] of the windows 
        @param dims: dimentions to reverse the input data,eg.
            dims=[False,True,True,False] => true means reverse dimention
        @param scopeN : the scope

        returns [batch,h/sh[0],w/sh[1],chanels*sh[0]*sh[1]] the output of the lstm
        """
        with tf.variable_scope("MultiDimentionalLSTMCell-"+scopeN):
            cell = MultiDimentionalLSTMCell(rnn_size)

            shape = input_data.get_shape().as_list()

            if shape[1]%sh[0] != 0:
                offset = tf.zeros([shape[0], sh[0]-(shape[1]%sh[0]), shape[2], shape[3]])
                input_data = tf.concat(1,[input_data,offset])
                shape = input_data.get_shape().as_list()
            if shape[2]%sh[1] != 0:
                offset = tf.zeros([shape[0], shape[1], sh[1]-(shape[2]%sh[1]), shape[3]])
                input_data = tf.concat(2,[input_data,offset])
                shape = input_data.get_shape().as_list()

            h,w = int(shape[1]/sh[0]),int(shape[2]/sh[1])
            features = sh[1]*sh[0]*shape[3]
            batch_size = shape[0]

            x =  tf.reshape(input_data, [batch_size,h,w, features])
            if dims is not None:
                x = tf.reverse(x, dims)
            x = tf.transpose(x, [1,2,0,3])
            x =  tf.reshape(x, [-1, features])
            x = tf.split(0, h*w, x)     

            sequence_length = tf.ones(shape=(batch_size,), dtype=tf.int32)*shape[0]
            inputs_ta = tf.TensorArray(dtype=tf.float32, size=h*w,name='input_ta')
            inputs_ta = inputs_ta.unpack(x)
            states_ta = tf.TensorArray(dtype=tf.float32, size=h*w+1,name='state_ta',clear_after_read=False)
            outputs_ta = tf.TensorArray(dtype=tf.float32, size=h*w,name='output_ta')

            states_ta = states_ta.write(h*w,  tf.nn.rnn_cell.LSTMStateTuple(tf.zeros([batch_size,rnn_size], tf.float32),
                                                         tf.zeros([batch_size,rnn_size], tf.float32)))
            def getindex1(t,w):
                return tf.cond(tf.less_equal(tf.constant(w),t),
                               lambda:t-tf.constant(w),
                               lambda:tf.constant(h*w))
            def getindex2(t,w):
                return tf.cond(tf.less(tf.constant(0),tf.mod(t,tf.constant(w))),
                               lambda:t-tf.constant(1),
                               lambda:tf.constant(h*w))

            time = tf.constant(0)

            def body(time, outputs_ta, states_ta):
                constant_val = tf.constant(0)
                stateUp = tf.cond(tf.less_equal(tf.constant(w),time),
                                  lambda: states_ta.read(getindex1(time,w)),
                                  lambda: states_ta.read(h*w))
                stateLast = tf.cond(tf.less(constant_val,tf.mod(time,tf.constant(w))),
                                    lambda: states_ta.read(getindex2(time,w)),
                                    lambda: states_ta.read(h*w)) 

                currentState = stateUp[0],stateLast[0],stateUp[1],stateLast[1]
                out , state = cell(inputs_ta.read(time),currentState)  
                outputs_ta = outputs_ta.write(time,out)
                states_ta = states_ta.write(time,state)
                return time + 1, outputs_ta, states_ta

            def condition(time,outputs_ta,states_ta):
                return tf.less(time ,  tf.constant(h*w)) 

            result , outputs_ta, states_ta = tf.while_loop(condition, body, [time,outputs_ta,states_ta])


            outputs = outputs_ta.pack()
            states  = states_ta.pack()

            y =  tf.reshape(outputs, [h,w,batch_size,rnn_size])
            y = tf.transpose(y, [2,0,1,3])
            if dims is not None:
                y = tf.reverse(y, dims)

            return y


def tanAndSum(rnn_size,input_data,scope):
        outs = []
        for i in range(2):
            for j in range(2):
                dims = [False]*4
                if i!=0:
                    dims[1] = True
                if j!=0:
                    dims[2] = True                 
                outputs  = multiDimentionalRNN_whileLoop(rnn_size,input_data,[2,2],
                                                       dims,scope+"-multi-l{0}".format(i*2+j))
                outs.append(outputs)
        outs = tf.pack(outs, axis=0)
        mean = tf.reduce_mean(outs, 0)
        return tf.nn.tanh(mean)

graph = tf.Graph()
with graph.as_default():

    input_data =  tf.placeholder(tf.float32, [20,36,90,1])
    #input_data = tf.ones([20,36,90,1],dtype=tf.float32)
    sh = [2,2]
    out1 = tanAndSum(20,input_data,'l1')
    out = tanAndSum(25,out1,'l2')
    cost = tf.reduce_mean(out)
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
    #out = multiDimentionalRNN_raw_rnn(2,input_data,sh,dims=[False,True,True,False],scopeN="layer1")
    #cell = MultiDimentionalLSTMCell(10)
    #out = cell.zero_state(2, tf.float32).c
with tf.Session(graph=graph) as session:
    tf.global_variables_initializer().run()
    ou,k,_ = session.run([out,cost,optimizer],{input_data:np.ones([20,36,90,1],dtype=np.float32)})
    print(ou.shape)
    print(k)

1 个答案:

答案 0 :(得分:3)

您应该在while循环调用中添加参数parallel_iterations=1。 如:

result, outputs_ta, states_ta = tf.while_loop(
    condition, body, [time,outputs_ta,states_ta], parallel_iterations=1)

这是必需的,因为在体内,您在同一张量阵列(states_ta)上执行读写操作。在并行循环执行的情况下(parallel_iterations> 1),某些线程可能会尝试从tensorArray读取信息,而不是另一个线程写入的信息。

我已经在tensorflow 0.12.1上使用parallel_iterations = 1来测试您的代码段,并且它按预期工作。