我是TensorFlow领域的新手,真的很难为模型加载我的输入和标签。输入数据为CSV文件,标签为PNG文件。更具体地说,每个CSV文件都是一个训练样本,每个文件仅包含一行。
下面是我的CSV文件的一些简单示例:(但是,在实际情况下,每一行都有2 ^ 9个元素。)
C1.csv
1,2,3,4,5,6,7,0,0,0
C2.csv
0,1,4,6,9,0,93,24,45,7
我将TXT文件中的CSV和PNG文件的路径组织在一起。
../ dataset / train.txt
../dataset/chartHis/C1.csv ../dataset/legend/L1.png
../dataset/chartHis/C2.csv ../dataset/legend/L2.png
../dataset/chartHis/C3.csv ../dataset/legend/L3.png
../dataset/chartHis/C4.csv ../dataset/legend/L4.png
../dataset/chartHis/C5.csv ../dataset/legend/L5.png
下面是我的代码:
import random
import tensorflow as tf
train_file = '../dataset/train.txt'
def data_loader(batch_size=1, file=train_file, resize=None):
paths = open(file, 'r').read().splitlines()
random.shuffle(paths)
image_paths = [p.split('\t')[0] for p in paths]
label_paths = [p.split('\t')[1] for p in paths]
# create batch input
# convert to tensor list
img_list = tf.convert_to_tensor(image_paths, dtype=tf.string)
lab_list = tf.convert_to_tensor(label_paths, dtype=tf.string)
# create data queue
data_queue = tf.train.slice_input_producer([img_list, lab_list],
shuffle=False, capacity=batch_size * 128, num_epochs=None)
# decode image
record_defaults = [[0]] * (1 << 9)
image = tf.stack(tf.decode_csv(tf.read_file(data_queue[0]), record_defaults=record_defaults))
label = tf.image.decode_png(tf.read_file(data_queue[1]), channels=3)
# resize to define image shape
if resize is None:
image = tf.reshape(image, [16, 32, 1])
label = tf.reshape(label, [40, 1024, 3])
else:
image = tf.image.resize_images(image, resize)
label = tf.image.resize_images(label, resize)
print(image)
# convert to float data type
image = tf.cast(image, dtype=tf.float32)
label = tf.cast(label, dtype=tf.float32)
# data pre-processing, normalize
image = tf.divide(image, tf.constant(255.0))
label = tf.divide(label, tf.constant(255.0))
# create batch data
images, labels = tf.train.shuffle_batch([image, label],
batch_size=batch_size, num_threads=1,
capacity=batch_size * 128, min_after_dequeue=batch_size * 0)
return {'images': images, 'labels': labels}
# Unit test
if __name__ == '__main__':
data_dict = data_loader()
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
images = data_dict['images']
labels = data_dict['labels']
print images.shape, labels.shape
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer()))
# coordinator for queue runner
coord = tf.train.Coordinator()
# start queue
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
batch_im, batch_gt = sess.run([images, labels])
print batch_im
print batch_gt
coord.request_stop()
coord.join(threads)
但是,我在以下代码中将 tf.decode_csv 与 tf.read_file 结合使用时遇到错误(非常类似于this post):
image = tf.stack(tf.decode_csv(tf.read_file(data_queue[0]), record_defaults=record_defaults))
我已按照上述文章的说明将 tf.decode_csv 与 tf.TextLineReader 结合使用,如下所示:
# create data queue
data_queue = tf.train.slice_input_producer([img_list, lab_list],
shuffle=False, capacity=batch_size * 128, num_epochs=None)
# decode image
record_defaults = [[0]] * (1 << 9)
reader = tf.TextLineReader()
_, line = reader.read(data_queue[0])
image = tf.stack(tf.decode_csv(line, record_defaults=record_defaults))
# image = tf.stack(tf.decode_csv(tf.read_file(data_queue[0]), record_defaults=record_defaults))
label = tf.image.decode_png(tf.read_file(data_queue[1]), channels=3)
但是,出现了另一个问题,我什至在Internet上都找不到任何解决方案。
Traceback (most recent call last):
File "<input>", line 4, in <module>
File "/home/amax/.pycharm_helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/home/amax/linping/DLColor/cnn/data_loader.py", line 64, in <module>
data_dict = data_loader()
File "/home/amax/linping/DLColor/cnn/data_loader.py", line 28, in data_loader
_, line = reader.read(data_queue[0])
File "/home/amax/linping/python2/local/lib/python2.7/site-packages/tensorflow/python/ops/io_ops.py", line 164, in read
return gen_io_ops.reader_read_v2(self._reader_ref, queue_ref, name=name)
File "/home/amax/linping/python2/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_io_ops.py", line 941, in reader_read_v2
queue_handle=queue_handle, name=name)
File "/home/amax/linping/python2/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 533, in _apply_op_helper
(prefix, dtypes.as_dtype(input_arg.type).name))
TypeError: Input 'queue_handle' of 'ReaderReadV2' Op has type string that does not match expected type of resource.
我做了一些实验,发现TextLineReader的read()方法与 tf.train.string_input_producer 代替 tf.train.slice_input_producer 一起很好地工作。但是,我认为我需要使用 tf.train.slice_input_producer ,因为CSV中的功能和PNG中的标签应该相互对应。
总而言之,我正在寻找一种使用tf.decode_csv,tf.read_file和tf.train.slice读取CSV文件的解决方案。顺便说一句,欢迎任何其他解决方案。
先谢谢您! :D