import numpy as np
import tensorflow as tf
from tensorflow.python.ops import rnn, rnn_cell
if __name__ == '__main__':
np.random.seed(1234)
X = np.array(np.array(range(1,121)).reshape(4, 6, 5),dtype=float)
x0 = tf.placeholder(tf.float32, [4, 6, 5])
x = tf.reshape(x0, [-1, 5])
x = tf.split(0, 4, x)
lstm = tf.nn.rnn_cell.BasicLSTMCell(5,state_is_tuple=True)
with tf.variable_scope('sen'):
outputs, states = tf.nn.rnn(lstm, x, dtype=tf.float32)
with tf.variable_scope('par'):
output2, states2 = tf.nn.dynamic_rnn(lstm, x0, dtype=tf.float32,time_major = True)
with tf.variable_scope('sen2'):
outputs3, states3 = tf.nn.rnn(lstm, x, dtype=tf.float32)
with tf.Session() as sess:
for i in range(3):
sess.run(tf.initialize_all_variables())
result1,result2, result3 = sess.run([outputs[-1],output2[-1],outputs3[-1]],{x0:X})
print result1
print '---------------------------------------'
print result2
print '---------------------------------------'
print result3
print '------------------------------------------------------------------------------'
我认为result1,result2和result3应该始终相同。但他们并不相同。每次运行该函数时,result2都会发生变化。有什么问题?
答案 0 :(得分:2)
问题是尽管您使用的是单个LSTM单元,但您在不同的变量范围内创建了3个RNN,因此它们无法共享参数。考虑打印所有可训练的变量以查看:
for var in tf.trainable_variables():
print var.name
要明确使用相同的参数,请尝试scope.reuse_variables()
其中scope
是不同输出的相同范围。
我想出了以下内容:
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import rnn, rnn_cell
if __name__ == '__main__':
np.random.seed(1234)
X = np.array(np.array(range(1,121)).reshape(4, 6, 5), dtype = float)
x0 = tf.placeholder(tf.float32, [4, 6, 5])
x = tf.reshape(x0, [-1, 5])
x = tf.split(0, 4, x)
with tf.variable_scope('lstm') as scope:
lstm = tf.nn.rnn_cell.BasicLSTMCell(5, state_is_tuple = True)
outputs, states = tf.nn.rnn(lstm, x, dtype = tf.float32)
scope.reuse_variables()
outputs2, states2 = tf.nn.dynamic_rnn(lstm, x0, dtype=tf.float32,time_major = True)
outputs3, states3 = tf.nn.rnn(lstm, x, dtype=tf.float32)
print(outputs3)
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
for var in tf.trainable_variables():
print var.name
for i in range(3):
result1, result2, result3 = sess.run([outputs, outputs2, outputs3], feed_dict = {x0: X})
print result1
print '---------------------------------------'
print result2
print '---------------------------------------'
print result3
print '---------------------------------------'
似乎工作正常。