Tensorflow解析CSV迭代器按行移位

时间:2017-12-30 00:04:50

标签: python tensorflow

我正在关注wide_deep教程,但我很难再现正确阅读CSV的示例。

以下是我生成虚拟CSV的代码:

data = pd.DataFrame({
    'y': [1,2,3],
    'x1':[4,5,6],
    'x2':[7.0,8.0,9.0],
    'x3':['ten','eleven','twelve']
})
file_path = 'tmp.csv'
data.to_csv(file_path, index=False, header=False)

这就是CSV的样子:

enter image description here

然后我尝试用以下文件读取文件:

def parse_csv(line):
    _CSV_COLUMNS = ['x1','x2','x3','y']
    defaults = [[0],[0.0],[''],[0]]
    columns = tf.decode_csv(line, record_defaults=defaults)
    features = dict(zip(_CSV_COLUMNS, columns))
    labels = features.pop('y')
    return features, tf.equal(labels, 3)

dataset = tf.data.TextLineDataset(file_path)
dataset = dataset.map(parse_csv)

iterator = dataset.make_one_shot_iterator()

for i in range(3):
    features, labels = iterator.get_next()
    for k,v in features.items():
        print(k, v.eval())
    print('-'*50)

输出如下:

x1 4
x2 8.0
x3 b'twelve'
--------------------------------------------------
<error message: OutOfRangeError (see above for traceback): End of sequence>

为什么不是4, 7.0, 'ten'

1 个答案:

答案 0 :(得分:2)

您面临的问题是由于v.eval()将推进所有组件的迭代器。来自(DOCS):

  

请注意,评估next1,next2或next3中的任何一个将推进所有组件的迭代器。迭代器的典型使用者将在单个表达式中包含所有组件。

获得所需信息的一种方法是:

代码:

iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()

for i in range(3):
    for k, v in sess.run(features).items():
        print(k, v)
    print('-' * 50)

测试代码:

import tensorflow as tf

sess = tf.InteractiveSession()

data = pd.DataFrame({
    'y': [1, 2, 3],
    'x1': [4, 5, 6],
    'x2': [7.0, 8.0, 9.0],
    'x3': ['ten', 'eleven', 'twelve']
})
file_path = 'tmp.csv'
data.to_csv(file_path, index=False, header=False)

def parse_csv(line):
    _CSV_COLUMNS = ['x1', 'x2', 'x3', 'y']
    defaults = [[0], [0.0], [''], [0]]
    columns = tf.decode_csv(line, record_defaults=defaults)
    features = dict(zip(_CSV_COLUMNS, columns))
    labels = features.pop('y')
    return features, tf.equal(labels, 3)

dataset = tf.data.TextLineDataset(file_path)
dataset = dataset.map(parse_csv)

iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()

for i in range(3):
    for k, v in sess.run(features).items():
        print(k, v)
    print('-' * 50)

结果:

x1 4
x2 7.0
x3 b'ten'
--------------------------------------------------
x1 5
x2 8.0
x3 b'eleven'
--------------------------------------------------
x1 6
x2 9.0
x3 b'twelve'