我正在关注wide_deep教程,但我很难再现正确阅读CSV的示例。
以下是我生成虚拟CSV的代码:
data = pd.DataFrame({
'y': [1,2,3],
'x1':[4,5,6],
'x2':[7.0,8.0,9.0],
'x3':['ten','eleven','twelve']
})
file_path = 'tmp.csv'
data.to_csv(file_path, index=False, header=False)
这就是CSV的样子:
然后我尝试用以下文件读取文件:
def parse_csv(line):
_CSV_COLUMNS = ['x1','x2','x3','y']
defaults = [[0],[0.0],[''],[0]]
columns = tf.decode_csv(line, record_defaults=defaults)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('y')
return features, tf.equal(labels, 3)
dataset = tf.data.TextLineDataset(file_path)
dataset = dataset.map(parse_csv)
iterator = dataset.make_one_shot_iterator()
for i in range(3):
features, labels = iterator.get_next()
for k,v in features.items():
print(k, v.eval())
print('-'*50)
输出如下:
x1 4
x2 8.0
x3 b'twelve'
--------------------------------------------------
<error message: OutOfRangeError (see above for traceback): End of sequence>
为什么不是4, 7.0, 'ten'
?
答案 0 :(得分:2)
您面临的问题是由于v.eval()
将推进所有组件的迭代器。来自(DOCS):
请注意,评估next1,next2或next3中的任何一个将推进所有组件的迭代器。迭代器的典型使用者将在单个表达式中包含所有组件。
获得所需信息的一种方法是:
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
for i in range(3):
for k, v in sess.run(features).items():
print(k, v)
print('-' * 50)
import tensorflow as tf
sess = tf.InteractiveSession()
data = pd.DataFrame({
'y': [1, 2, 3],
'x1': [4, 5, 6],
'x2': [7.0, 8.0, 9.0],
'x3': ['ten', 'eleven', 'twelve']
})
file_path = 'tmp.csv'
data.to_csv(file_path, index=False, header=False)
def parse_csv(line):
_CSV_COLUMNS = ['x1', 'x2', 'x3', 'y']
defaults = [[0], [0.0], [''], [0]]
columns = tf.decode_csv(line, record_defaults=defaults)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('y')
return features, tf.equal(labels, 3)
dataset = tf.data.TextLineDataset(file_path)
dataset = dataset.map(parse_csv)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
for i in range(3):
for k, v in sess.run(features).items():
print(k, v)
print('-' * 50)
x1 4
x2 7.0
x3 b'ten'
--------------------------------------------------
x1 5
x2 8.0
x3 b'eleven'
--------------------------------------------------
x1 6
x2 9.0
x3 b'twelve'