当使用具有重复功能的te​​nsorflow数据集模块时,如何在每个时期的末尾获得准确性/损失统计信息?

时间:2018-12-11 08:20:25

标签: python tensorflow tensorflow-datasets

下面是定义数据集的代码部分:

filenames = ['./train_data.tfrecords','./test_data.tfrecords']

train_dataset = tf.data.TFRecordDataset(filenames[0])
train_dataset = train_dataset.map(parser, num_parallel_calls=batch_size)
train_dataset = train_dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
train_dataset = train_dataset.batch(batch_size).repeat(50)

iterator = tf.data.Iterator.from_structure(train_dataset.output_types,  train_dataset.output_shapes)
next_element = iterator.get_next()


training_init_op = iterator.make_initializer(train_dataset)

下面是使用代码的部分(删除了一些用于记录的行):

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()        
    sess.run(init_op)
    sess.run(training_init_op)
    train_iter=0
    while True:
        try:
            l, _, acc, summary = sess.run([cost, optimizer, accuracy, merged],feed_dict={ keep_prob: 0.7})
            train_iter += 1
            if train_iter%20==0:
                train_summary_writer.add_summary(summary,train_iter)
        except tf.errors.OutOfRangeError:
            print("Epoch: {}, loss: {:.3f}, training accuracy: {:.2f}%".format(epochs, l, acc * 100))
            sess.run(testing_init_op)
            while True:
                try:
                    acc, summary = sess.run([accuracy, merged],feed_dict={ keep_prob: 1})
                except tf.errors.OutOfRangeError:
                    print("Average validation set accuracy is {:.2f}%".format((avg_acc / counter) * 100))
                    break
            break

也许可以通过使用批处理大小和样本数量来了解每个时期的完成情况,但是有没有办法使用内置的tensorflow变量来了解这一点?

0 个答案:

没有答案