如何修改预训练图? Tensorflow

时间:2017-06-13 01:59:37

标签: tensorflow deep-learning resuming-training


我想修改一个预训练的模型然后微调它。我能够在tensorflow中加载图形。但是当我写新图层时,我的图形形状会意外地改变。代码很长但是这里是

with tf.Session() as persisted_sess:

    graph_file = os.path.join("./tmp/my-model", "input_graph.pb")
    with tf.gfile.FastGFile(graph_file,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        persisted_sess.graph.as_default()
        tf.import_graph_def(graph_def, name="")

# inputs    
    x = tf.placeholder(tf.float32, [None, width*height],name="first_var")
    y_ = tf.placeholder(tf.float32, [None, nClass],name="second_var")
    z_ = tf.placeholder(tf.float32, [None, nClass],name="third_var")

    W_conv1 = persisted_sess.graph.get_tensor_by_name("W_conv1:0")
    b_conv1 = persisted_sess.graph.get_tensor_by_name("b_conv1:0")
    x_image = tf.reshape(x, [-1,width,height,1],name="input_node")

    W_conv2=persisted_sess.graph.get_tensor_by_name("W_conv2:0")
    b_conv2 = persisted_sess.graph.get_tensor_by_name("b_conv2:0")  

    W_fc1= persisted_sess.graph.get_tensor_by_name("W_fc1:0")
    b_fc1 = persisted_sess.graph.get_tensor_by_name("b_fc1:0")  

    W_fc2= persisted_sess.graph.get_tensor_by_name("W_fc2:0")
    b_fc2 = persisted_sess.graph.get_tensor_by_name("b_fc2:0")

    W_fc3=tf.get_variable("W_fc3", shape=[ nClass + nClass, nClass],initializer=tf.contrib.layers.xavier_initializer())
    b_fc3 = tf.get_variable("b_fc3", shape=[nClass],initializer=tf.contrib.layers.xavier_initializer())  



    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)

    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)

    h_pool2_flat = tf.reshape(h_pool2, [-1, (width/4) * (height/4) * nFeatures2])

    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
    h_fc2=tf.matmul(h_fc1, W_fc2) + b_fc2

    keep_prob=tf.placeholder(tf.float32)

    h_fc2_drop = tf.nn.dropout(h_fc2, keep_prob)

    fc3_z_flat=tf.concat(1,[h_fc2_drop,z_])
    h_fc3=tf.matmul(fc3_z_flat,W_fc3)+b_fc3

    y=tf.nn.softmax(h_fc3,name="output_node")
    N_training=100000
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    #tf.summary.scalar("cross_entropy",cross_entropy)
#    quadratic_cost= tf.scalar_mul(1.0/(N_training*2.0),tf.reduce_sum(tf.squared_difference(y,y_)))
#   tf.summary.scalar("quadratic_cost",quadratic_cost)

   # define training step which minimises cross entropy
    optimizer = tf.train.AdamOptimizer(learning_rate=2e-5,epsilon=0.004)
    #optimizer = tf.train.GradientDescentOptimizer (2e-5,False,'GradientDescent')
    train_step = optimizer.minimize(cross_entropy)
    tf.summary.scalar("learning_rate",optimizer._lr)

   # argmax gives index of highest entry in vector (1st axis of 1D tensor)
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

   # get mean of all entries in correct prediction, the higher the better
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
    tf.summary.scalar("accuracy",accuracy)

    train_writer = tf.summary.FileWriter("./tmp/trainlog", persisted_sess.graph)
    merged = tf.summary.merge_all()



    persisted_sess.run(tf.global_variables_initializer())
    checkpoint_file = os.path.join("./tmp/ckpt", "saved_checkpoint4999-0")
    saver = tf.train.Saver( [W_conv1,b_conv1,W_conv2,b_conv2,W_fc1,b_fc1,W_fc2,b_fc2])
    saver.restore(persisted_sess, checkpoint_file)
    # associate the "label" and "image" objects with the corresponding features read from 
    # a single example in the training data file
    label, image, question = getImage("./../Data/train-00000-of-00001")
    # and similarly for the validation data
    vlabel, vimage, vquestion = getImage("./../Data/validation-00000-of-00001")
    #print('Size %s' % vlabel)
    # associate the "label_batch" and "image_batch" objects with a randomly selected batch---
    # of labels and images respectively

    imageBatch, labelBatch, questionBatch = tf.train.shuffle_batch(
        [image, label, question], batch_size=1500,
        capacity=2000,
        min_after_dequeue=1500)

   # and similarly for the validation data 
    vimageBatch, vlabelBatch, vquestionBatch = tf.train.shuffle_batch(
        [vimage, vlabel, vquestion], batch_size=1500,
        capacity=2000,
        min_after_dequeue=1500)



    tf.train.write_graph(persisted_sess.graph.as_graph_def(), FLAGS.model_dir, 'new_graph',True)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=persisted_sess,coord=coord)
    tf.global_variables_initializer().run

所有图层几乎相同,只是添加了W_fc3和W_fc2。我在哪里做错了?

0 个答案:

没有答案