我现在有多个RNN图层设置如下:
stack = tf.nn.rnn_cell.MultiRNNCell([
tf.nn.rnn_cell.GRUCell(num_hidden, activation=clipped_relu)
for _ in range(num_rnn_layers)
])
但我正在尝试使用https://www.tensorflow.org/api_docs/python/tf/contrib/layers/layer_norm向RNN图层添加图层规范化。我尝试了很多不同的设置,但无法让模型编译。
有没有人这样做过?如果是这样,你是如何实现它的?
答案 0 :(得分:0)
我认为您需要定义自己的图层类,以便在调用函数内进行规范化。你试过吗?
答案 1 :(得分:0)
这里有一个图层规范化实现:
tf.contrib.rnn.LayerNormBasicLSTMCell
可以在MultiRNNCell
函数中使用。