(tensorflow版本:'0.12.head')
TensorArray.close
的文档说它关闭当前的TensorArray 。对TensorArray
的状态意味着什么?我尝试以下代码
import tensorflow as tf
sess = tf.InteractiveSession()
a1 = tf.TensorArray(tf.int32, 2)
a1.close().run()
a2 = a1.write(0, 0)
a2.close().run()
print(a2.read(0).eval())
并且没有错误。 close
的用法是什么?
Learning-to-learn includes TensorArray.close
in the reset operations of the network。我无法弄清楚评论Empty array as part of the reset process的含义。
更新
例如,
import tensorflow as tf
sess = tf.InteractiveSession()
N = 3
def cond(i, arr):
return i < N
def body(i, arr):
arr = arr.write(i, i)
i += 1
return i, arr
arr = tf.TensorArray(tf.int32, N)
_, result_arr = tf.while_loop(cond, body, [0, arr])
reset = arr.close() # corresponds to https://github.com/deepmind/learning-to-learn/blob/6ee52539e83d0452051fe08699b5d8436442f803/meta.py#L370
NUM_EPOCHS = 3
for _ in range(NUM_EPOCHS):
reset.run() # corresponds to https://github.com/deepmind/learning-to-learn/blob/6ee52539e83d0452051fe08699b5d8436442f803/util.py#L32
print(result_arr.stack().eval())
为什么arr.close()
不会使while循环失败?在每个纪元的开头调用arr.close()有什么好处?
答案 0 :(得分:1)
这是一个包含本机操作系统的Python操作系统,它们都有帮助字符串,但本机操作帮助字符串提供了更多信息。如果您查看inspect.getsourcefile(fx_array.close)
,它会指向tensorflow/python/ops/tensor_array_ops.py
。在实现中,您会看到它遵循_tensor_array_close_v2
。所以你可以这样做
> from tensorflow.python.ops import gen_data_flow_ops
> help(gen_data_flow_ops._tensor_array_close_v2)
Delete the TensorArray from its resource container. This enables
the user to close and release the resource in the middle of a step/run.
同一个文档字符串也位于TensorArrayCloseV2
查看tensorflow/core/kernels/tensor_array_ops.cc您看到TensorArrayCloseOp
是TensorArrayCloseV2
注册的实施,并且有更多信息
// Delete the TensorArray from its resource container. This enables
// the user to close and release the resource in the middle of a step/run.
// TODO(ebrevdo): decide whether closing the grad op should happen
// here or on the python side.
class TensorArrayCloseOp : public OpKernel {
public:
explicit TensorArrayCloseOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
TensorArray* tensor_array;
OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array));
core::ScopedUnref unref(tensor_array);
// Instead of deleting this TA from the ResourceManager, we just
// clear it away and mark it as closed. The remaining memory
// consumed store its mutex and handle Tensor. This will be
// cleared out at the end of the step anyway, so it's fine to keep
// it around until the end of the step. Further calls to the
// TensorArray will fail because TensorArray checks internally to
// see if it is closed or not.
描述似乎与您看到的行为不一致,可能是一个错误。
答案 1 :(得分:0)
学习与学习示例中关闭的TensorArray
不是传递给while循环的原始TensorArray
。
# original array (fx_array) declared here
fx_array = tf.TensorArray(tf.float32, size=len_unroll + 1,
clear_after_read=False)
# new array (fx_array) returned here
_, fx_array, x_final, s_final = tf.while_loop(
cond=lambda t, *_: t < len_unroll,
body=time_step,
loop_vars=(0, fx_array, x, state),
parallel_iterations=1,
swap_memory=True,
name="unroll")
从此处对fx_array.close()
的任何后续调用都会关闭while循环返回的新数组,而不是在第一次迭代中传递给循环的原始数组。
如果您想了解close
的行为方式,请运行:
session.run([reset, loss])
由于TensorArray has already been closed.
操作尝试在已关闭的数组上运行loss
,因此pack()
会失败。