冻结和恢复用tf.train.shuffle_batch训练的张量流模型

时间:2018-06-08 17:23:03

标签: python tensorflow inference tfrecord

我正在学习如何将 tfrecords tf.train.shuffle_batch 一起使用,我还使用了队列和线程。我以这种方式训练我的模型(图像classificator)没有任何问题,因为有很多关于如何做到这一点的文档。我可以保存并恢复提供相同tfrecord管道的模型,但我现在只想在将模型保存到tfrecord后将模型与单个图像(不是.pb)一起使用格式(保护架构信息)。

这就是我的所作所为:

queue_loader=Queue_loader(inputname,batch_size)
model=MyModel()
model.build(inputimages)
model.loss(inputlabels)
train_op = model.train()
saver = tf.train.Saver()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoints) 

在每个纪元后我保存:

saver.save(sess, checkpoint, global_step=epoch)

queue_loader是一个将tf.train.shuffle_batch与tfrecords和队列一起调用的函数。  我不明白的是如何使用我的检查点文件与freeze.py冻结我的模型并在第二次使用它。我想做的是使用与使用占位符和feed_dictionary训练的模型相同的方法,但现在使用tf.train.shuffle_batch时我没有任何占位符,所以我不知道哪个图形的节点作为输入节点保存到我的图形中。事实上,我的模型中有这样的东西:

def build(self,X): conv=tf.nn.relu((tf.nn.conv2d(X, W, strides=[1, 1, 1, 1]) + b))

因此,使用tf.train.shuffle_batch定义的输入管道可以防止使用pkaceholders,因此在使用脚本freeze.py定义我感兴趣的节点时,我不知道输入节点在哪里。

另一个问题是,即使仅使用 freeze.py 保存输出节点,我也不能像以前一样使用我的模型,因为当使用sess.run(feed_dictionary)访问它时,它开始永远等待一些线程被输入。我真的很困惑,你能给我一个提示或示例代码吗?

非常感谢您提前

0 个答案:

没有答案