Tensorflow 2:嵌套TensorArray

时间:2019-12-02 17:02:01

标签: tensorflow

此代码有什么问题?编辑:它在CPU上工作,但是在GPU上运行时失败。它会运行几次迭代,然后由于以下错误之一而失败(github issue here):

2019-12-02 12:59:29.727966: F tensorflow/core/framework/tensor_shape.cc:445] Check failed: end <= dims() (1 vs. 0)

Process finished with exit code -1073740791 (0xC0000409)

tensorflow.python.framework.errors_impl.InvalidArgumentError:  Tried to set a tensor with incompatible shape at a list index. Item element shape: [3,3] list shape: [3]
     [[{{node while/body/_1/TensorArrayV2Write/TensorListSetItem}}]] [Op:__inference_computeElement_73]

@tf.function
def computeElement_byBin():
    c = tf.TensorArray(tf.int64, size=1, infer_shape=False, element_shape=(3,))
    const = tf.cast(tf.constant([1, 2, 3]), tf.int64)
    c = c.write(0, const)
    c_c = c.concat()
    return c_c

@tf.function
def computeElement():
    c = tf.TensorArray(tf.int64, size=1, infer_shape=False, element_shape=(3,))
    for x in tf.range(50):
        byBinVariant = computeElement_byBin()
        c = c.write(0, byBinVariant)
    return c.concat()

k = 0
while True:
    k += 1
    r = computeElement()
    print('iteration: %s, result: %s' % (k, r))

1 个答案:

答案 0 :(得分:0)

我玩的更多,并且缩小了范围:

@tf.function
def computeElement():
    tarr = tf.TensorArray(tf.int32, size=1,clear_after_read=False)
    tarr = tarr.write(0, [1])
    concat = tarr.concat()

    # PROBLEM HERE
    for x in tf.range(50):
        concat = tarr.concat()

    return concat

如果您设置tf.config.threading.set_inter_op_parallelism_threads(1),该错误将消失,这意味着与展开的tensorflow循环的并行化有关。知道在循环python变量而不是张量时tensorflow会静态展开,我可以确认此代码有效:

@tf.function
def computeElement(arr):
    tarr = tf.TensorArray(tf.int32, size=1)
    tarr = tarr.write(0, [1])
    concat = tarr.concat()

    a = 0
    while a<arr:
        concat = tarr.concat()
        a+=1

    return concat

k = 0
while True:
    k += 1
    r = computeElement(50)

所以目前的解决方案是遍历python变量而不是张量。