我尝试使用Tensorflow Dataset API读取数据。我已将文件名和标签文件名加载到数组中,然后加载到数据集中。然后我尝试将这些文件名映射到实际的图像文件,但是得到的错误似乎表明映射函数的输入接收占位符而不是实际的张量。
class DatasetReader:
def __init__(self, records_list, batch_size=1):
self.batch_size = batch_size
self.records = {}
self.records["image"] = tf.convert_to_tensor([record['image'] for record in records_list])
self.records["filename"] = tf.convert_to_tensor([record['filename'] for record in records_list])
self.records["annotation"] = tf.convert_to_tensor([record['annotation'] for record in records_list])
self.dataset = Dataset.from_tensor_slices(self.records)
self.dataset = self.dataset.map(self._input_parser)
self.dataset = self.dataset.batch(batch_size)
self.dataset = self.dataset.repeat()
def _input_parser(self, record):
filename = record['filename']
image_name = record['image']
annotation_file = record['annotation']
image = tf.image.decode_image(tf.read_file(filename))
annotation = tf.image.decode_image(tf.read_file(annotation_file))
return self._augment_image(image, annotation)
我得到的错误在image = tf.image.decode_image(tf.read_file(filename))
行。堆栈跟踪在下面。
File "FCN.py", line 269, in <module>
tf.app.run()
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "FCN.py", line 179, in main
train_records, valid_records, image_options_train, image_options_val, FLAGS.batch_size, FLAGS.batch_size)
File "/home/ubuntu/FCN.tensorflow/TFReader.py", line 89, in from_records
train_reader = DatasetReader(train_records, train_image_options, train_batch_size)
File "/home/ubuntu/FCN.tensorflow/TFReader.py", line 34, in __init__
self.dataset = self.dataset.map(self._input_parser)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 964, in map
return MapDataset(self, map_func, num_threads, output_buffer_size)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 1735, in __init__
self._map_func.add_to_graph(ops.get_default_graph())
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 449, in add_to_graph
self._create_definition_if_needed()
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/data/python/framework/function.py", line 168, in _create_definition_if_needed
outputs = self._func(*inputs)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 1723, in tf_map_func
ret = map_func(nested_args)
File "/home/ubuntu/FCN.tensorflow/TFReader.py", line 42, in _input_parser
image = tf.image.decode_image(tf.read_file(filename))
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_io_ops.py", line 223, in read_file
result = _op_def_lib.apply_op("ReadFile", filename=filename, name=name)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/data/python/framework/function.py", line 80, in create_op
data_types, **kwargs)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 665, in create_op
**kwargs)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2632, in create_op
set_shapes_for_outputs(ret)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1911, in set_shapes_for_outputs
shapes = shape_func(op)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1861, in call_with_requiring
return call_cpp_shape_fn(op, require_shape_fn=True)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 595, in call_cpp_shape_fn
require_shape_fn)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 659, in _call_cpp_shape_fn_impl
raise ValueError(err.message)
ValueError: Shape must be rank 0 but is rank 1 for 'ReadFile' (op: 'ReadFile') with input shapes: [?].
答案 0 :(得分:1)
您无法将Rank-1张量传递给tf.read_file。以下是一些例子:
import tensorflow as tf
# Correct: input can be a string.
tf.image.decode_image(tf.read_file("filename"))
# Correct: input can be a rank-0 tensor.
tf.image.decode_image(tf.read_file(tf.convert_to_tensor("filename")))
# Wrong: input cannot be a list.
tf.image.decode_image(tf.read_file(["filename"]))
# Wrong: input cannot be a rank-1 tensor
tf.image.decode_image(tf.read_file(tf.convert_to_tensor(["filename"])))
在您的代码中,似乎self.records["filename"]
是一级-1张量;您可能会错误地将其作为参数传递给tf.read_file
_input_parser