我有一个要解析的tfrecords
数据集。
我正在使用以下代码对其进行解析:
image_size = [224,224]
def read_tfrecord(tf_record):
features = {
"filename": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"fun": tf.io.FixedLenFeature([], tf.string),
"label": tf.io.VarLenFeature(tf.int64),
}
tf_record = tf.parse_single_example(tf_record, features)
filename = tf.image.decode_jpeg(tf_record['filename'], channels=3)
filename = tf.cast(filename, tf.float32) / 255.0 # convert image to floats in [0, 1] range
filename = tf.reshape(filename, [*image_size, 3]) # explicit size will be needed for TPU
label = tf.cast(tf_record['label'],tf.float32)
return filename, label
def load_dataset(filenames):
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.with_options(option_no_order)
#dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=16)
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=32, num_parallel_calls=AUTO) # faster
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
return dataset
train_data=load_dataset(train_filenames)
val_data=load_dataset(val_filenames)
test_data=load_dataset(test_filenames)
运行此代码后,我得到:
train_data
我试图通过以下方式查看数据集中的图像:
<DatasetV1Adapter shapes: ((224, 224, 3), (?,)), types: (tf.float32, tf.float32)>
def display_9_images_from_dataset(dataset):
subplot=331
plt.figure(figsize=(13,13))
images, labels = dataset_to_numpy_util(dataset, 9)
for i, image in enumerate(images):
title = CLASSES[np.argmax(labels[i], axis=-1)]
subplot = display_one_flower(image, title, subplot)
if i >= 8:
break;
plt.tight_layout()
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
def dataset_to_numpy_util(dataset, N):
dataset = dataset.batch(N)
if tf.executing_eagerly():
# In eager mode, iterate in the Datset directly.
for images, labels in dataset:
numpy_images = images.numpy()
numpy_labels = labels.numpy()
break;
else: # In non-eager mode, must get the TF note that
# yields the nextitem and run it in a tf.Session.
get_next_item = dataset.make_one_shot_iterator().get_next()
with tf.Session() as ses:
numpy_images, numpy_labels = ses.run(get_next_item)
return numpy_images, numpy_labels
display_9_images_from_dataset(train_data)
但是我得到了错误:
InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got unknown format starting with 'B2.jpg'
[[{{node DecodeJpeg}}]]
[[IteratorGetNext_3]]
我有点困惑,因为它说文件是jpg
格式,并要求jpeg
,据我所知是相同的。
还有一个原因是我不确定如何查看图像,甚至不确定我是否正确解析了图像。
答案 0 :(得分:0)
扩展名“ .jpg”和“ .jpeg”在使用它的API进行的验证检查方面有所不同。
tf.image.decode_jpeg拍摄带有“ .jpeg”扩展名的图像。
尝试使用.jpeg扩展名重命名.jpg图像,它应该开始起作用。