从恢复的Tensorflow变量

时间:2016-06-19 18:41:02

标签: python-2.7 tensorflow checkpointing

我有一个简单的经常性网络示例,其中保存了tf.Saverweightbiasstate个变量。

当示例运行时没有选项时,它会初始化状态向量以包含零,但我想传递一个load_model选项,并使用状态向量的最后一个值作为session.run调用。

我看到的所有文档都坚持认为必须调用session.run来从变量中检索存储的值,但在这种情况下,我想检索这些值,以便初始化状态变量。我是否需要单独执行图表来检索初始化值?

以下示例代码:

import tensorflow as tf
import math
import numpy as np

INPUTS = 10
HIDDEN_1 = 2
BATCH_SIZE = 3

def batch_vm2(m, x):
  [input_size, output_size] = m.get_shape().as_list()

  input_shape = tf.shape(x)
  batch_rank = input_shape.get_shape()[0].value - 1
  batch_shape = input_shape[:batch_rank]
  output_shape = tf.concat(0, [batch_shape, [output_size]])

  x = tf.reshape(x, [-1, input_size])
  y = tf.matmul(x, m)

  y = tf.reshape(y, output_shape)

  return y

def get_weight_and_biases():
    with tf.variable_scope(network_scope, reuse = True) as scope:
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
    return weights, biases

def get_saver():
    with tf.variable_scope('h1') as scope:
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
        saver = tf.train.Saver([weights, biases, state])
    return saver, scope


def load(sess, saver, checkpoint_dir = './'):

        print("loading a session")
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise Exception("no checkpoint found")
        return

iteration = None

def iterate_state(prev_state_tuple, input):
    with tf.variable_scope(network_scope, reuse = True) as scope:
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
        print("input: ",input.get_shape())
        matmuladd = batch_vm2(weights, input) + biases
        matmulpri = tf.Print(matmuladd,[matmuladd, weights], message=" malmul -> %i, weights " % iteration)
        print("prev state: ",prev_state_tuple.get_shape())
        unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
        prev_state = 0.99* unpacked_state
        prev_state = tf.Print(prev_state, [unpacked_state, matmuladd], message=" -> prevstate, matmulpri ")
        state = state.assign( prev_state + 0.01*matmulpri )
        #output = tf.nn.relu(state)
        output = tf.nn.tanh(state)
        state = tf.Print(state, [state], message=" state -> ")
        output = tf.Print(output, [output], message=" output -> ")
        print(" state: ", state.get_shape())
        print(" output: ", output.get_shape())
        concat_result = tf.concat(0,[state, output])
        print (" concat return: ", concat_result.get_shape())
        return concat_result

def data_iter():
    while True:
        idxs = np.random.rand(BATCH_SIZE, INPUTS)
        yield idxs

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean('load_model', False, 'If true, uses model files '
                     'to restore.')


network_scope = None

with tf.Graph().as_default():
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
    iteration = -1
    saver, network_scope = get_saver()
    initial_state = tf.placeholder(tf.float32, shape=(HIDDEN_1))
    initial_out = tf.zeros([HIDDEN_1],
                             name='initial_out')
    concat_tensor = tf.concat(0,[initial_state, initial_out])
    print(" init state: ",initial_state.get_shape())
    print(" init out: ",initial_out.get_shape())
    print(" concat: ",concat_tensor.get_shape())
    scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
    print ("scanout shape: ", scanout.get_shape())
    state, output = tf.split(1,2,scanout, name='split_scan_output')
    print(" end state: ",state.get_shape())
    print(" end out: ",output.get_shape())


    sess = tf.Session()
    # Run the Op to initialize the variables.

    sess.run(tf.initialize_all_variables())
    tf.train.write_graph(sess.graph_def, './tenIrisSave/logsd','graph.pbtxt')
    tf_weight, tf_bias = get_weight_and_biases()
    tf.histogram_summary('weights', tf_weight)
    tf.histogram_summary('bias', tf_bias)
    tf.histogram_summary('state', state)
    tf.histogram_summary('out', output)
    summary_op = tf.merge_all_summaries()
    summary_writer = tf.train.SummaryWriter('./tenIrisSave/summary',sess.graph_def)
    if FLAGS.load_model:
        load(sess, saver)
        # HOW DO I LOAD restored state values??????
        #st = state[BATCH_SIZE - 1,:]
        #st = sess.run([state], feed_dict={})
        print("LOADED last state vec: ", st)
    else:
        st = np.array([0.0 , 0.0])
    iter_ = data_iter()
    for i in xrange(0, 1):
        print ("iteration: ",i)
        iteration = i
        input_data = iter_.next()
        out,st,so,summary_str = sess.run([output,state,scanout,summary_op], feed_dict={ inputs: input_data, initial_state: st })
        saver.save(sess, 'my-model', global_step=1+i)
        summary_writer.add_summary(summary_str, i)
        summary_writer.flush()
        print("input vec: ", input_data)
        print("state vec: ", st)
        st = st[-1]
        print("last state vec: ", st)
        print("output vec: ", out)
        print(" end state (runtime): ",st.shape)
        print(" end out (runtime): ",out.shape)
        print(" end scanout (runtime): ",so.shape)

请注意第124-126行注释行,以了解我尝试初始化Feed字典值的方法。他们都没有工作。

1 个答案:

答案 0 :(得分:1)

您有两个占位符:

  • inputs
  • initial_state

根据我的理解,你想要(取决于FLAGS.load_model):

  1. 使用一个充满零的初始状态

    • 这很简单,你只需要提供一个充满零的numpy数组
  2. 使用state上的最后一行,这是图中的Tensor,具体取决于两个占位符。

    • 您只想加载上一个检查点的值
  3. 完成这个分析之后,我的第一个假设是错误恰好来自于你在行中使用另一个名为state的张量的事实:

    state, output = tf.split(1,2,scanout, name='split_scan_output')
    

    因此,TensorFlow将尝试检索此state,这取决于两个占位符,而不是检索所需的变量state的值。只需重命名第二个就可以了。

    您可以尝试:

    if FLAGS.load_model:
        load(sess, saver)
        with tf.variable_scope('h1', reuse=True)
            state_saved = tf.get_variable('state')
        st = sess.run(state_saved)