字符级别的BasicRnnCell中的权重如何更新

时间:2018-11-18 00:39:44

标签: tensorflow rnn

当我使用Tensorflow构建字符级RNN网络时,我对模型权重的变化感到困惑。我以为重量不会在一批中更新。

但是内核(我认为是Wk)正在发生变化。并且有6个内核。所以我很困惑为什么更改了它,为什么有6个内核。我可以使用Tensorflow直接获得W和U吗?这是我的代码。谢谢。

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()
h = [1, 0, 0, 0]
e = [0, 1, 0, 0]
l = [0, 0, 1, 0]
o = [0, 0, 0, 1]


with tf.variable_scope('two_sequances') as scope:
    # One cell RNN input_dim (4) -> output_dim (2). sequence: 5
    hidden_size = 2
    cell = tf.contrib.rnn.BasicRNNCell(num_units=hidden_size)
    x_data = np.array([[h, e, l, l, o]], dtype=np.float32)
    print(x_data.shape)
    print(x_data)
    outputs, _states = tf.nn.dynamic_rnn(cell, x_data, dtype=tf.float32)
    sess.run(tf.global_variables_initializer())
    results, state = sess.run([outputs, _states])

    variable_names = [v.name for v in tf.global_variables()]
    values = sess.run(variable_names)
    for k, v in zip(variable_names, values):
        print(k, v)

变量的输出如下。

two_sequances/rnn/basic_rnn_cell/kernel:0 
[[ 0.6147509   0.6268855 ]
 [ 0.34818882  0.8140872 ]
 [ 0.4074654   0.011693  ]
 [-0.5032909  -0.69920516]
 [ 0.62231725  0.18967694]
 [ 0.6888749   0.77280706]]
two_sequances/rnn/basic_rnn_cell/bias:0 
[0. 0.]

0 个答案:

没有答案