如何在Tensorflow中取消堆叠可变大小张量?

时间:2018-04-12 09:45:56

标签: python tensorflow

我通过占位符提供输入数据,如下所示:

input = tf.placeholder(tf.float32, [None, batch_size]) 
inputs = tf.unstack(input, axis=0)

显然这与ValueError: Cannot infer num from shape (?, 32)崩溃了。

如何沿轴0取出input

1 个答案:

答案 0 :(得分:0)

你不能沿着未知的维度拆除这样的张量。文档:

  

通过沿轴切削来从值中解包数量张量   尺寸。如果未指定num(默认值),则从中推断出   价值的形状。如果value.shape [axis]未知,则ValueError为   提高。

您认为这是如何工作的?对于不同的维度值,图表将完全不同。

即使像tf.dynamic_partition这样的解决方法在这种情况下也无济于事,因为它的参数必须是一个整数,在图形构建阶段已经知道了。

如果您知道未知维度的上限,可以尝试:

import tensorflow as tf
import numpy as np

data = np.random.randn(10, 8)

tensor = tf.placeholder(tf.float32, [10, 8])
rets = tf.unstack(tensor, axis=0)


tensor2 = tf.placeholder(tf.float32, [None, 8])


dyn_shape = tf.shape(tensor2)
b_op = dyn_shape[0]
partition = tf.range(b_op)
some_large_number = 20
rets2 = tf.dynamic_partition(tensor2, partition, some_large_number)

with tf.Session() as sess:
    print sess.run(rets, {tensor:data})
    print sess.run(rets2, {tensor2:data})

给出输出:

[array([-0.60014623, -0.0812249 ,  1.5079778 , -0.45486602, -1.3389106 ,
       -1.3552084 ,  2.3415568 ,  0.24747756], dtype=float32), array([-0.2079824 ,  0.33814394,  0.8470432 , -1.3832365 , -0.01087348,
       -0.13608357,  0.89929885, -1.2724507 ], dtype=float32), array([ 0.36865985,  0.45177847, -1.1189924 ,  1.2984366 , -0.67447174,
        2.3120618 ,  0.91252357, -0.13333966], dtype=float32), array([ 1.0067816 ,  1.2311213 ,  0.03433327, -0.09440815,  0.01012954,
       -2.0957463 , -0.49972147, -0.30406335], dtype=float32), array([-0.5904513 ,  0.49920034, -1.5793694 ,  1.3227024 , -0.93950355,
       -0.03706869, -0.1222709 ,  2.0227952 ], dtype=float32), array([-0.06153346, -0.7300583 ,  1.7760276 ,  0.13010012, -1.7523713 ,
       -0.52992773,  1.367956  ,  0.48238465], dtype=float32), array([ 1.2311738 , -0.72093534, -0.28476417, -1.1963955 ,  0.60491234,
        0.35766497, -0.4614565 ,  1.0839593 ], dtype=float32), array([ 1.0952466 , -2.5115075 ,  1.6301945 ,  0.20886853,  0.8650316 ,
       -0.56956375,  0.08775095, -1.4105127 ], dtype=float32), array([ 1.3576531 ,  0.5293029 ,  0.60603464, -0.41250053,  1.0304515 ,
        0.71655554, -1.2762316 , -1.1565298 ], dtype=float32), array([-0.26633576,  1.5087231 , -0.0391343 ,  0.40856156, -0.6008501 ,
        0.3730529 ,  0.28835198,  0.20331612], dtype=float32)]
[array([[-0.60014623, -0.0812249 ,  1.5079778 , -0.45486602, -1.3389106 ,
        -1.3552084 ,  2.3415568 ,  0.24747756]], dtype=float32), array([[-0.2079824 ,  0.33814394,  0.8470432 , -1.3832365 , -0.01087348,
        -0.13608357,  0.89929885, -1.2724507 ]], dtype=float32), array([[ 0.36865985,  0.45177847, -1.1189924 ,  1.2984366 , -0.67447174,
         2.3120618 ,  0.91252357, -0.13333966]], dtype=float32), array([[ 1.0067816 ,  1.2311213 ,  0.03433327, -0.09440815,  0.01012954,
        -2.0957463 , -0.49972147, -0.30406335]], dtype=float32), array([[-0.5904513 ,  0.49920034, -1.5793694 ,  1.3227024 , -0.93950355,
        -0.03706869, -0.1222709 ,  2.0227952 ]], dtype=float32), array([[-0.06153346, -0.7300583 ,  1.7760276 ,  0.13010012, -1.7523713 ,
        -0.52992773,  1.367956  ,  0.48238465]], dtype=float32), array([[ 1.2311738 , -0.72093534, -0.28476417, -1.1963955 ,  0.60491234,
         0.35766497, -0.4614565 ,  1.0839593 ]], dtype=float32), array([[ 1.0952466 , -2.5115075 ,  1.6301945 ,  0.20886853,  0.8650316 ,
        -0.56956375,  0.08775095, -1.4105127 ]], dtype=float32), array([[ 1.3576531 ,  0.5293029 ,  0.60603464, -0.41250053,  1.0304515 ,
         0.71655554, -1.2762316 , -1.1565298 ]], dtype=float32), array([[-0.26633576,  1.5087231 , -0.0391343 ,  0.40856156, -0.6008501 ,
         0.3730529 ,  0.28835198,  0.20331612]], dtype=float32), 
         array([], shape=(0, 8), dtype=float32), array([], shape=(0, 8), dtype=float32), array([], shape=(0, 8), dtype=float32), array([], shape=(0, 8), dtype=float32), array([], shape=(0, 8), dtype=float32), array([], shape=(0, 8), dtype=float32), array([], shape=(0, 8), dtype=float32), array([], shape=(0, 8), dtype=float32), array([], shape=(0, 8), dtype=float32), array([], shape=(0, 8), dtype=float32)]