tf.train.shuffle_batch不适合我

时间:2016-09-04 23:47:36

标签: python tensorflow

我正在尝试使用TensorFlow清理方式(tf.train.shuffle_batch)处理我的输入数据,我从教程中收集了大部分代码,稍作修改,如decode_jpeg函数。

size = 32,32
classes = 43
train_size = 12760
batch_size = 100
max_steps = 10000

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        # Defaults are not specified since both keys are required.
        features={
            'image/encoded': tf.FixedLenFeature([], tf.string),
            'image/class/label': tf.FixedLenFeature([], tf.int64),
            'image/height': tf.FixedLenFeature([], tf.int64),
            'image/width': tf.FixedLenFeature([], tf.int64),
        })
    label = tf.cast(features['image/class/label'], tf.int32)
    reshaped_image = tf.image.decode_jpeg(features['image/encoded'])
    reshaped_image = tf.image.resize_images(reshaped_image, size[0], size[1], method = 0)
    reshaped_image = tf.image.per_image_whitening(reshaped_image)
    return reshaped_image, label

def inputs(train, batch_size, num_epochs):
    subset = "train"
    tf_record_pattern = os.path.join(FLAGS.train_dir + '/GTSRB', '%s-*' % subset)
    data_files = tf.gfile.Glob(tf_record_pattern)
    filename_queue = tf.train.string_input_producer(
        data_files, num_epochs=num_epochs)

    # Even when reading in multiple threads, share the filename
    # queue.
    image, label = read_and_decode(filename_queue)

    # Shuffle the examples and collect them into batch_size batches.
    # (Internally uses a RandomShuffleQueue.)
    # We run this in two threads to avoid being a bottleneck.
    images, sparse_labels = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, num_threads=2,
        capacity=1000 + 3 * batch_size,
        # Ensures a minimum amount of shuffling of examples.
        min_after_dequeue=1000)
    return images, sparse_labels

当我尝试运行时

batch_x, batch_y = inputs(True, 100,100)

我收到以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-6-543290a0c903> in <module>()
----> 1 batch_x, batch_y = inputs(True, 100,100)

<ipython-input-5-a8c07c7fc263> in inputs(train, batch_size, num_epochs)
     73         capacity=1000 + 3 * batch_size,
     74         # Ensures a minimum amount of shuffling of examples.
---> 75         min_after_dequeue=1000)
     76     #return image, label
     77     return images, sparse_labels

/Users/Kevin/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/input.pyc in shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads, seed, enqueue_many, shapes, allow_smaller_final_batch, shared_name, name)
    800     queue = data_flow_ops.RandomShuffleQueue(
    801         capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
--> 802         dtypes=types, shapes=shapes, shared_name=shared_name)
    803     _enqueue(queue, tensor_list, num_threads, enqueue_many)
    804     full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),

/Users/Kevin/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/data_flow_ops.pyc in __init__(self, capacity, min_after_dequeue, dtypes, shapes, names, seed, shared_name, name)
    580     """
    581     dtypes = _as_type_list(dtypes)
--> 582     shapes = _as_shape_list(shapes, dtypes)
    583     names = _as_name_list(names, dtypes)
    584     # If shared_name is provided and an op seed was not provided, we must ensure

/Users/Kevin/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/data_flow_ops.pyc in _as_shape_list(shapes, dtypes, unknown_dim_allowed, unknown_rank_allowed)
     70   if not unknown_dim_allowed:
     71     if any([not shape.is_fully_defined() for shape in shapes]):
---> 72       raise ValueError("All shapes must be fully defined: %s" % shapes)
     73   if not unknown_rank_allowed:
     74     if any([shape.dims is None for shape in shapes]):

ValueError: All shapes must be fully defined: [TensorShape([Dimension(32), Dimension(32), Dimension(None)]), TensorShape([])]

我不确定是什么导致了这个错误,我想这与我正在处理我的图像的方式有关,因为它表明当它们应该有3个通道(RGB)时它们没有尺寸。

1 个答案:

答案 0 :(得分:4)

batching methods in TensorFlowtf.train.batch()tf.train.batch_join()tf.train.shuffle_batch()tf.train.shuffle_batch_join())要求批次中的每个元素都具有完全相同的形状*,所以他们可以被打包成密集的张量。在您的代码中,您传递给image的{​​{1}}张量的第三个维度似乎未知大小。这对应于每个图像中的通道数,对于单色图像为1,对于彩色图像为3,或者对于具有alpha通道的彩色图像为4。如果您传递明确的tf.train.shuffle_batch()(其中channels=N适当地为1,3或4),这将为TensorFlow提供有关图像张量形状的足够信息,以便继续。

*有一个例外:当您将N传递给dynamic_pad=Truetf.train.batch()时,元素可以具有不同的形状,但它们必须具有相同的等级。通常,这仅用于顺序数据,而不是图像数据(在图像边缘会有不良行为)。