我试图通过研究https://github.com/tensorflow/models/blob/master/official/resnet/cifar10_main.py上的cifar10官方示例来弄清楚如何使用Tensorflow的数据集模块
要自己构建数据集,我在函数'input_fn'中替换以下代码:
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
通过
dataset = creat_dataset()
其中'creat_dataset'定义为:
def creat_dataset():
def unpickle(file):
import cPickle
with open(file, 'rb') as fo:
dict = cPickle.load(fo)
ll = dict['labels']
return dict['data'], np.array(ll).reshape(len(ll), 1)
dir = './cifar_10/data_batch_'
data = None
label = None
for i in range(1,6):
if data is None:
data, label = unpickle(dir + '1')
else:
data_, label_ = unpickle(dir + str(i))
data = np.concatenate((data, data_), 0)
label = np.concatenate((label, label_))
data = np.concatenate((label, data), 1)
data = tf.constant(data, tf.uint8)
dataset = tf.data.Dataset.from_tensor_slices(data)
return dataset
但我得到的错误信息如下:
Traceback (most recent call last):
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 260, in <module>
tf.app.run(argv=[sys.argv[0]] + unparsed)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 126, in run
_sys.exit(main(argv))
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 244, in main
resnet.resnet_main(FLAGS, cifar10_model_fn, input_function)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 766, in resnet_main
classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 352, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 809, in _train_model
input_fn, model_fn_lib.ModeKeys.TRAIN))
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 668, in _get_features_and_labels_from_input_fn
result = self._call_input_fn(input_fn, mode)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 760, in _call_input_fn
return input_fn(**kwargs)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 764, in input_fn_train
flags.multi_gpu)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 162, in input_fn
examples_per_epoch=num_images, multi_gpu=multi_gpu)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 104, in process_record_dataset
num_parallel_calls=num_parallel_calls)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 792, in map
return ParallelMapDataset(self, map_func, num_parallel_calls)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1628, in __init__
super(ParallelMapDataset, self).__init__(input_dataset, map_func)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1597, in __init__
self._map_func.add_to_graph(ops.get_default_graph())
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 486, in add_to_graph
self._create_definition_if_needed()
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 338, in _create_definition_if_needed_impl
outputs = self._func(*inputs)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1562, in tf_map_func
ret = map_func(nested_args)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 103, in <lambda>
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 69, in parse_record
record_vector = tf.decode_raw(raw_record, tf.uint8)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_parsing_ops.py", line 195, in decode_raw
little_endian=little_endian, name=name)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 533, in _apply_op_helper
(prefix, dtypes.as_dtype(input_arg.type).name))
TypeError: Input 'bytes' of 'DecodeRaw' Op has type uint8 that does not match expected type of string.
有人可以向我解释如何修复此错误吗?
答案 0 :(得分:0)
只需将表达式record_vector = tf.decode_raw(raw_record, tf.uint8)
更改为record_vector = raw_record
即可解决此问题,cifar数据集中的项似乎不是张量。
答案 1 :(得分:0)
我遇到了与您相同的错误,上面的答案很好。因此要明确一点,也许您的代码中存在以下情况:
_, serialized_example = reader.read(filename_queue)
img_features = tf.parse_single_example(serialized=serialized_example,
features={
'image':tf.FixedLenFeature([], tf.float32),
'label':tf.FixedLenFeature([], tf.int64)
})
# image = tf.decode_raw(img_features['image'], tf.uint8)
image = img_features['image']
现在看看:
'image':tf.FixedLenFeature([], tf.float32),
您在互联网上观看的大多数教程是:
'image':tf.FixedLenFeature([], tf.string),
并运行以下行代码就可以了:
image = tf.decode_raw(img_features['data'], tf.uint8)
但是当tfrecord的初始FixedLenFeature
只是以下允许的值之一:float16, float32, float64, int32, uint16, uint8, int16, int8, int64
,
那么哪里不需要decode_raw
可能会导致您得到错误,只是
image = img_features['image']
。
顺便说一句,如果您使用Jupyter Notebook,只需记住在修改代码后Restart & Clear Output
内核,然后再次逐步运行程序即可。