我在下面粘贴了部分代码(所有指向张量板部分的部分)。我只记录损失标量变量,并且仅在一个时期添加一次摘要。我总共跑了3个纪元。理想情况下,该文件应该是非常小的tfevents文件。但是,tfevents文件为1.3GB。我不确定是什么导致文件太大。 愿意在需要时分享其余代码
def do_training(update_op, loss, summary_op):
writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
step = 0
while True:
if step % (X_train.shape[0]/batch_size) == 0:
_, loss_value = sess.run((update_op, loss))
summary = sess.run(summary_op)
writer.add_summary(summary, global_step=step)
print('Step {} with loss {}'.format(step, loss_value))
else:
_, loss_value = sess.run((update_op, loss))
step += 1
except tf.errors.OutOfRangeError:
# we're through the dataset
pass
writer.close()
saver.save(sess,save_path)
print('Final loss: {}'.format(loss_value))
def serial_training(model_fn, dataset):
iterator = dataset.make_one_shot_iterator()
loss = model_fn(lambda: iterator.get_next())
tf.summary.scalar("loss", loss)
summary_op = tf.summary.merge_all()
optimizer = tf.train.AdamOptimizer(learning_rate=0.0002)
global_step = tf.train.get_or_create_global_step()
update_op = optimizer.minimize(loss, global_step=global_step)
do_training(update_op, loss, summary_op)
tf.reset_default_graph()
serial_training(training_model,training_dataset(epochs=3,batch_size=batch_size))