Tensorflow - 我是否正确恢复了模型?

时间:2017-03-27 22:58:29

标签: python tensorflow neural-network conv-neural-network

我有以下代码正常工作(没有错误)。我的问题是我是否正确恢复模型?特别是我看不到语句print(v_)的任何输出。

所以,我想知道我是否正确地做以下事情:

  1. 恢复模型
  2. 使用已恢复的模型

    将tensorflow导入为tf

    data, labels = cifar_tools.read_data('C:\\Users\\abc\\Desktop\\Testing')
    
    x = tf.placeholder(tf.float32, [None, 150 * 150])
    y = tf.placeholder(tf.float32, [None, 2])
    
    w1 = tf.Variable(tf.random_normal([5, 5, 1, 64]))
    b1 = tf.Variable(tf.random_normal([64]))
    
    w2 = tf.Variable(tf.random_normal([5, 5, 64, 64]))
    b2 = tf.Variable(tf.random_normal([64]))
    
    w3 = tf.Variable(tf.random_normal([38*38*64, 1024]))
    b3 = tf.Variable(tf.random_normal([1024]))
    
    w_out = tf.Variable(tf.random_normal([1024, 2]))
    b_out = tf.Variable(tf.random_normal([2]))
    
    def conv_layer(x,w,b):
        conv = tf.nn.conv2d(x,w,strides=[1,1,1,1], padding = 'SAME')
        conv_with_b = tf.nn.bias_add(conv,b)
        conv_out = tf.nn.relu(conv_with_b)
        return conv_out
    
    def maxpool_layer(conv,k=2):
        return tf.nn.max_pool(conv, ksize=[1,k,k,1], strides=[1,k,k,1], padding='SAME')
    
    def model():
        x_reshaped = tf.reshape(x, shape=[-1, 150, 150, 1])
    
        conv_out1 = conv_layer(x_reshaped, w1, b1)
        maxpool_out1 = maxpool_layer(conv_out1)
        norm1 = tf.nn.lrn(maxpool_out1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
        conv_out2 = conv_layer(norm1, w2, b2)
        norm2 = tf.nn.lrn(conv_out2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
        maxpool_out2 = maxpool_layer(norm2)
    
        maxpool_reshaped = tf.reshape(maxpool_out2, [-1, w3.get_shape().as_list()[0]])
        local = tf.add(tf.matmul(maxpool_reshaped, w3), b3)
        local_out = tf.nn.relu(local)
    
        out = tf.add(tf.matmul(local_out, w_out), b_out)
        return out
    
    model_op = model()
    
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(model_op, y))
    train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
    
    correct_pred = tf.equal(tf.argmax(model_op, 1), tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        onehot_labels = tf.one_hot(labels, 2, on_value=1.,off_value=0.,axis=-1)
        onehot_vals = sess.run(onehot_labels)
        batch_size = len(data)
        # Restore model
        saver = tf.train.import_meta_graph('C:\\Users\\abc\\Desktop\\\Testing\\mymodel.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        all_vars = tf.get_collection('vars')
        for v in all_vars:
            v_ = sess.run(v)
            print(v_)
    
    for j in range(0, 5):
        print('EPOCH', j)
        for i in range(0, len(data), batch_size):
            batch_data = data[i:i+batch_size, :]
            batch_onehot_vals = onehot_vals[i:i+batch_size, :]
            _, accuracy_val = sess.run([train_op, accuracy], feed_dict={x: batch_data, y: batch_onehot_vals})
            print(i, accuracy_val)
    
        print('DONE WITH EPOCH')
    
  3. 编辑1

    恢复这种方式会有效吗?

    saver = tf.train.Saver()
    saver = tf.train.import_meta_graph('C:\\Users\\Abder-Rahman\\Desktop\\\Testing\\mymodel.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    print('model restored'
    

    编辑2

    这就是我保存我的模型的方式:

    #Save model
    saver = tf.train.Saver()
    saved_path = saver.save(sess, 'C:\\Users\\abc\\Desktop\\\Testing\\mymodel')
    print("The model is in this file: ", saved_path)
    

    感谢。

3 个答案:

答案 0 :(得分:2)

您的保护代码是正确的。 虽然变量必须在检索集合之前添加到集合中。 tf.add_to_collection("vars", w1) tf.add_to_collection("vars", b1) ... 然后  all_vars = tf.get_collection('vars')

答案 1 :(得分:1)

通常我会像这样恢复TensorFlow模型:

 with tf.Session(graph=graph) as session:
    if os.path.exists(save_path):
        # Restore variables from disk.
        saver.restore(session, save_path)
    else:
        tf.initialize_all_variables().run()
        print('Initialized')

    # do the work
    # ... 
 saver.save(session, save_path)   # save the model

示例代码可以是fetch here

我需要了解更多关于如何保存模型的信息,似乎您的模型在保存之前已经恢复,而且您的模型没有转到tf.graph并与会话连接。

答案 2 :(得分:0)

我假设您已经阅读了我的博客here,模型保存的机制非常简单,当您加载模型时,参数值和关系(可能是您关心的)可以通过变量名匹配

例如

all_vars = tf.get_collection('vars')

让我困惑的是,您使用了函数tf.all_variables(),但您从未定义过名为“vars”的范围。您可能应首先使用{{1}}进行测试。