如何将csv文件转换为TFrecord tensorFlow格式?

时间:2016-10-03 14:33:55

标签: python tensorflow

大家好我需要将csv文件转换为TensorFlow的TFrecord。我非常感谢你的帮助。 我需要转换的csv文件的一个例子是:

Col1 Col2 Col3 Col4目标

2.56 0.98 0.45 7.8 0.189

3.10 5.78 4.78 9.0 0.78

...

非常感谢!!!

1 个答案:

答案 0 :(得分:2)

以下代码将从多个CSV文件创建一个TFRecords文件......但尚未能够将数据读回来。

import pandas as pd
import numpy as np
import os
import tensorflow as tf
from tqdm import tqdm


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def make_q_list(filepathlist, filetype):
    filepathlist = filepathlist
    filepaths = []
    labels = []
    for path in filepathlist:
        data_files = os.listdir(path)
        for data in data_files:
            if data.endswith(filetype):
                data_file = os.path.join(path, data)
                data_label = os.path.basename(os.path.normpath(path))
                filepaths.append(data_file)
                labels.append(data_label)

    return filepaths, labels 

def tables_to_TF(queue_list, tf_filename, file_type='csv'):
    # Target variable needs to be the last column of data
    filepath = os.path.join(tf_filename)
    print('Writing', filepath)
    writer = tf.python_io.TFRecordWriter(tf_filename)
    for file in tqdm(queue_list):
        if file_type == 'csv':
            data = pd.read_csv(file).values
        elif file_type == 'hdf':
            data = pd.read_hdf(file).values
        else:
            print(file_type, 'is not supported at this time...')
            break
        for row in data:
            # file formate : feature 1…..feature n, label
            features, label = row[:-1], row[-1]
            example = tf.train.Example()
            example.features.feature["features"].float_list.value.extend(features)
            example.features.feature["label"].float_list.value.append(label)
            writer.write(example.SerializeToString()

#Generate data
for i in range(10):
    filename = './Data/random_csv' + str(i) + '.csv'
    pd.DataFrame(np.random.uniform(0,100,size=(100, 50))).to_csv(filename)

filepathlist = ['./Data']
q, _ = make_q_list(filepathlist, '.csv')
tffilename = 'Demo_TFR.tfrecords'
tables_to_TF(q, tffilename, file_type='csv')

03/18/2018编辑:删除了冗余的代码行data_file = data_file