我在TensorFlow的csv阅读器中遗漏了什么?

时间:2015-11-13 05:27:40

标签: tensorflow

它主要是网站上教程的复制粘贴。我收到一个错误:

  

无效参数:ConcatOp:预期的连接维度   范围[0,0],但得到0 [[节点:concat = Concat [N = 4,T = DT_INT32,   _device =“/ job:localhost / replica:0 / task:0 / cpu:0”](concat / concat_dim,DecodeCSV,DecodeCSV:1,DecodeCSV:2,DecodeCSV:3)]]

我的csv文件的内容是:

  

3,4,1,8,4

 import tensorflow as tf


filename_queue = tf.train.string_input_producer(["test2.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults)
# print tf.shape(col1)

features = tf.concat(0, [col1, col2, col3, col4])
with tf.Session() as sess:
  # Start populating the filename queue.
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)

  for i in range(1200):
    # Retrieve a single instance:
    example, label = sess.run([features, col5])

  coord.request_stop()
  coord.join(threads)

2 个答案:

答案 0 :(得分:13)

由于程序中张量的形状,问题就出现了。 TL; DR 您应该使用tf.concat()代替tf.pack(),这会将四个标量col张量转换为1-D长度为4的张量。

在开始之前,请注意您可以在任何get_shape()对象上使用Tensor方法来获取有关该张量的静态形状信息。例如,代码中的注释掉的行可以是:

print col1.get_shape()
# ==> 'TensorShape([])' - i.e. `col1` is a scalar.

value返回的reader.read()张量是一个标量字符串。 tf.decode_csv(value, record_defaults=[...])record_defaults的每个元素生成与value形状相同的张量,即本例中的标量。标量是具有单个元素的0维张量。标量没有定义tf.concat(i, xs):它将N维张量列表(xs)连接到一个新的N维张量,沿着维i,其中0 <= i < N,如果i,则无效N = 0

tf.pack(xs)运算符旨在简单地解决此问题。它需要一个k N维张量列表(具有相同的形状)并将它们打包成第0维中大小为k的N + 1维张量。如果您将tf.concat()替换为tf.pack(),您的程序将会正常运行:

# features = tf.concat(0, [col1, col2, col3, col4])
features = tf.pack([col1, col2, col3, col4])

with tf.Session() as sess:
  # Start populating the filename queue.
  # ...

答案 1 :(得分:1)

我也坚持这个tutorial。当我更改with tf.Session() for:

时,我能够将另一个问题换成另一个问题
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

for i in range(2):
    #print i
    example, label = sess.run([features, col5])

coord.request_stop()
coord.join(threads)

sess.close()

错误消失了,TF开始运行,但看起来它被卡住了。如果取消注释# print,您将看到只运行一次迭代。最有可能这不是真的有用(因为我交换了一个错误无限执行)。