此代码有什么问题?编辑:它在CPU上工作,但是在GPU上运行时失败。它会运行几次迭代,然后由于以下错误之一而失败(github issue here):
2019-12-02 12:59:29.727966: F tensorflow/core/framework/tensor_shape.cc:445] Check failed: end <= dims() (1 vs. 0)
Process finished with exit code -1073740791 (0xC0000409)
或
tensorflow.python.framework.errors_impl.InvalidArgumentError: Tried to set a tensor with incompatible shape at a list index. Item element shape: [3,3] list shape: [3]
[[{{node while/body/_1/TensorArrayV2Write/TensorListSetItem}}]] [Op:__inference_computeElement_73]
@tf.function
def computeElement_byBin():
c = tf.TensorArray(tf.int64, size=1, infer_shape=False, element_shape=(3,))
const = tf.cast(tf.constant([1, 2, 3]), tf.int64)
c = c.write(0, const)
c_c = c.concat()
return c_c
@tf.function
def computeElement():
c = tf.TensorArray(tf.int64, size=1, infer_shape=False, element_shape=(3,))
for x in tf.range(50):
byBinVariant = computeElement_byBin()
c = c.write(0, byBinVariant)
return c.concat()
k = 0
while True:
k += 1
r = computeElement()
print('iteration: %s, result: %s' % (k, r))
答案 0 :(得分:0)
我玩的更多,并且缩小了范围:
@tf.function
def computeElement():
tarr = tf.TensorArray(tf.int32, size=1,clear_after_read=False)
tarr = tarr.write(0, [1])
concat = tarr.concat()
# PROBLEM HERE
for x in tf.range(50):
concat = tarr.concat()
return concat
如果您设置tf.config.threading.set_inter_op_parallelism_threads(1)
,该错误将消失,这意味着与展开的tensorflow循环的并行化有关。知道在循环python变量而不是张量时tensorflow会静态展开,我可以确认此代码有效:
@tf.function
def computeElement(arr):
tarr = tf.TensorArray(tf.int32, size=1)
tarr = tarr.write(0, [1])
concat = tarr.concat()
a = 0
while a<arr:
concat = tarr.concat()
a+=1
return concat
k = 0
while True:
k += 1
r = computeElement(50)
所以目前的解决方案是遍历python变量而不是张量。