如何在tensorflow中读取整个CSV文件作为一个训练示例?

时间:2017-06-07 18:10:08

标签: python csv tensorflow reader

我的数据集包含数百个.csv文件,这些文件具有固定数量的列和可变数量的行。问题是 - 如何将其读入tensorflow?

filename_queue = tf.train.string_input_producer(['file1.csv','file2.csv'])
features_reader = tf.WholeFileReader()
filename, value = features_reader.read(filename_queue)

现在有一些方法可以将值解码为其中的实际数字。有没有办法,或者我应该使用不同的读卡器?

1 个答案:

答案 0 :(得分:1)

所以实际上我通过创建tf.records - tensorflow二进制文件来解决这个问题,并且我认为这通常是一种方法。

虽然处理tf.records的官方文档并不令人满意,但这里有一个很好的解释:http://web.stanford.edu/class/cs20si/lectures/notes_09.pdf

首先需要读取文件并将其转换为二进制格式。在我的情况下,我只是将文件读取到一个numpy数组。

file = your_custom_reader(csv_file)
file = file.tobytes()

现在,在我的情况下,列数是常量,但数据集中的行数变量。这可能很棘手 - 当你阅读二进制文件时,它们会作为张量而没有预定义的形状(在注释的例子中,形状存储在二进制文件中,但这仍然意味着你需要在会​​话中对它进行评估,这使它成为无用于构建模型)。因此,在此步骤中,将张量填充到最大尺寸是有用的。

file = your_custom_reader(csv_file)
file = pad_to_max_size(file)
file = file.tobytes()

写入tf.record很容易。鉴于每个文件都有一个标签y:

writer = tf.python_io.TFRecordWriter(file_name)
example = tf.train.Example(features=tf.train.Features(feature={
    'features': tf.train.Feature(bytes_list=tf.train.BytesList(value=[file])),
    'y'       : tf.train.Feature(bytes_list=tf.train.BytesList(value=[y.tobytes()]))
    }))
writer.write(example.SerializeToString())
writer.close()

现在,二进制文件可以按如下方式加载

tfrecord_file_queue = tf.train.string_input_producer([file_name, file_name_2,...,file_name_N], name='queue')
reader = tf.TFRecordReader()
_, tfrecord_serialized = reader.read(tfrecord_file_queue)
tfrecord_features = tf.parse_single_example(tfrecord_serialized,
                    features={
                        'features': tf.FixedLenFeature([],tf.string),
                        'y'       : tf.FixedLenFeature([],tf.string)                                                   
                                },  name='tf_features')

正如我所说,对于其余的代码,了解张量的形状很重要。我的是SHAPE_1和SHAPE_2

features = tf.decode_raw(tfrecord_features['features'],tf.float32)
features = tf.reshape(audio_features, (SHAPE_1,SHAPE_2))
features.set_shape((SHAPE_1,SHAPE_2))
y = tf.decode_raw(tfrecord_features['y'],tf.float32)

上面我喜欢的斯坦福大学的演讲幻灯片中提供了将代码放入函数的更有条理的示例。我推荐这些幻灯片很多,特别是因为它们提供了更多的解释,缺乏这个答案。不过,我希望它有所帮助!