Tensorflow:如何在训练中我想要的步骤中保存模型

时间:2017-06-05 05:58:24

标签: python tensorflow

这是一个问题,我希望我的代码可以保存模型每100步,我的TRAIN_STEPS是3000,所以应该保存近30个模型,但它只保存最后5个模型。检查点的详细信息是:

model_checkpoint_path: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2900"
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2500"
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2600"
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2700"
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2800"
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2900"

只保存那5个型号。我不知道为什么。有人告诉我?这是我的代码

# coding=utf-8
from  color_1 import read_and_decode, get_batch, get_test_batch
import color_inference
import cv2
import os
import time
import numpy as np
import tensorflow as tf

batch_size=128
TRAIN_STEPS=3000
crop_size=56
MOVING_AVERAGE_DECAY=0.99
num_examples=50000
LEARNING_RATE_BASE=0.8
LEARNING_RATE_DECAY=0.99
MODEL_SAVE_PATH="/home/vrview/tensorflow/example/char/tfrecords/color/"
MODEL_NAME="model.ckpt"

def train(batch_x,batch_y):
    image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')
    label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input')
    image_input = tf.reshape(image_holder, [-1, 56, 56, 3])

    y=color_inference.inference(image_holder)
    global_step=tf.Variable(0,trainable=False)

    def loss(logits, labels):
        labels = tf.cast(labels, tf.int64)
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=labels, name='cross_entropy_per_example')

        cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
        tf.add_to_collection('losses', cross_entropy_mean)
        return tf.add_n(tf.get_collection('losses'), name='total_loss')

    loss = loss(y, label_holder)
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

    saver=tf.train.Saver()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(TRAIN_STEPS):
            image_batch, label_batch = sess.run([batch_x, batch_y])
            _, loss_value,step = sess.run([train_op, loss,global_step], feed_dict={image_holder: image_batch,
                                                                                   label_holder:label_batch})
            if i % 100 == 0:
                format_str=('After %d step,loss on training batch is: %.2f')
                print (format_str%(i,loss_value))
                saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=i)
        coord.request_stop()  
        coord.join(threads)
def main(argv=None):
    image, label = read_and_decode('train.tfrecords')
    batch_image, batch_label = get_batch(image, label, batch_size, crop_size)  # batch 生成测试
    train(batch_image,batch_label)
if __name__=='__main__':
    tf.app.run()

1 个答案:

答案 0 :(得分:1)

max_to_keep=30添加到您的保护程序的构造函数中,默认值为5,这就是您只保存5次的原因