使用TFRecord时,如何运行中间验证检查? (更好的方法?)

时间:2017-05-30 19:13:57

标签: python tensorflow

我们说我定义了一个网络Net,下面的示例代码运行良好。

# ... input processing using TFRecord ...     # reading from TFRecord
x, y = tf.train.batch([image, label])         # encode batch
net = Net(x,y)                                # connect to network

# ... initialize and session ...
for iteration:
    loss, _ = sess.run([net.loss, net.train_op])

Net没有tf.placeholder,因为输入是由TFRecord提供商的张量提供的。 如果我想运行验证集,例如,每500步,该怎么办?如何切换输入流?

x, y = tf.train.batch([image, label], ...)      # training set
vx, vy = tf.train.batch([vimage, vlabel], ...)  # validation set
net = Net(x,y)

for iteration:
    loss, _ = sess.run([net.loss, net.train_op])

    if step % 500 == 0:
         # graph is already defined from input to loss.
         # how can I run net.loss with vx and vy??

我能想象的只有一件事是,将Net修改为占位符,每次都像

一样运行
sess.run([...], feed_dict = {Net.x:sess.run(x), Net.y:sess.run(y)})
sess.run([...], feed_dict = {Net.x:sess.run(vx), Net.y:sess.run(vy)})

然而,在我看来,我失去了使用TFRecord的好处(例如,完整的TF集成)。在计算流程的中间,我必须停止流程,运行tf.sess,然后继续(通过强制在中间使用CPU来解决这个较低的速度?)

我想知道,

  1. 如果有更好的方法。
  2. 如果我的解决方案没有我想象的那么糟糕。
  3. 提前致谢。

1 个答案:

答案 0 :(得分:1)

有一种更好的方法(比占位符)。我在TensorFlow中使用CIFAR10教程遇到了这个问题,我调整了这个问题,每500个批次左右同时检查测试集的准确性。这是共享变量派上用场的地方。

x, y = tf.train.batch([image, label], ...) # training set
vx, vy = tf.train.batch([vimage, vlabel], ...) # validation set

with tf.variable_scope("model") as scope:
  net = Net(x,y)
  scope.reuse_variables()
  vnet = Net(vx,vy) 

for iteration:
  loss, _ = sess.run([net.loss, net.train_op])

  if step % 500 == 0:
    loss, acc = sess.run([vnet.loss, vnet.accuracy])

通过在第二次调用Net()时将范围设置为重用变量,您将使用在第一次调用中创建但具有不同输入集的相同张量和值。只需确保vimage和vlabel不会重复使用图像和标签中的张量(这可能通过创建自己的变量范围来解决)。