我有一个以这种方式创建的MultiRNN单元
def get_cell(cell_type, num_units, training):
if cell_type == "RNN":
cell = tf.contrib.rnn.BasicRNNCell(num_units)
elif cell_type == "LSTM":
cell = tf.contrib.rnn.BasicLSTMCell(num_units)
else:
cell = tf.contrib.rnn.GRUCell(num_units)
if training:
cell = tf.contrib.rnn.DropoutWrapper(cell,
input_keep_prob=params["dropout_input_keep_prob"],
output_keep_prob=params["dropout_output_keep_prob"],
state_keep_prob=params["dropout_state_keep_prob"])
return cell
final_cell_structure = tf.contrib.rnn.MultiRNNCell([get_cell(cell_type, num_units, (mode == tf.estimator.ModeKeys.TRAIN)) for _ in range(num_layers)])
我正在尝试将其状态初始化为随机值。我试过这样做:
initial_state = state = final_cell_structure.zero_state(batch_size, tf.float32)
if mode == tf.estimator.ModeKeys.PREDICT:
state = state + tf.random_normal(shape=tf.shape(state), mean=0.0, stddev=0.6)
但我不断收到错误消息
Expected state to be a tuple of length 3, but received: Tensor("Reshape:0", shape=(3, 1, 10), dtype=float32)
当我使用它时
output, state = final_cell_structure(inputs, state)
更新 我尝试使用
state = [st + tf.random_normal(shape=tf.shape(st), mean=0.0, stddev=0.6) for st in state]
正如Pop所建议的,它适用于基本RNN细胞和GRU细胞,但当我将其与LSTM细胞一起使用时,我会收到以下错误
Tensor objects are not iterable when eager execution is not enabled. To iterate over this tensor use tf.map_fn
解决 LSTM单元状态由元组组成,所以我发现这个解决方案可以工作
state_placeholder = tf.random_normal(shape=(num_layers, 2, batch_size, num_units), mean=0.0, stddev=1.0)
l = tf.unstack(state_placeholder, axis=0)
state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1]) for idx in range(num_layers)])
答案 0 :(得分:0)
这个想法是state
是一个元组。
所以你需要以这种方式更新它:
state = [st + tf.random_normal(shape=tf.shape(st), mean=0.0, stddev=0.6) for st in state]
它应该有用。
使用您的方法,您创建了单个张量f形状(2,b,k)而不是具有相同大小(b,k)的两个张量的元组