在TensorFlow中使用MultiRNNCell的随机初始状态

时间:2018-02-22 09:58:07

标签: python tensorflow state rnn

我有一个以这种方式创建的MultiRNN单元

def get_cell(cell_type, num_units, training):
    if cell_type == "RNN":
        cell = tf.contrib.rnn.BasicRNNCell(num_units)
    elif cell_type == "LSTM":
        cell = tf.contrib.rnn.BasicLSTMCell(num_units)
    else:
        cell = tf.contrib.rnn.GRUCell(num_units)

    if training:
        cell = tf.contrib.rnn.DropoutWrapper(cell,
                                input_keep_prob=params["dropout_input_keep_prob"],
                                output_keep_prob=params["dropout_output_keep_prob"],
                                state_keep_prob=params["dropout_state_keep_prob"])

    return cell

final_cell_structure = tf.contrib.rnn.MultiRNNCell([get_cell(cell_type, num_units, (mode == tf.estimator.ModeKeys.TRAIN)) for _ in range(num_layers)])

我正在尝试将其状态初始化为随机值。我试过这样做:

initial_state = state = final_cell_structure.zero_state(batch_size, tf.float32)
if mode == tf.estimator.ModeKeys.PREDICT:
    state = state + tf.random_normal(shape=tf.shape(state), mean=0.0, stddev=0.6)

但我不断收到错误消息

Expected state to be a tuple of length 3, but received: Tensor("Reshape:0", shape=(3, 1, 10), dtype=float32)

当我使用它时

output, state = final_cell_structure(inputs, state)

更新 我尝试使用

state = [st + tf.random_normal(shape=tf.shape(st), mean=0.0, stddev=0.6) for st in state]

正如Pop所建议的,它适用于基本RNN细胞和GRU细胞,但当我将其与LSTM细胞一起使用时,我会收到以下错误

Tensor objects are not iterable when eager execution is not enabled. To iterate over this tensor use tf.map_fn

解决 LSTM单元状态由元组组成,所以我发现这个解决方案可以工作

state_placeholder = tf.random_normal(shape=(num_layers, 2, batch_size, num_units), mean=0.0, stddev=1.0)
l = tf.unstack(state_placeholder, axis=0)
state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1]) for idx in range(num_layers)])

1 个答案:

答案 0 :(得分:0)

这个想法是state是一个元组。

所以你需要以这种方式更新它:

state = [st + tf.random_normal(shape=tf.shape(st), mean=0.0, stddev=0.6) for st in state]

它应该有用。

使用您的方法,您创建了单个张量f形状(2,b,k)而不是具有相同大小(b,k)的两个张量的元组