sess.run()中的第二个feed_dict,培训时的评估

时间:2018-03-25 23:52:54

标签: python tensorflow conv-neural-network

臭名昭着:

*InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [100,8]
         [[Node: Placeholder_1 = Placeholder[dtype=DT_FLOAT, shape=[100,8], _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]*

我一直在训练CNN一段时间了。我有训练和评估方法的单独python脚本,并希望将它们结合起来。我希望网络训练N次迭代,然后运行评估,检查模型是否比以前的任何模型更好,如果是,请保存,然后恢复训练。

自组合它们以来,网络将进行训练,但一旦进入循环的“验证”部分,它就会在sess.run失败,我必须在其中提供验证数据。

sess.run中有一些调试打印(“hi”)s,错误出现在第3个之后的行上,但是在第4个“hi”之前。这是验证数据发生feed_dict的地方。

我已经尝试过只使用一个图形,并且还使用它自己的v_image_out,v_label_out等构建一个单独的图形,但没有用。几天来我一直在努力解决这个错误,改变了很多事情。任何帮助是极大的赞赏。代码已经过修改,以符合我的目的:https://github.com/yeephycho/tensorflow_input_image_by_tfrecord/blob/master/src/flower_train_cnn.py

*编辑:它调用错误的确切行是:

     File "train_cnn_4_pools.py", line 295, in train
        v_accuracy_out, v_logits_batch_out, v_summary = sess.run(
[accuracy, logits_batch, merged_summary_op],
 feed_dict={image_batch_placeholder: v_image_out, label_tensor_placeholder: `v_label_out})`

引起:

  File "train_cnn_4_pools.py", line 317, in <module>
    train()
  File "train_cnn_4_pools.py", line 226, in train
    label_batch_placeholder = tf.placeholder(tf.float32, shape=[BATCH_SIZE, NUM_CLASSES])

这非常奇怪,因为我在验证循环中调用label_tensor_placeholder失败了。不是因为的错误而发生错误的label_batch_placeholder!为什么要尝试调用错误的占位符?!

def train():
image_batch_out, label_batch_out, filename_batch = input(if_eval = False)
v_image_batch_out, v_label_batch_out, v_filename_batch = input(if_eval = True)

image_batch_placeholder = tf.placeholder(tf.float32, shape=[BATCH_SIZE, None, None, 3])
image_batch = tf.reshape(image_batch_out, (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3))
v_image_batch = tf.reshape(v_image_batch_out, (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3))

label_batch_placeholder = tf.placeholder(tf.float32, shape=[BATCH_SIZE, NUM_CLASSES])
label_tensor_placeholder = tf.placeholder(tf.int64, shape=[BATCH_SIZE])
label_offset = -tf.ones([BATCH_SIZE], dtype=tf.int64, name="label_batch_offset")
label_batch_one_hot = tf.one_hot(tf.add(label_batch_out, label_offset), depth=NUM_CLASSES, on_value=1.0, off_value=0.0)
label_batch = tf.add(label_batch_out, label_offset)
v_label_batch = tf.add(v_label_batch_out, label_offset)
with tf.variable_scope("inference") as scope:

    logits_out = network(image_batch)
    scope.reuse_variables()
    v_logits_out = network(v_image_batch)

logits_batch = tf.to_int64(tf.arg_max(v_logits_out, dimension = 1))
#loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=label_batch_one_hot, logits=logits_out))
prediction_op = tf.nn.softmax(logits_out)
v_prediction_op = tf.nn.softmax(v_logits_out)

correct_prediction = tf.equal(logits_batch, label_tensor_placeholder)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

loss = tf.losses.mean_squared_error(labels=label_batch_placeholder, predictions=prediction_op)
global_step = tf.Variable(0, name='global_step', trainable=False)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

variable_summaries(prediction_op, name="Predictions")
surity_summary(prediction_op, name="Certainty")
#surity = surity_calc(prediction_op)
tf.summary.scalar("loss", loss)
merged_summary_op = tf.summary.merge_all()

with tf.Session() as sess:
    # Visualize the graph through tensorboard.
    #file_writer = tf.summary.FileWriter("C:/logs", sess.graph)
    # op to write logs to Tensorboard
    summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
    accuracy_accu = 0
    best_accu = 0
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.save(sess, chk_path)
    saver.restore(sess, chk_path)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess = sess)
    for epoch in range(1, 1000):
        for i in range(5):
            image_out, label_out, label_batch_one_hot_out, filename_out = sess.run([image_batch, label_batch_out, label_batch_one_hot, filename_batch])

            _, infer_out, loss_out, summary, predict_out, global_step_out = sess.run([train_step, logits_out, loss, merged_summary_op, prediction_op, global_step], feed_dict={image_batch_placeholder: image_out, label_batch_placeholder: label_batch_one_hot_out})


            #print(image_out.shape)
            print("label_out: ")
            #print(filename_out)
            print(label_out)
            #print(label_batch_one_hot_out)
            #print("infer_out: ")
            #print(infer_out)
            print("prediction: ")
            print(predict_out)
            print("loss: " + str(loss_out))
            print("local step: " + str(i))
            print("global step: " + str(global_step_out - 1))
            print("epoch: " + str(epoch))
            if(i%10 == 0):
                summary_writer.add_summary(summary, global_step_out)
        print("hi")
        for p in range(int(v_TRAINING_SET_SIZE/BATCH_SIZE)):
            print("hi")
            v_image_out, v_label_out, v_filename_out = sess.run([v_image_batch, v_label_batch, v_filename_batch])
            print("hi")
            v_accuracy_out, v_logits_batch_out, v_summary = sess.run([accuracy, logits_batch, merged_summary_op], feed_dict={image_batch_placeholder: v_image_out, label_tensor_placeholder: label_out})
            print("hi")
            accuracy_accu += accuracy_out

            print(p)
            print(v_image_out.shape)
            print("label_out: ")
            print(v_filename_out)
            print(v_label_out)
            print(v_logits_batch_out)
            print("accuracy: ", v_accuracy_out)
            summary_writer.add_summary(v_summary, p)

        print("Accuracy: ")
        print((accuracy_accu/TRAINING_SET_SIZE)*100)
        if(accuracy_accu > best_accu):
            saver.save(sess, chk_path)
            best_accu = accuracy_accu
    coord.request_stop()
    coord.join(threads)
    sess.close()

列车()

1 个答案:

答案 0 :(得分:0)

回答了我自己的问题:对于有这个问题的其他人来说,事实证明图表的tf.summary部分需要训练标签占位符!从验证循环中删除它会立即修复问题,而添加调整摘要图解决了问题。抱歉混乱SO。