如何在TensorArray中获取包含不同形状张量的值

时间:2019-06-27 08:26:52

标签: python tensorflow

我得到了一个TensorArray,其中包含一个通过tf.while_loop()的各种形状张量的列表,但是我不知道如何将它们作为带有张量的普通列表来获取。

例如:

TensorArray([[1,2], [1,2,3], ...]) -> [Tensor([1,2]), Tensor([1,2,3]), ...]
res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
res = res.write(0, (1, 2))
res = res.write(0, (1, 2, 3))
with tf.Session() as sess:                                                        
     print sess.run(res.stack())

我在sess.run(res.stack())中收到错误消息

  

TensorArray的形状不一致。索引0的形状为[2],但索引1的形状为[3]

1 个答案:

答案 0 :(得分:0)

通常,您无法在张量数组中列出张量列表,因为其大小仅在执行图形时才知道。但是,如果您事先知道大小,则可以自己列出读取操作的列表:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
    res = res.write(0, (1, 2))
    res = res.write(1, (1, 2, 3))
    def loop_body(i, res):
        # Must import the following in Python 2:
        # from __future__ import print_function
        with tf.control_dependencies([tf.print(res.read(i))]):
            return i + 1, res
    i, res = tf.while_loop(
        lambda i, res: i < res.size(),
        loop_body,
        (tf.constant(0, tf.int32), res))
    print(sess.run(i))
    # [1 2]
    # [1 2 3]
    # 2

否则,您仍然可以使用while循环来迭代张量数组。例如,您可以这样打印其内容:

new Promise(function (resolve)
{   
    new Promise(function (resolve)
    {   
            $( '#button' ).addClass( '_loading_' );
            resolve(1);
    } ).then(function (value)
    {

        /*
        //very large code
        */

    } );
    resolve(1);

} ).then(function (value)
{
    setTimeout( function()
    {
        $( '#button' ).removeClass( '_loading_' );
    },100 );
} );