在Tensorflow

时间:2019-01-13 15:25:06

标签: csv tensorflow

我是TensorFlow领域的新手,真的很难为模型加载我的输入和标签。输入数据为CSV文件,标签为PNG文件。更具体地说,每个CSV文件都是一个训练样本,每个文件仅包含一行。

下面是我的CSV文件的一些简单示例:(但是,在实际情况下,每一行都有2 ^ 9个元素。)

C1.csv

1,2,3,4,5,6,7,0,0,0

C2.csv

0,1,4,6,9,0,93,24,45,7

我将TXT文件中的CSV和PNG文件的路径组织在一起。

../ dataset / train.txt

../dataset/chartHis/C1.csv      ../dataset/legend/L1.png
../dataset/chartHis/C2.csv      ../dataset/legend/L2.png
../dataset/chartHis/C3.csv      ../dataset/legend/L3.png
../dataset/chartHis/C4.csv      ../dataset/legend/L4.png
../dataset/chartHis/C5.csv      ../dataset/legend/L5.png

下面是我的代码:

import random

import tensorflow as tf

train_file = '../dataset/train.txt'

def data_loader(batch_size=1, file=train_file, resize=None):

    paths = open(file, 'r').read().splitlines()

    random.shuffle(paths)

    image_paths = [p.split('\t')[0] for p in paths]
    label_paths = [p.split('\t')[1] for p in paths]

    # create batch input
    # convert to tensor list
    img_list = tf.convert_to_tensor(image_paths, dtype=tf.string)
    lab_list = tf.convert_to_tensor(label_paths, dtype=tf.string)

    # create data queue
    data_queue = tf.train.slice_input_producer([img_list, lab_list],
                                               shuffle=False, capacity=batch_size * 128, num_epochs=None)

    # decode image
    record_defaults = [[0]] * (1 << 9)
    image = tf.stack(tf.decode_csv(tf.read_file(data_queue[0]), record_defaults=record_defaults))
    label = tf.image.decode_png(tf.read_file(data_queue[1]), channels=3)

    # resize to define image shape
    if resize is None:
        image = tf.reshape(image, [16, 32, 1])
        label = tf.reshape(label, [40, 1024, 3])
    else:
        image = tf.image.resize_images(image, resize)
        label = tf.image.resize_images(label, resize)

    print(image)
    # convert to float data type
    image = tf.cast(image, dtype=tf.float32)
    label = tf.cast(label, dtype=tf.float32)

    # data pre-processing, normalize
    image = tf.divide(image, tf.constant(255.0))
    label = tf.divide(label, tf.constant(255.0))


    # create batch data
    images, labels = tf.train.shuffle_batch([image, label],
                                            batch_size=batch_size, num_threads=1,
                                            capacity=batch_size * 128, min_after_dequeue=batch_size * 0)


    return {'images': images, 'labels': labels}


# Unit test
if __name__ == '__main__':
    data_dict = data_loader()

    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    images = data_dict['images']
    labels = data_dict['labels']
    print images.shape, labels.shape

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer()))

    # coordinator for queue runner
    coord = tf.train.Coordinator()

    # start queue
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    batch_im, batch_gt = sess.run([images, labels]) 
    print batch_im
    print batch_gt

    coord.request_stop()
    coord.join(threads)

但是,我在以下代码中将 tf.decode_csv tf.read_file 结合使用时遇到错误(非常类似于this post):


image = tf.stack(tf.decode_csv(tf.read_file(data_queue[0]), record_defaults=record_defaults))

我已按照上述文章的说明将 tf.decode_csv tf.TextLineReader 结合使用,如下所示:

    # create data queue
    data_queue = tf.train.slice_input_producer([img_list, lab_list],
                                               shuffle=False, capacity=batch_size * 128, num_epochs=None)

    # decode image
    record_defaults = [[0]] * (1 << 9)
    reader = tf.TextLineReader()
    _, line = reader.read(data_queue[0])
    image = tf.stack(tf.decode_csv(line, record_defaults=record_defaults))
    # image = tf.stack(tf.decode_csv(tf.read_file(data_queue[0]), record_defaults=record_defaults))
    label = tf.image.decode_png(tf.read_file(data_queue[1]), channels=3)

但是,出现了另一个问题,我什至在Internet上都找不到任何解决方案。

Traceback (most recent call last):
  File "<input>", line 4, in <module>
  File "/home/amax/.pycharm_helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/home/amax/linping/DLColor/cnn/data_loader.py", line 64, in <module>
    data_dict = data_loader()
  File "/home/amax/linping/DLColor/cnn/data_loader.py", line 28, in data_loader
    _, line = reader.read(data_queue[0])
  File "/home/amax/linping/python2/local/lib/python2.7/site-packages/tensorflow/python/ops/io_ops.py", line 164, in read
    return gen_io_ops.reader_read_v2(self._reader_ref, queue_ref, name=name)
  File "/home/amax/linping/python2/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_io_ops.py", line 941, in reader_read_v2
    queue_handle=queue_handle, name=name)
  File "/home/amax/linping/python2/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 533, in _apply_op_helper
    (prefix, dtypes.as_dtype(input_arg.type).name))
TypeError: Input 'queue_handle' of 'ReaderReadV2' Op has type string that does not match expected type of resource.

我做了一些实验,发现TextLineReader的read()方法与 tf.train.string_input_producer 代替 tf.train.slice_input_producer 一起很好地工作。但是,我认为我需要使用 tf.train.slice_input_producer ,因为CSV中的功能和PNG中的标签应该相互对应。

总而言之,我正在寻找一种使用tf.decode_csv,tf.read_file和tf.train.slice读取CSV文件的解决方案。顺便说一句,欢迎任何其他解决方案。

先谢谢您! :D

0 个答案:

没有答案