Tensorflow:使用输入队列训练和测试相同的图形

时间:2017-05-24 15:59:55

标签: testing tensorflow dataset training-data

我面临的一个问题无法解决我在互联网上找到的问题。

我已经构建了我的神经网络并将其连接到输入管道。 从tfrecord读取数据,使用tf.train.batch和queueRunners,Coords等。

我已经将我的NN构建到一个名为"模型"的python类中。我用的是:

  

model = Model(...这里所有超参数...)

...

  

model.predict()

  

model.step()

所有培训阶段都很有效。

但现在我想在每个X纪元/训练步骤中添加一个测试阶段。

我真的不知道该怎么做。 我有几点想法,但我找不到最好的想法:

  • 将代码复制到我的类中以获取:loss_train和loss_test,依此类推我的图形的每个节点? (使用火车和测试之间的共享变量)
  • 创建我的模型的2个实例:
  

model_train = Model(reuse = false)

     

model_test = Model(reuse = true)

  • 使用tf.make_template?我真的没有找到这个功能的任何好例子......
  • 任何其他解决方案?

我很感激任何建议,

1 个答案:

答案 0 :(得分:1)

我在尝试TFRecords数据集时遇到了同样的问题。有几种可能性。因为我想在只有一个GPU的计算机上执行此操作,我实现如下:

# Training Dataset
train_dataset = tf.contrib.data.TFRecordDataset(train_files)
train_dataset = train_dataset.map(parse_function)
train_dataset = train_dataset.shuffle(buffer_size=10000)
train_dataset = train_dataset.batch(200)
# Validation Dataset
validation_dataset = tf.contrib.data.TFRecordDataset(val_files)
validation_dataset = validation_dataset.map(parse_function)
validation_dataset = validation_dataset.batch(200)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.contrib.data.Iterator.from_string_handle(handle,
 train_dataset.output_types, train_dataset.output_shapes)
next_element = iterator.get_next()

# Generate the Iterators
training_iterator = train_dataset.make_initializable_iterator()
validation_iterator = validation_dataset.make_one_shot_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

然后,为了访问元素,你可以像:

img, lbl = sess.run(next_element, feed_dict={handle: training_handle})

根据您愿意做的ATM交换句柄。

但请记住,这不可并行化。通过此链接,您可以深入了解创建多个输入管道的不同方法Tensorflow | Reading Data