我先在急切的执行模式下然后在图形模式下编写了相同的代码。现在,我还不能弄清楚为什么在急切模式下可以正常运行时,为什么GRU状态没有保留在图形模式下。
以下是紧急模式代码:
import tensorflow as tf
import xxhash
import numpy as np
tf.enable_eager_execution()
rnn_units = 1024
def hash_code(arr):
return xxhash.xxh64(arr).hexdigest()
model = tf.keras.Sequential([tf.keras.layers.GRU(rnn_units,
return_sequences=True,
stateful=True,
recurrent_initializer='glorot_uniform', batch_input_shape=[1, None, 256])])
lstm_wt = np.load('lstm_wt.npy', allow_pickle=True) # fixed weights for comparison
lstm_re_wt = np.load('lstm_re_wt.npy', allow_pickle=True)
lstm_bias = np.load('lstm_bias.npy', allow_pickle=True)
model.layers[0].set_weights([lstm_wt, lstm_re_wt, lstm_bias])
op_embed = np.load('op_embed.npy', allow_pickle=True) # fixed input
op_lstm = model(op_embed)
print(hash_code(op_lstm.numpy()))
op_lstm = model(op_embed)
print(hash_code(op_lstm.numpy()))
model.layers[0].reset_states() # now reset the state, you'll get back the initial output.
op_lstm = model(op_embed)
print(hash_code(op_lstm.numpy()))
此代码的输出:
d092fdb4739588a3
cdfdf8b8e292c6e8
d092fdb4739588a3
现在,图形模型代码:
import tensorflow as tf
import xxhas
import numpy as np
# checking lstm
op_embed = np.load('op_embed.npy', allow_pickle=True)
# load op_embed, lstm weights
lstm_wt = np.load('lstm_wt.npy', allow_pickle=True)
lstm_re_wt = np.load('lstm_re_wt.npy', allow_pickle=True)
lstm_bias = np.load('lstm_bias.npy', allow_pickle=True)
rnn_units = 1024
layers = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
stateful=True,
recurrent_initializer='glorot_uniform')
x_placeholder = tf.placeholder(shape=op_embed.shape, dtype=tf.float32)
op_lstm = layers(x_placeholder)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
layers.set_weights([lstm_wt, lstm_re_wt, lstm_bias])
tf.assign(layers.weights[0],lstm_wt ).eval(sess)
tf.assign(layers.weights[1], lstm_re_wt).eval(sess)
tf.assign(layers.weights[2], lstm_bias).eval(sess)
print('keras op hash',xxhash.xxh64(sess.run(op_lstm, feed_dict={x_placeholder:op_embed})).hexdigest())
print('keras op hash',xxhash.xxh64(sess.run(op_lstm, feed_dict={x_placeholder:op_embed})).hexdigest())
输出:
keras op hash d092fdb4739588a3
keras op hash d092fdb4739588a3
在图形模式下如何解决此歧义并保留状态的任何见解? 之前有人问过类似的问题,但尚未回答。 Statefulness in eager mode vs non-eager mode
答案 0 :(得分:0)
即使是在问题中提供的Link中,也要在此处(答案部分)指定解决方案,以社区的利益。
Recurrent Neural Network
(RNN
或GRU
或LSTM
)在state
/ Non-Eager-Mode
中执行时丢失了Graph-Mode
默认。
如果要保留state
,则需要在Initial State
调用期间传递RNN
,如下所示:
current_state = np.zeros((1,1))
state_placeholder = tf.placeholder(tf.float32, shape=[1, 1])
output, state = rnn(x, initial_state=state_placeholder)
然后,在执行输出时,除了State
的{{1}}之外,我们还需要传递Input
。
代码,
feed_dict
可以替换为
print('keras op hash',xxhash.xxh64(sess.run(op_lstm, feed_dict={x_placeholder:op_embed})).hexdigest())
希望这会有所帮助。学习愉快!