Tensorflow:读取CSV

时间:2017-11-18 21:58:53

标签: tensorflow

filename_queue = tf.train.string_input_producer([csv_file_path], shuffle=False)
reader = tf.TextLineReader()
_, serialized_example = reader.read(filename_queue)
filename = tf.decode_csv(serialized_example, record_defaults=[[""]], field_delim=',')

# Input
png = tf.read_file(filename)

我正在阅读带有一列的CSV文件。 我正在关注 errorValueError: **Shape** must be rank 0 but is rank 1 for 'ReadFile' (op: 'ReadFile') with input shapes: [1]. 有人可以告诉我这个问题吗?

1 个答案:

答案 0 :(得分:2)

tf.read_file()需要标量输入(即,只有一个字符串),但tf.decode_csv的结果将返回“等级1”上下文,即1-D列表。您需要取消引用结果:

filename = tf.decode_csv(serialized_example, record_defaults=[[""]], field_delim=',')
filename = filename[0]   # <-- add this.
png = tf.read_file(filename)

有关更多详细信息,请参阅tf.decode_csv的文档 - 请注意,返回类型是Tensor对象的列表。