Tensorflow:tfrecord tf.data InvalidArgumentError:预期的图像(JPEG,PNG或GIF),空文件

时间:2018-10-09 21:28:04

标签: python tensorflow

import os
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow.contrib.eager as tfe

tfe.enable_eager_execution()

cwd = 'E:\Tensorflow\Wenshan_Cai_Nanoletters\classes\\'
classes = {'cats', 'dogs', 'horses', 'humans'}

def convert_to_tfrecord(filename):

    writer = tf.python_io.TFRecordWriter(filename)
    for index, name in enumerate(classes):
        class_path = cwd + name + '\\'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name   

            img = Image.open(img_path)
            img = img.resize((64,64))
            img_raw = img.tobytes()       
            example = tf.train.Example(features = tf.train.Features(feature = {
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [index])),
                'img_raw': tf.train.Feature(bytes_list = tf.train.BytesList(value = [img_raw]))
            }))     
            writer.write(example.SerializeToString())  
    writer.close()

def dataset_input_fn():
    filenames = ['E:\Tensorflow\Wenshan_Cai_Nanoletters\mytrain.tfrecords']
    dataset = tf.data.TFRecordDataset(filenames)

    def parser(record):
        keys_to_features = {
            'image_data': tf.FixedLenFeature((), tf.string, default_value  = ''),
            'label': tf.FixedLenFeature((), tf.int64,
                                        default_value= tf.zeros([], dtype = tf.int64)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)
        image = tf.image.decode_jpeg(parsed['image_data'])
        image = tf.reshape(image, [64, 64, 3])
        label = tf.cast(parsed['label'], tf.int32)

        return image, label

    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size = 10000)    
    dataset = dataset.batch(batch_size = 2)
    dataset = dataset.repeat(1)              
    iterator = dataset.make_one_shot_iterator()  

    images, labels = iterator.get_next()
    return images, labels             # return a tuple

convert_to_tfrecord('mytrain.tfrecords')
output_file = dataset_input_fn()


for image, label in output_file:
    img_data = tf.image.decode_jpeg(image)
    plt.imshow(image)
    plt.show()
    print(label)

回溯(最近通话最近):   在第59行的“ E:/Tensorflow/Wenshan_Cai_Nanoletters/TFRecord.py”文件中     output_file =数据集_input_fn()   第55行中的文件“ E:/Tensorflow/Wenshan_Cai_Nanoletters/TFRecord.py”,在dataset_input_fn中     图片,标签= iterator.get_next()

tensorflow.python.framework.errors_impl.InvalidArgumentError:预期的图像(JPEG,PNG或GIF),文件为空      [[{{node DecodeJpeg}} = DecodeJpegacceptable_fraction = 1,channels = 0,dct_method =“”,fancy_upscaling = true,ratio = 1,try_recover_truncated = false]] [Op:IteratorGetNextSync]

0 个答案:

没有答案