在卷积神经网络中保存和恢复检查点

时间:2016-07-14 17:34:35

标签: python-2.7 image-processing tensorflow conv-neural-network

我尝试在检查点保存我的网络培训,而不是每次测试时都要进行培训。所以当我运行测试文件时,我不知道我的代码中有什么问题,它会再次训练。任何身体都可以帮助我吗?

这是火车档案

saver = tf.train.Saver()


with tf.Session(graph=graph) as session:

num_steps = 1001

session.run()

print('Initialized')

for step in range(num_steps):

  offset = (step * batch_size) % (train_labels.shape[0] - batch_size)     
  batch_data = train_dataset[offset:(offset + batch_size), :, :, :]

  batch_labels = train_labels[offset:(offset + batch_size), :]
  print("batch_labels",batch_labels)
  feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels}

  _, l, predictions = session.run(
    [optimizer, loss, train_prediction ], feed_dict=feed_dict)

  if (step % 50 == 0):
    print('Minibatch loss at step %d: %f' % (step, l))
    print('Minibatch accuracy: %.1f%%' % accuracy(predictions, batch_labels))
    print('Validation accuracy: %.1f%%' % accuracy(valid_prediction.eval(), valid_labels))

save_path = saver.save(session, "/home/owner//tensorflow/tensorflow/models/image/mnist/new_dataset/models.ckpt")

print("Model saved in file: %s" % save_path)

这是测试文件:

from __future__ import print_function

import numpy as np

import tensorflow as tf

from six.moves import cPickle as pickle

from six.moves import range

import time

from datetime import datetime
import tensorflow as tf


saver = tf.train.Saver()
init = tf.initialize_all_variables()
with tf.Session() as session:

   saver.restore(session ,"/home/owner/tensorflow/tensorflow/models/image/mnist/new_dataset/models.ckpt")
   print("Model restored.")
   print('Test accuracy: %.1f%%' % accuracy(test_prediction.eval() , test_labels, force = False ))

0 个答案:

没有答案