TensorFlow形状和类型推断

时间:2017-02-07 17:36:30

标签: json tensorflow protocol-buffers

我了解到我可以通过修改来自protobuf的JSON来编程TensorFlow。见here

如果我修改了这个JSON,那么我有时会遇到一个问题,我需要手动编辑JSON以传播各种输入和输出的正确形状。有没有办法让TF为我自动执行此操作,以便我可以通过占位符指定输入,然后自动传播形状和类型?

1 个答案:

答案 0 :(得分:3)

如果您知道要进行哪些类型的修改,则可以从占位符中删除该形状信息。不确定性将自动传播。例如:

import tensorflow as tf
placeholder = tf.placeholder(dtype=tf.float32, shape=[None])
derived = (placeholder / 3)[1:, None]
print(placeholder.get_shape(), derived.get_shape())

打印:

(TensorShape([Dimension(None)]), TensorShape([Dimension(None), Dimension(1)]))

因此,placeholder的长度不会保存静态形状信息。你甚至可以拥有未知等级的张量。

重新计算静态形状是一种诱人的思想,但目前尚不支持,因为图形构造可能依赖于静态形状信息。例如:

placeholder = tf.placeholder(dtype=tf.float32, shape=[2])
if placeholder.get_shape()[0].value % 2 == 0:
    derived = placeholder
else:
    derived = tf.concat(0, [placeholder, [0]])

这不是推荐的图形构造技术(使用tf.shapecond更好),但它确实发生了。遗憾的是,这种静态形状条件图结构未在元图中捕获。