这是一个问题,我希望我的代码可以保存模型每100步,我的TRAIN_STEPS是3000,所以应该保存近30个模型,但它只保存最后5个模型。检查点的详细信息是:
model_checkpoint_path: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2900"
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2500"
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2600"
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2700"
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2800"
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2900"
只保存那5个型号。我不知道为什么。有人告诉我?这是我的代码
# coding=utf-8
from color_1 import read_and_decode, get_batch, get_test_batch
import color_inference
import cv2
import os
import time
import numpy as np
import tensorflow as tf
batch_size=128
TRAIN_STEPS=3000
crop_size=56
MOVING_AVERAGE_DECAY=0.99
num_examples=50000
LEARNING_RATE_BASE=0.8
LEARNING_RATE_DECAY=0.99
MODEL_SAVE_PATH="/home/vrview/tensorflow/example/char/tfrecords/color/"
MODEL_NAME="model.ckpt"
def train(batch_x,batch_y):
image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')
label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input')
image_input = tf.reshape(image_holder, [-1, 56, 56, 3])
y=color_inference.inference(image_holder)
global_step=tf.Variable(0,trainable=False)
def loss(logits, labels):
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
return tf.add_n(tf.get_collection('losses'), name='total_loss')
loss = loss(y, label_holder)
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)
saver=tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(TRAIN_STEPS):
image_batch, label_batch = sess.run([batch_x, batch_y])
_, loss_value,step = sess.run([train_op, loss,global_step], feed_dict={image_holder: image_batch,
label_holder:label_batch})
if i % 100 == 0:
format_str=('After %d step,loss on training batch is: %.2f')
print (format_str%(i,loss_value))
saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=i)
coord.request_stop()
coord.join(threads)
def main(argv=None):
image, label = read_and_decode('train.tfrecords')
batch_image, batch_label = get_batch(image, label, batch_size, crop_size) # batch 生成测试
train(batch_image,batch_label)
if __name__=='__main__':
tf.app.run()
答案 0 :(得分:1)
将max_to_keep=30
添加到您的保护程序的构造函数中,默认值为5,这就是您只保存5次的原因