我试图创建一个张量列表,并使用tensorflow2中的for循环将它们堆叠在一起。我创建了一个测试示例,并尝试如下进行测试。
import tensorflow as tf
@tf.function
def test(x):
tensor_list = []
for i in tf.range(x):
tensor_list.append(tf.ones(4)*tf.cast(i, tf.float32))
return tf.stack(tensor_list)
result = test(5)
print(result)
但是我收到如下错误:
Traceback (most recent call last):
File "test.py", line 10, in <module>
result = test(5)
File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 457, in __call__
result = self._call(*args, **kwds)
File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 503, in _call
self._initialize(args, kwds, add_initializers_to=initializer_map)
File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 408, in _initialize
*args, **kwds))
File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1848, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2150, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2041, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 915, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 358, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 905, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.InaccessibleTensorError: in converted code:
test.py:8 test *
return tf.stack(tensor_list)
/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/util/dispatch.py:180 wrapper
return target(*args, **kwargs)
/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/ops/array_ops.py:1165 stack
return gen_array_ops.pack(values, axis=axis, name=name)
/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_array_ops.py:6304 pack
"Pack", values=values, axis=axis, name=name)
/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/op_def_library.py:793 _apply_op_helper
op_def=op_def)
/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:544 create_op
inp = self.capture(inp)
/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:603 capture
% (tensor, tensor.graph, self))
InaccessibleTensorError: The tensor 'Tensor("mul:0", shape=(4,), dtype=float32)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_8, id=139870442952744); accessed from: FuncGraph(name=test, id=139870626510608).
有人知道我在做什么错吗?如何创建张量列表并将其与tensorflow 2中的for循环堆叠在一起?
答案 0 :(得分:2)
使用tf.TensorArray
代替list
答案 1 :(得分:1)
通常应该使用“ tf.map_fn”来完成张量的循环。这是一个可行的解决方案:
import tensorflow as tf
import numpy as np
@tf.function
def test(x):
tensor_list = tf.map_fn(lambda inp: tf.ones(4)*tf.cast(inp, tf.float32), x, dtype=tf.dtypes.float32)
return tf.stack(tensor_list)
result = test(np.arange(5))
print(result)
但是,您必须在test()
函数中提供一个实数数组,但也可以在tf.range()
内调用tf.function
来将标量转换为张量。