我对使用TensorArray有疑问。
问题:
我想要使用tf.while_loop
访问TensorArray的元素。请注意,我可以使用例如u1.read(0)
来阅读TensorArray的内容。
我目前的代码:
以下是我到目前为止的情况:
embeds_raw = tf.constant(np.array([
[1, 1],
[1, 1],
[2, 2],
[3, 3],
[3, 3],
[3, 3]
], dtype='float32'))
embeds = tf.Variable(initial_value=embeds_raw)
container_variable = tf.zeros([512], dtype=tf.int32, name='container_variable')
sen_len = tf.placeholder('int32', shape=[None], name='sen_len')
# max_l = tf.reduce_max(sen_len)
current_size = tf.shape(sen_len)[0]
padded_sen_len = tf.pad(sen_len, [[0, 512 - current_size]], 'CONSTANT')
added_container_variable = tf.add(container_variable, padded_sen_len)
u1 = tf.TensorArray(dtype=tf.float32, size=512, clear_after_read=False)
u1 = u1.split(embeds, added_container_variable)
sentences = []
i = 0
def condition(_i, _t_array):
return tf.less(_i, current_size)
def body(_i, _t_array):
sentences.append(_t_array.read(_i))
return _i + 1, _t_array
idx, arr = tf.while_loop(condition, body, [i, u1])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sents = sess.run(arr, feed_dict={sen_len: [2, 1, 3]})
print(sents)
错误消息:
Traceback(最近一次调用最后一次):文件 " /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;, 第267行, init fetch,allow_tensor = True,allow_operation = True))File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", 第2584行,在as_graph_element中 return self._as_graph_element_locked(obj,allow_tensor,allow_operation)文件 " /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py" ;, 第2673行,在_as_graph_element_locked中 %(type(obj)。 name ,types_str))TypeError:无法将TensorArray转换为Tensor或Operation。
在处理上述异常期间,发生了另一个异常:
Traceback(最近一次调用最后一次):文件 " /home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py" ;, 第191行,in main()File" /home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", 第187行,主要 variable_container()File" /home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", 第179行,在variable_container中 sents = sess.run(arr,feed_dict = {sen_len:[2,1,3]})文件" /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session。 PY&#34 ;, 789行,在运行中 run_metadata_ptr)File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", 第984行,在_run self._graph,fetches,feed_dict_string,feed_handles = feed_handles)文件 " /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;, 第410行,在 init 中 self._fetch_mapper = _FetchMapper.for_fetch(fetches)File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", 第238行,在for_fetch中 return _ElementFetchMapper(fetches,contraction_fn)File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", 第271行,在 init 中 %(fetch,type(fetch),str(e)))TypeError:Fetch参数的类型无效,必须是 字符串或张量。 (无法将TensorArray转换为Tensor或 操作)。
答案 0 :(得分:2)
我没有足够的声誉发表评论,所以我会写一个答案。
我不太明白你的代码打算做什么,但例外是因为sess.run()返回Tensor
s,而arr
是TensorArray
。你可以做,例如:
sents = sess.run(arr.concat(), feed_dict={sen_len: [2, 1, 3]})
当然,这只会解除你的分裂。如果你想获得所有的价值,可能:
sents = sess.run([arr.read(i) for i in range(512)], feed_dict={sen_len: [2, 1, 3]})
但我确信必须有比硬编码512更清洁的方法。大概是你的while_loop意味着做某事。