我上了一个可用的here教程,我试图在我的数据集上运行它,它能够编译并开始训练,但是这里我得到了:
该模型似乎没有在每次迭代中保存。 我尝试了100个纪元,但没有任何改变,它给出了第一次迭代的输出。 您知道会有什么问题吗? (我知道代码很抱歉)
def train(model, epochs, log_string):
'''Train the RNN'''
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Used to determine when to stop the training early
valid_loss_summary = []
# Keep track of which batch iteration is being trained
iteration = 0
print()
print("Training Model: {}".format(log_string))
train_writer = tf.summary.FileWriter('./logs/3/train/{}'.format(log_string), sess.graph)
valid_writer = tf.summary.FileWriter('./logs/3/valid/{}'.format(log_string))
for e in range(epochs):
state = sess.run(model.initial_state)
# Record progress with each epoch
train_loss = []
train_acc = []
val_acc = []
val_loss = []
with tqdm(total=len(x_train)) as pbar:
for _, (x, y) in enumerate(get_batches(x_train, y_train, batch_size), 1):
feed = {model.inputs: x,
model.labels: y[:, None],
model.keep_prob: dropout,
model.initial_state: state}
summary, loss, acc, state, _ = sess.run([model.merged,
model.cost,
model.accuracy,
model.final_state,
model.optimizer],
feed_dict=feed)
# Record the loss and accuracy of each training batch
train_loss.append(loss)
train_acc.append(acc)
# Record the progress of training
train_writer.add_summary(summary, iteration)
iteration += 1
pbar.update(batch_size)
avg_train_loss = np.mean(train_loss)
avg_train_acc = np.mean(train_acc)
val_state = sess.run(model.initial_state)
with tqdm(total=len(x_valid)) as pbar:
for x, y in get_batches(x_valid, y_valid, batch_size):
feed = {model.inputs: x,
model.labels: y[:, None],
model.keep_prob: 1,
model.initial_state: val_state}
summary, batch_loss, batch_acc, val_state = sess.run([model.merged,
model.cost,
model.accuracy,
model.final_state],
feed_dict=feed)
# Record the validation loss and accuracy of each epoch
val_loss.append(batch_loss)
val_acc.append(batch_acc)
pbar.update(batch_size)
# Average the validation loss and accuracy of each epoch
avg_valid_loss = np.mean(val_loss)
avg_valid_acc = np.mean(val_acc)
valid_loss_summary.append(avg_valid_loss)
# Record the validation data's progress
valid_writer.add_summary(summary, iteration)
# Print the progress of each epoch
print("Epoch: {}/{}".format(e, epochs),
"Train Loss: {:.3f}".format(avg_train_loss),
"Train Acc: {:.3f}".format(avg_train_acc),
"Valid Loss: {:.3f}".format(avg_valid_loss),
"Valid Acc: {:.3f}".format(avg_valid_acc))
# Stop training if the validation loss does not decrease after 3 epochs
if avg_valid_loss > min(valid_loss_summary):
print("No Improvement.")
stop_early += 1
if stop_early == 3:
break
# Reset stop_early if the validation loss finds a new low
# Save a checkpoint of the model
else:
print("New Record!")
stop_early = 0
checkpoint = "sauvegarde/controverse_{}.ckpt".format(log_string)
saver.save(sess,checkpoint)
非常感谢您的回答:)