在张量流中的测试时间循环数据集一次

时间:2016-02-25 19:46:25

标签: tensorflow

为了评估测试数据,对数据集进行单次传递的最佳方法是什么?我想避免在python中使用feed_dict编写数据加载脚本。相反,我想使用所有漂亮的TF基础设施进行排队,批处理等。

在cifar example中,测试示例的数量是硬编码的,代码需要num_test_examples/batch_size步才能进行评估。使用批处理基础架构似乎应该有更好的方法来实现这一点。

似乎标准模式是在捕获队列抛出的异常时停止运行。我已经尝试了一些事情,这样当没有更多的例子来填充队列时,队列会抱怨(即制作人不能再生产)。这不是您想要捕获的异常。当消费者没有任何东西要消耗时,你想要捕获,即队列是空的。我该怎么做?

另外,如果测试示例的数量不能被批量大小整除(例如,测试示例的数量是素数),您会怎么做?

其他信息:

在实践中,我们通常通过调用do_evaluation()函数在学习期间多次评估测试数据。如果您只想处理测试数据一次,Yaroslav的答案很有用。理想情况下,每次调用do_evaluation都会在测试数据集中的每个示例上运行一次。我们需要一些机制来重置批处理器,以便您可以再次单次通过它。这里有一些代码。不要使用limit_epochs命令。它需要一个没有洗牌的批处理器并指定测试集中的批次数(如果设置的示例数量不能被minibatchsize整除,则这不起作用)。该函数返回一个新的操作,用于获取当您在整个集合上运行时将抛出tf.errors.OutOfRangeError的数据。第二个返回值是应该调用以重置计数器的操作。这应该是do_evaluation()函数内的第一个调用。

def single_pass(source_batcher,num_batches):
    zero = tf.constant(0, dtype=tf.int64)
    batch_count = tf.Variable(zero, name="epochs", trainable=False)
    limiter = tf.count_up_to(batch_count,num_batches)
    with tf.control_dependencies([limiter]):
      batcher = tf.identity(source_batcher)

    reset = tf.assign(batch_count, zero)

    return batcher, reset

1 个答案:

答案 0 :(得分:1)

您可以为此使用tf.Data API。像这样

import tensorflow as tf

graph = tf.Graph()
sess = tf.Session(graph=graph)


def build_dataset(train_or_test):
    if train_or_test == 'test':
        dataset = tf.data.Dataset.from_tensor_slices(tf.zeros([4, 2]))
    elif train_or_test == 'train':
        dataset = tf.data.Dataset.from_tensor_slices(tf.ones([10, 2]))
    else:
        raise ValueError('wrong name')
    batch_size = 3
    dataset = dataset.batch(batch_size)
    return dataset


def build_inputs():
    train_dataset = build_dataset('train')
    test_dataset = build_dataset('test')
    iterator = tf.data.Iterator.from_structure(
        train_dataset.output_types,
        train_dataset.output_shapes,)
    data = iterator.get_next()
    tf.identity(data, name='data')
    iterator.make_initializer(train_dataset, name='train_init')
    iterator.make_initializer(test_dataset, name='test_init')


def model(inputs):
    return tf.add(inputs, 1, name='output')


def build_graph():
    with graph.as_default():
        build_inputs()
        data = graph.get_tensor_by_name('data:0')
        model(data)


def train():
    train_init = graph.get_operation_by_name('train_init')
    sess.run(train_init)
    out = graph.get_tensor_by_name('output:0')
    while True:
        try:
            network_out = sess.run(out)
            print(network_out.shape)
            print(network_out)
        except tf.errors.OutOfRangeError:
            break


def test():
    test_init = graph.get_operation_by_name('test_init')
    sess.run(test_init)
    out = graph.get_tensor_by_name('output:0')
    while True:
        try:
            network_out = sess.run(out)
            print(network_out.shape)
            print(network_out)
        except tf.errors.OutOfRangeError:
            break


def train_loop():
    cur_epoch = 0
    while cur_epoch < 1:
        print('Test epoch')
        test()
        print('Train epoch')
        train()
        cur_epoch += 1


def initialise_graph():
    with graph.as_default():
        sess.run(tf.global_variables_initializer())


build_graph()
initialise_graph()
train_loop()

这将输出:

Test epoch
(3, 2)
[[1. 1.]
 [1. 1.]
 [1. 1.]]
(1, 2)
[[1. 1.]]
Train epoch
(3, 2)
[[2. 2.]
 [2. 2.]
 [2. 2.]]
(3, 2)
[[2. 2.]
 [2. 2.]
 [2. 2.]]
(3, 2)
[[2. 2.]
 [2. 2.]
 [2. 2.]]
(1, 2)
[[2. 2.]]