Tensorflow tf.split()列表索引超出范围?

时间:2017-04-24 13:05:40

标签: tensorflow

以下是代码:

a = tf.constant([1,2,3,4])
b = tf.constant([4])
c = tf.split(a, tf.squeeze(b))

然后,结果证明是错误的:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/jeff/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 1203, in split
    num = size_splits_shape.dims[0]
IndexError: list index out of range

但为什么?

1 个答案:

答案 0 :(得分:2)

The docs州,

  

如果num_or_size_splits是张量,size_splits,则将值拆分为len(size_splits)个段。第i个棋子的形状与除了尺寸为size_splits [i]的尺寸轴之外的尺寸相同。

请注意,size_splits需要是可剪切的。

但是当你squeeze(b)时,因为它在你的例子中只有一个元素,它会返回一个没有维度的标量。标量不能被切片:

b_ = tf.squeeze(b)
b_[0] # error

因此你的错误。