import os
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
cwd = 'E:\Tensorflow\Wenshan_Cai_Nanoletters\classes\\'
classes = {'cats', 'dogs', 'horses', 'humans'}
def convert_to_tfrecord(filename):
writer = tf.python_io.TFRecordWriter(filename)
for index, name in enumerate(classes):
class_path = cwd + name + '\\'
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((64,64))
img_raw = img.tobytes()
example = tf.train.Example(features = tf.train.Features(feature = {
'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [index])),
'img_raw': tf.train.Feature(bytes_list = tf.train.BytesList(value = [img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
def dataset_input_fn():
filenames = ['E:\Tensorflow\Wenshan_Cai_Nanoletters\mytrain.tfrecords']
dataset = tf.data.TFRecordDataset(filenames)
def parser(record):
keys_to_features = {
'image_data': tf.FixedLenFeature((), tf.string, default_value = ''),
'label': tf.FixedLenFeature((), tf.int64,
default_value= tf.zeros([], dtype = tf.int64)),
}
parsed = tf.parse_single_example(record, keys_to_features)
image = tf.image.decode_jpeg(parsed['image_data'])
image = tf.reshape(image, [64, 64, 3])
label = tf.cast(parsed['label'], tf.int32)
return image, label
dataset = dataset.map(parser)
dataset = dataset.shuffle(buffer_size = 10000)
dataset = dataset.batch(batch_size = 2)
dataset = dataset.repeat(1)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels # return a tuple
convert_to_tfrecord('mytrain.tfrecords')
output_file = dataset_input_fn()
for image, label in output_file:
img_data = tf.image.decode_jpeg(image)
plt.imshow(image)
plt.show()
print(label)
回溯(最近通话最近): 在第59行的“ E:/Tensorflow/Wenshan_Cai_Nanoletters/TFRecord.py”文件中 output_file =数据集_input_fn() 第55行中的文件“ E:/Tensorflow/Wenshan_Cai_Nanoletters/TFRecord.py”,在dataset_input_fn中 图片,标签= iterator.get_next()
tensorflow.python.framework.errors_impl.InvalidArgumentError:预期的图像(JPEG,PNG或GIF),文件为空 [[{{node DecodeJpeg}} = DecodeJpegacceptable_fraction = 1,channels = 0,dct_method =“”,fancy_upscaling = true,ratio = 1,try_recover_truncated = false]] [Op:IteratorGetNextSync]