Tensorflow数据集结构

时间:2018-03-26 08:35:40

标签: python tensorflow tensorflow-datasets

我试图通过研究https://github.com/tensorflow/models/blob/master/official/resnet/cifar10_main.py上的cifar10官方示例来弄清楚如何使用Tensorflow的数据集模块

要自己构建数据集,我在函数'input_fn'中替换以下代码:

filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)

通过

dataset = creat_dataset()

其中'creat_dataset'定义为:

def creat_dataset():
    def unpickle(file):
        import cPickle
        with open(file, 'rb') as fo:
            dict = cPickle.load(fo)
        ll = dict['labels']
        return dict['data'], np.array(ll).reshape(len(ll), 1)

    dir = './cifar_10/data_batch_'
    data = None
    label = None
    for i in range(1,6):
        if data is None:
            data, label = unpickle(dir + '1')
        else:
            data_, label_ = unpickle(dir + str(i))
            data = np.concatenate((data, data_), 0)
            label = np.concatenate((label, label_))

    data = np.concatenate((label, data), 1)
    data = tf.constant(data, tf.uint8)


    dataset = tf.data.Dataset.from_tensor_slices(data)
    return dataset

但我得到的错误信息如下:

Traceback (most recent call last):
  File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 260, in <module>
    tf.app.run(argv=[sys.argv[0]] + unparsed)
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 126, in run
    _sys.exit(main(argv))
  File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 244, in main
    resnet.resnet_main(FLAGS, cifar10_model_fn, input_function)
  File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 766, in resnet_main
    classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 352, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 809, in _train_model
    input_fn, model_fn_lib.ModeKeys.TRAIN))
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 668, in _get_features_and_labels_from_input_fn
    result = self._call_input_fn(input_fn, mode)
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 760, in _call_input_fn
    return input_fn(**kwargs)
  File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 764, in input_fn_train
    flags.multi_gpu)
  File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 162, in input_fn
    examples_per_epoch=num_images, multi_gpu=multi_gpu)
  File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 104, in process_record_dataset
    num_parallel_calls=num_parallel_calls)
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 792, in map
    return ParallelMapDataset(self, map_func, num_parallel_calls)
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1628, in __init__
    super(ParallelMapDataset, self).__init__(input_dataset, map_func)
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1597, in __init__
    self._map_func.add_to_graph(ops.get_default_graph())
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 486, in add_to_graph
    self._create_definition_if_needed()
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed
    self._create_definition_if_needed_impl()
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 338, in _create_definition_if_needed_impl
    outputs = self._func(*inputs)
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1562, in tf_map_func
    ret = map_func(nested_args)
  File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 103, in <lambda>
    dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
  File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 69, in parse_record
    record_vector = tf.decode_raw(raw_record, tf.uint8)
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_parsing_ops.py", line 195, in decode_raw
    little_endian=little_endian, name=name)
  File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 533, in _apply_op_helper
    (prefix, dtypes.as_dtype(input_arg.type).name))
TypeError: Input 'bytes' of 'DecodeRaw' Op has type uint8 that does not match expected type of string.

有人可以向我解释如何修复此错误吗?

2 个答案:

答案 0 :(得分:0)

只需将表达式record_vector = tf.decode_raw(raw_record, tf.uint8)更改为record_vector = raw_record即可解决此问题,cifar数据集中的项似乎不是张量。

答案 1 :(得分:0)

我遇到了与您相同的错误,上面的答案很好。因此要明确一点,也许您的代码中存在以下情况:

_, serialized_example = reader.read(filename_queue)
img_features = tf.parse_single_example(serialized=serialized_example, 
                                       features={
                                           'image':tf.FixedLenFeature([], tf.float32),
                                           'label':tf.FixedLenFeature([], tf.int64)
                                       })
# image = tf.decode_raw(img_features['image'], tf.uint8)
image = img_features['image']

现在看看:
'image':tf.FixedLenFeature([], tf.float32),
您在互联网上观看的大多数教程是:
'image':tf.FixedLenFeature([], tf.string),
并运行以下行代码就可以了:
image = tf.decode_raw(img_features['data'], tf.uint8)
但是当tfrecord的初始FixedLenFeature只是以下允许的值之一:float16, float32, float64, int32, uint16, uint8, int16, int8, int64
那么哪里不需要decode_raw可能会导致您得到错误,只是
image = img_features['image']
顺便说一句,如果您使用Jupyter Notebook,只需记住在修改代码后Restart & Clear Output内核,然后再次逐步运行程序即可。