tensorflow保存库内的共享变量

时间:2016-05-03 17:08:49

标签: python tensorflow lstm

我组建了一个带有几个隐藏状态的小型LSTM单元。从Tensorflow howtos我能够保存和恢复使用tf.Variable声明的变量的状态。但是,当我调查rnn_cell.py时,我发现存在一个函数:

def linear(args, output_size, bias, bias_start=0.0, scope=None):

并且里面有一个共享变量访问

matrix = vs.get_variable("Matrix", [total_arg_size, output_size])

据我所知,这个矩阵存储权重W_i,W_o,W_f和W_o,因为在线性函数之后,出现:

new_c = c * sigmoid(f + self._forget_bias) + sigmoid(i) * tanh(j)
new_h = tanh(new_c) * sigmoid(o)

所以,我也愿意保存和恢复这个变量。我的问题是这可能发生的地方?

2 个答案:

答案 0 :(得分:0)

对于记录,可以通过深入到变量范围来获得矩阵。 get_variable也需要暗淡的信息:[2 * hidden_size, 4 * hidden_size]

        with tf.variable_scope("RNN", reuse=True):
          with tf.variable_scope("BasicLSTMCell", reuse=True):
            with tf.variable_scope("Linear", reuse=True):
                v1 = tf.get_variable("Matrix", [2 * hidden_size, 4 * hidden_size])
                print(v1.eval())

答案 1 :(得分:0)

您可以通过评估它来访问您的张量。例如,要获取您应该评估的matrix的值,并按以下方式进行评估: ar = sess.run(matrix) for row in ar: for col in row: # your method to save your data 并且您可以构建一个类,其中您的变量用作占位符,您只需使用之前保存的已加载模型来提供它们!