我想修改一个预训练的模型然后微调它。我能够在tensorflow中加载图形。但是当我写新图层时,我的图形形状会意外地改变。代码很长但是这里是
with tf.Session() as persisted_sess:
graph_file = os.path.join("./tmp/my-model", "input_graph.pb")
with tf.gfile.FastGFile(graph_file,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
persisted_sess.graph.as_default()
tf.import_graph_def(graph_def, name="")
# inputs
x = tf.placeholder(tf.float32, [None, width*height],name="first_var")
y_ = tf.placeholder(tf.float32, [None, nClass],name="second_var")
z_ = tf.placeholder(tf.float32, [None, nClass],name="third_var")
W_conv1 = persisted_sess.graph.get_tensor_by_name("W_conv1:0")
b_conv1 = persisted_sess.graph.get_tensor_by_name("b_conv1:0")
x_image = tf.reshape(x, [-1,width,height,1],name="input_node")
W_conv2=persisted_sess.graph.get_tensor_by_name("W_conv2:0")
b_conv2 = persisted_sess.graph.get_tensor_by_name("b_conv2:0")
W_fc1= persisted_sess.graph.get_tensor_by_name("W_fc1:0")
b_fc1 = persisted_sess.graph.get_tensor_by_name("b_fc1:0")
W_fc2= persisted_sess.graph.get_tensor_by_name("W_fc2:0")
b_fc2 = persisted_sess.graph.get_tensor_by_name("b_fc2:0")
W_fc3=tf.get_variable("W_fc3", shape=[ nClass + nClass, nClass],initializer=tf.contrib.layers.xavier_initializer())
b_fc3 = tf.get_variable("b_fc3", shape=[nClass],initializer=tf.contrib.layers.xavier_initializer())
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
h_pool2_flat = tf.reshape(h_pool2, [-1, (width/4) * (height/4) * nFeatures2])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
h_fc2=tf.matmul(h_fc1, W_fc2) + b_fc2
keep_prob=tf.placeholder(tf.float32)
h_fc2_drop = tf.nn.dropout(h_fc2, keep_prob)
fc3_z_flat=tf.concat(1,[h_fc2_drop,z_])
h_fc3=tf.matmul(fc3_z_flat,W_fc3)+b_fc3
y=tf.nn.softmax(h_fc3,name="output_node")
N_training=100000
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
#tf.summary.scalar("cross_entropy",cross_entropy)
# quadratic_cost= tf.scalar_mul(1.0/(N_training*2.0),tf.reduce_sum(tf.squared_difference(y,y_)))
# tf.summary.scalar("quadratic_cost",quadratic_cost)
# define training step which minimises cross entropy
optimizer = tf.train.AdamOptimizer(learning_rate=2e-5,epsilon=0.004)
#optimizer = tf.train.GradientDescentOptimizer (2e-5,False,'GradientDescent')
train_step = optimizer.minimize(cross_entropy)
tf.summary.scalar("learning_rate",optimizer._lr)
# argmax gives index of highest entry in vector (1st axis of 1D tensor)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
# get mean of all entries in correct prediction, the higher the better
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
tf.summary.scalar("accuracy",accuracy)
train_writer = tf.summary.FileWriter("./tmp/trainlog", persisted_sess.graph)
merged = tf.summary.merge_all()
persisted_sess.run(tf.global_variables_initializer())
checkpoint_file = os.path.join("./tmp/ckpt", "saved_checkpoint4999-0")
saver = tf.train.Saver( [W_conv1,b_conv1,W_conv2,b_conv2,W_fc1,b_fc1,W_fc2,b_fc2])
saver.restore(persisted_sess, checkpoint_file)
# associate the "label" and "image" objects with the corresponding features read from
# a single example in the training data file
label, image, question = getImage("./../Data/train-00000-of-00001")
# and similarly for the validation data
vlabel, vimage, vquestion = getImage("./../Data/validation-00000-of-00001")
#print('Size %s' % vlabel)
# associate the "label_batch" and "image_batch" objects with a randomly selected batch---
# of labels and images respectively
imageBatch, labelBatch, questionBatch = tf.train.shuffle_batch(
[image, label, question], batch_size=1500,
capacity=2000,
min_after_dequeue=1500)
# and similarly for the validation data
vimageBatch, vlabelBatch, vquestionBatch = tf.train.shuffle_batch(
[vimage, vlabel, vquestion], batch_size=1500,
capacity=2000,
min_after_dequeue=1500)
tf.train.write_graph(persisted_sess.graph.as_graph_def(), FLAGS.model_dir, 'new_graph',True)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=persisted_sess,coord=coord)
tf.global_variables_initializer().run
所有图层几乎相同,只是添加了W_fc3和W_fc2。我在哪里做错了?