带有tf.nn.rnn_cell.MultiRNNCell的tf.contrib.layers.layer_norm

时间:2017-06-17 22:55:31

标签: tensorflow

我现在有多个RNN图层设置如下:

    stack = tf.nn.rnn_cell.MultiRNNCell([
        tf.nn.rnn_cell.GRUCell(num_hidden, activation=clipped_relu)
        for _ in range(num_rnn_layers)
    ])

但我正在尝试使用https://www.tensorflow.org/api_docs/python/tf/contrib/layers/layer_norm向RNN图层添加图层规范化。我尝试了很多不同的设置,但无法让模型编译。

有没有人这样做过?如果是这样,你是如何实现它的?

2 个答案:

答案 0 :(得分:0)

我认为您需要定义自己的图层类,以便在调用函数内进行规范化。你试过吗?

答案 1 :(得分:0)

这里有一个图层规范化实现:

tf.contrib.rnn.LayerNormBasicLSTMCell

可以在MultiRNNCell函数中使用。