如何在Tensorflow中使用tf.while_loop访问TensorArray元素?

时间:2017-08-02 13:32:00

标签: python tensorflow

我对使用TensorArray有疑问。

问题:
我想要使​​用tf.while_loop访问TensorArray的元素。请注意,我可以使用例如u1.read(0)来阅读TensorArray的内容。

我目前的代码:
以下是我到目前为止的情况:

embeds_raw = tf.constant(np.array([
    [1, 1],
    [1, 1],
    [2, 2],
    [3, 3],
    [3, 3],
    [3, 3]
], dtype='float32'))
embeds = tf.Variable(initial_value=embeds_raw)
container_variable = tf.zeros([512], dtype=tf.int32, name='container_variable')
sen_len = tf.placeholder('int32', shape=[None], name='sen_len')
# max_l = tf.reduce_max(sen_len)
current_size = tf.shape(sen_len)[0]
padded_sen_len = tf.pad(sen_len, [[0, 512 - current_size]], 'CONSTANT')
added_container_variable = tf.add(container_variable, padded_sen_len)
u1 = tf.TensorArray(dtype=tf.float32, size=512, clear_after_read=False)
u1 = u1.split(embeds, added_container_variable)

sentences = []
i = 0

def condition(_i, _t_array):
    return tf.less(_i, current_size)

def body(_i, _t_array):
    sentences.append(_t_array.read(_i))
    return _i + 1, _t_array

idx, arr = tf.while_loop(condition, body, [i, u1])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sents = sess.run(arr, feed_dict={sen_len: [2, 1, 3]})
    print(sents)

错误消息:

  

Traceback(最近一次调用最后一次):文件   " /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,   第267行, init       fetch,allow_tensor = True,allow_operation = True))File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py",   第2584行,在as_graph_element中       return self._as_graph_element_locked(obj,allow_tensor,allow_operation)文件   " /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py" ;,   第2673行,在_as_graph_element_locked中       %(type(obj)。 name ,types_str))TypeError:无法将TensorArray转换为Tensor或Operation。

在处理上述异常期间,发生了另一个异常:

  

Traceback(最近一次调用最后一次):文件   " /home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py" ;,   第191行,in       main()File" /home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py",   第187行,主要       variable_container()File" /home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py",   第179行,在variable_container中       sents = sess.run(arr,feed_dict = {sen_len:[2,1,3]})文件" /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session。 PY&#34 ;,   789行,在运行中       run_metadata_ptr)File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",   第984行,在_run       self._graph,fetches,feed_dict_string,feed_handles = feed_handles)文件   " /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,   第410行,在 init 中       self._fetch_mapper = _FetchMapper.for_fetch(fetches)File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",   第238行,在for_fetch中       return _ElementFetchMapper(fetches,contraction_fn)File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",   第271行,在 init 中       %(fetch,type(fetch),str(e)))TypeError:Fetch参数的类型无效,必须是   字符串或张量。 (无法将TensorArray转换为Tensor或   操作)。

1 个答案:

答案 0 :(得分:2)

我没有足够的声誉发表评论,所以我会写一个答案。

我不太明白你的代码打算做什么,但例外是因为sess.run()返回Tensor s,而arrTensorArray。你可以做,例如:

sents = sess.run(arr.concat(), feed_dict={sen_len: [2, 1, 3]})

当然,这只会解除你的分裂。如果你想获得所有的价值,可能:

sents = sess.run([arr.read(i) for i in range(512)], feed_dict={sen_len: [2, 1, 3]})

但我确信必须有比硬编码512更清洁的方法。大概是你的while_loop意味着做某事。