更新LSTM的初始状态

时间:2019-07-30 07:58:52

标签: python tensorflow

我试图通过保存整个批次中的相应状态来实现有状态LSTM。

为此,我正在使用https://stackoverflow.com/a/41240243/860160中提供的代码。

def get_state_variables(batch_size, cell):
    # For each layer, get the initial state and make a variable out of it
    # to enable updating its value.
    state_variables = []
    for state_c, state_h in cell.zero_state(batch_size, tf.float32):
        state_variables.append(tf.contrib.rnn.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    return tuple(state_variables)


def get_state_update_op(state_variables, new_states):
    # Add an operation to update the train states with the last state tensors
    update_ops = []
    for state_variable, new_state in zip(state_variables, new_states):
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(new_state[0]),
                           state_variable[1].assign(new_state[1])])
    # Return a tuple in order to combine all update_ops into a single operation.
    # The tuple's actual value should not be used.
    return tf.tuple(update_ops)

但是,我收到与 tf.Variable(state_c,trainable = False)相关的错误,如下所示:

ValueError: initial_value must have a shape specified: Tensor("encoder/MultiRNNCellZeroState/LSTMCellZeroState/zeros:0", shape=(?, 500), dtype=float32)

我不知道自己在做什么错。

0 个答案:

没有答案