TensorFlow decode_csv形状错误

时间:2018-03-25 08:37:31

标签: python tensorflow tensorflow-datasets

我使用*.csvtf.data.TextLineDataset文件中读取并在其上应用map

dataset = tf.data.TextLineDataset(os.path.join(data_dir, subset, 'label.txt'))
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
                          num_parallel_calls=num_parallel_calls)

解析函数parse_record_fn如下所示:

def parse_record(raw_record, is_training):
    default_record = ["./", -1]
    filename, label = tf.decode_csv([raw_record], default_record)
    # do something
    return image, label

但是在解析函数中ValueError引发了tf.decode_csv

ValueError: Shape must be rank 1 but is rank 0 for 'DecodeCSV' (op: 'DecodeCSV') with input shapes: [1], [], [].

我的*.csv文件示例:

/data/1.png, 5
/data/2.png, 7

问题

  1. 哪里出错?
  2. shapes: [1], [], []是什么意思?
  3. 重现

    此错误可在此代码中重现:

    import tensorflow as tf
    import os
    
    def parse_record(raw_record, is_training):
        default_record = ["./", -1]
        filename, label = tf.decode_csv([raw_record], default_record)
    
        # do something
    
        return image, label
    
    with tf.Session() as sess:
        csv_path = './labels.txt'
    
    
        dataset = tf.data.TextLineDataset(csv_path)
    
        dataset = dataset.map(lambda value: parse_record(value, True))
    
    
    sess.run(dataset)
    

1 个答案:

答案 0 :(得分:3)

查看tf.decode_csv的文档,它说明了默认记录:

  

record_defaults:具有特定类型的Tensor对象列表。   可接受的类型是float32,float64,int32,int64,string。一   输入记录的每列张量,具有标量默认值   该列的值,如果需要该列,则为空。

我相信你得到的错误源于你如何定义张量default_record。您的default_record肯定是张量对象(或可转换为张量的对象)的列表,但我认为错误消息告诉它们应该是1级张量,而不是像你的情况那样的0级张量。

您可以通过将默认记录排名为1张张来解决问题。请参阅以下玩具示例:

import tensorflow as tf

my_line = 'filename.png, 10'
default_record_1 = [['./'], [-1]] # do this!
default_record_2 = ['./', -1] # this is what you do now

decoded_1 = tf.decode_csv(my_line, default_record_1)
with tf.Session() as sess:
    d = sess.run(decoded_1)
    print(d)

# This will cause an error
decoded_2 = tf.decode_csv(my_line, default_record_2)

最后一行产生的错误很常见:

  

ValueError:Shape必须为1级,但对于' DecodeCSV_1'为0。 (OP:   ' DecodeCSV')输入形状:[],[],[]。

在消息中,输入形状(三个括号[])指的是{{1}的输入参数recordsrecord_defaultsfield_delim的形状}}。在您的情况下,由于您输入tf.decode_csv,因此第一个形状为[1]。我同意这个案子的信息不是很有用......