下面是定义数据集的代码部分:
filenames = ['./train_data.tfrecords','./test_data.tfrecords']
train_dataset = tf.data.TFRecordDataset(filenames[0])
train_dataset = train_dataset.map(parser, num_parallel_calls=batch_size)
train_dataset = train_dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
train_dataset = train_dataset.batch(batch_size).repeat(50)
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(train_dataset)
下面是使用代码的部分(删除了一些用于记录的行):
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
sess.run(training_init_op)
train_iter=0
while True:
try:
l, _, acc, summary = sess.run([cost, optimizer, accuracy, merged],feed_dict={ keep_prob: 0.7})
train_iter += 1
if train_iter%20==0:
train_summary_writer.add_summary(summary,train_iter)
except tf.errors.OutOfRangeError:
print("Epoch: {}, loss: {:.3f}, training accuracy: {:.2f}%".format(epochs, l, acc * 100))
sess.run(testing_init_op)
while True:
try:
acc, summary = sess.run([accuracy, merged],feed_dict={ keep_prob: 1})
except tf.errors.OutOfRangeError:
print("Average validation set accuracy is {:.2f}%".format((avg_acc / counter) * 100))
break
break
也许可以通过使用批处理大小和样本数量来了解每个时期的完成情况,但是有没有办法使用内置的tensorflow变量来了解这一点?