有没有为tf.decode_csv设置record_defaults的一般方法?

时间:2017-04-28 07:34:20

标签: python tensorflow

我对TensorFlow相当新,因此遇到了一些困难。在完成数据的预处理后,because I don't know how to generate batches,我将数据保存为csv。然后我尝试在tensorflow中读取它,但在解码('tf.decode_csv')时,record_defaults参数是必要的。但是我的数据中有很多列,因此分配record_defaults确实需要时间。那么我们如何设置所有值0(假设我们不知道具体的列数)?

2 个答案:

答案 0 :(得分:0)

我试图找到一种方法来跳过/绕过它,但似乎没有办法忽略记录默认值。

但是,既然您无论如何都必须知道行中CSV的长度才能在Tensorflow中读取它,我到目前为止找到的最简单方法就是用这个单行预填充默认值:

rDefaults = [['a'] * num_cells_in_your_row]

所以我的数据是每行约800列,这样我就不必单独解决它们了。另外,在我的情况下,读入的数据需要是String格式,但您可以将初始值设置为零/等。而不是'万一你需要花车......

***更新:

如上所述,上述内容并不限制您在整个行中数据类型统一的情况。以下是将行中的特定单元格转换为所需数据类型的方法:

rDefaults = [['a'] * num_cells_in_your_row]

def read_from_csv(filename_queue):
    reader = tf.TextLineReader(skip_header_lines=False)
    _, csv_row = reader.read(filename_queue)
    data = tf.decode_csv(csv_row, record_defaults=rDefaults)
    dateLbl = tf.slice(data, [0], [CD]) # portion of the row that is 'String'
    crossLbl = tf.slice(data, [CD], [CC])# Also 'String'
# this part converts the rest of from String to float:
    obs = tf.string_to_number(tf.slice(data, [CD + CC], [SEQLEN]), tf.float32)
    return dateLbl, crossLbl, obs

希望这会有所帮助......

答案 1 :(得分:0)

这是我发现的最优雅的方式:

return avg;