我正在尝试读取以FloatList
形式存储在tfrecords中的numpy数组。这就是我正在使用的代码
def get_training_dataset():
dataset = load_dataset(TRAINING_FILENAMES)
dataset = dataset.repeat()
dataset = dataset.shuffle(2048)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO)
return dataset
def load_dataset(filenames):
tf_op = tf.data.Options()
tf_op.experimental_deterministic = False
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
print(dataset)
dataset = dataset.with_options(tf_op)
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
return dataset
def read_tfrecord(example):
tfrec_format = {
"x" : tf.io.FixedLenFeature([], tf.float32),
"y" : tf.io.FixedLenFeature([], tf.float32)
}
print(example)
example = tf.io.parse_single_example(example, tfrec_format)
print(example['y'].shape)
x = tf.reshape(example['x'],[224,224,3])
y = to_categorical(tf.reshape(example['y'],[224,224,2]), num_classes=32)
y = {"output1":y[:,:,0], "output2":y[:,:,1]}
return x,y
print("Training data shapes:")
for image, label in get_training_dataset().take(3):
print(image.numpy().shape, label.numpy().shape)
print("Validation data shapes:")
for image, label in get_validation_dataset().take(3):
print(image.numpy().shape, label.numpy().shape)
印刷品的输出是
训练数据形状:
张量(“ args_0:0”,shape =(),dtype = string) ()
为什么形状显示为空? tfrecords不为空。我也收到以下错误
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-40-53f01a0d2c13> in <module>
1 print("Training data shapes:")
----> 2 for image, label in get_training_dataset().take(3):
3 print(image.numpy().shape, label.numpy().shape)
4 print("Validation data shapes:")
5 for image, label in get_validation_dataset().take(3):
<ipython-input-22-f4dcf023fbcc> in get_training_dataset()
1 def get_training_dataset():
----> 2 dataset = load_dataset(TRAINING_FILENAMES)
3 dataset = dataset.repeat()
4 dataset = dataset.shuffle(2048)
5 dataset = dataset.batch(BATCH_SIZE)
<ipython-input-39-16e1046af942> in load_dataset(filenames)
5 print(dataset)
6 dataset = dataset.with_options(tf_op)
----> 7 dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
8
9 return dataset
/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self, map_func, num_parallel_calls, deterministic)
1626 num_parallel_calls,
1627 deterministic,
-> 1628 preserve_cardinality=True)
1629
1630 def flat_map(self, map_func):
/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, input_dataset, map_func, num_parallel_calls, deterministic, use_inter_op_parallelism, preserve_cardinality, use_legacy_function)
4018 self._transformation_name(),
4019 dataset=input_dataset,
-> 4020 use_legacy_function=use_legacy_function)
4021 if deterministic is None:
4022 self._deterministic = "default"
/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
3219 with tracking.resource_tracker_scope(resource_tracker):
3220 # TODO(b/141462134): Switch to using garbage collection.
-> 3221 self._function = wrapper_fn.get_concrete_function()
3222
3223 if add_to_graph:
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self, *args, **kwargs)
2530 """
2531 graph_function = self._get_concrete_function_garbage_collected(
-> 2532 *args, **kwargs)
2533 graph_function._garbage_collector.release() # pylint: disable=protected-access
2534 return graph_function
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
2494 args, kwargs = None, None
2495 with self._lock:
-> 2496 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
2497 if self.input_signature:
2498 args = self.input_signature
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2775
2776 self._function_cache.missed.add(call_context_key)
-> 2777 graph_function = self._create_graph_function(args, kwargs)
2778 self._function_cache.primary[cache_key] = graph_function
2779 return graph_function, args, kwargs
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2665 arg_names=arg_names,
2666 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2667 capture_by_value=self._capture_by_value),
2668 self._function_attributes,
2669 # Tell the ConcreteFunction to clean up its graph once it goes out of
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
979 _, original_func = tf_decorator.unwrap(python_func)
980
--> 981 func_outputs = python_func(*func_args, **func_kwargs)
982
983 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapper_fn(*args)
3212 attributes=defun_kwargs)
3213 def wrapper_fn(*args): # pylint: disable=missing-docstring
-> 3214 ret = _wrapper_helper(*args)
3215 ret = structure.to_tensor_list(self._output_structure, ret)
3216 return [ops.convert_to_tensor(t) for t in ret]
/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in _wrapper_helper(*args)
3154 nested_args = (nested_args,)
3155
-> 3156 ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
3157 # If `func` returns a list of tensors, `nest.flatten()` and
3158 # `ops.convert_to_tensor()` would conspire to attempt to stack
/opt/conda/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
263 except Exception as e: # pylint:disable=broad-except
264 if hasattr(e, 'ag_error_metadata'):
--> 265 raise e.ag_error_metadata.to_exception(e)
266 else:
267 raise
ValueError: in user code:
<ipython-input-37-7bad4da2b4f1>:9 read_tfrecord *
x = tf.reshape(example['x'],[224,224,3])
/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:193 reshape **
result = gen_array_ops.reshape(tensor, shape, name)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py:8087 reshape
"Reshape", tensor=tensor, shape=shape, name=name)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:744 _apply_op_helper
attrs=attr_protos, op_def=op_def)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:595 _create_op_internal
compute_device)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3327 _create_op_internal
op_def=op_def)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:1817 __init__
control_input_ops, op_def)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:1657 _create_c_op
raise ValueError(str(e))
ValueError: Cannot reshape a tensor with 1 elements to shape [224,224,3] (150528 elements) for '{{node Reshape}} = Reshape[T=DT_FLOAT, Tshape=DT_INT32](ParseSingleExample/ParseExample/ParseExampleV2, Reshape/shape)' with input shapes: [], [3] and with input tensors computed as partial shapes: input[1] = [224,224,3].
我在做什么错?我是tfrecords的新手,请帮助 这是我正在使用的数据集https://www.kaggle.com/gameatro/colorization-data