来自tf.while_loop输出中的堆叠tensorArray的未知大小

时间:2017-06-02 21:10:30

标签: tensorflow

以下代码使用tf.while_loop(...)来计算动态长度。

    outputs_tensor_array = tf.TensorArray(tf.float32,
                                          size=0,
                                          clear_after_read=False,
                                          infer_shape=False,
                                          dynamic_size = True,
                                          element_shape[self.batch_size, self.size])

    initial_args = [outputs_tensor_array, 0]
    outputs, *_ = tf.while_loop(lambda out, idx, *_ : idx < max_len,
                                func,
                                initial_args + additional_args,
                                parallel_iterations = 32,
                                swap_memory = True)
    outputs = outputs.stack()

我想知道是否可以强制执行大小,或至少使该大小为None以强制执行大小约束并在图表中启用进一步的计算。当前形状为[?, batch, hidden_size]

1 个答案:

答案 0 :(得分:1)

tensor.set_shape将优化静态形状信息并在与当前静态形状信息不兼容时抛出错误(在TensorArray.stack()情况下,它将允许您为第0个维度的静态形状信息设置任何值)。

tf.reshape对于声明/填充形状信息也很有用,尽管它并不完美。如果执行图形时Tensor的大小错误,则只会抛出错误(否则可能会隐藏下游的形状错误)。

更复杂,但您也可以set_shape获取静态形状信息,然后使用tf.Asserttf.shape来检查执行图形时的Tensor形状。