通过tf.TextLineReader和tf.decode_csv从CSV行读取混合数据类型

时间:2017-02-10 22:05:13

标签: python tensorflow

我想阅读的不仅仅是输入CSV数据文件每行的数字要素和标签。我的功能和标签都是数字的,但我也想为每一行使用类似日期或类似字符串的标识符,以便我可以在随机随机批量读取后返回该行。

由于TensorFlow需要以编程方式预先填充的record_defaults数组,如何在一行中创建TF读取混合类型?如果我使用浮点数预填充record_defaults,则包含日期/字符串的单元格会出错。如果我使用字符串预填充record_defaults,以及稍后将数字部分转换为浮点数的视图,我会收到一个错误,即TF不支持字符串到浮点数转换。

文档声明:

具有以下类型的Tensor对象列表:float32,int32,int64,string。输入记录的每列一个张量,具有该列的标量默认值,如果需要该列,则为空

如上所述,我如何指定“空”'如果我只是想要所有列,并跳过整个默认值头痛?在那种情况下,TF将如何知道每个单元格自动选择哪些数据类型?我有很多专栏,因此按照文档中的一个示例逐一编写它是不切实际的:

record_defaults = [[1], [1], [1], [1], [1]]

(我有数百个细胞!)

我的代码看起来像这样:

rDefaults = [[0.02] for row in range((300))]
reader = tf.TextLineReader(skip_header_lines=False)
_, csv_row = reader.read(filename_queue)
data = tf.decode_csv(csv_row, record_defaults=rDefaults)
dateLbl = tf.slice(data, [0], [TD]) # <- this is where I'm having the problem
features = tf.slice(data, [TD], [TS])
label = tf.slice(data, [TS], [TL]) 

感谢您的任何意见!

0 个答案:

没有答案