如何使用嵌套形状的tf.data.Dataset.padded_batch?

时间:2017-11-03 19:33:15

标签: python tensorflow tensorflow-datasets

我正在为每个元素构建一个具有两个形状[张,宽,高,3]和[批次,类]张量的数据集。为简单起见,我们可以说class = 5.

您向dataset.padded_batch(1000,shape)提供哪种形状,以便沿宽度/高度/ 3轴填充图像?

我尝试了以下内容:

tf.TensorShape([[None,None,None,3],[None,5]])
[tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])]
[[None,None,None,3],[None,5]]
([None,None,None,3],[None,5])
(tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])‌​)

每次提出TypeError

The docs州:

  

padded_shapes:tf.TensorShape或tf.int64向量的嵌套结构   类似张量的物体代表各自的形状   每个输入元素的组件应在批处理之前填充。   任何未知的尺寸(例如tf.TensorShape中的tf.Dimension(None))   在一个类似张量的对象中,-1将被填充到每个批次中该维度的最大大小。

相关代码:

dataset = tf.data.Dataset.from_generator(generator,tf.float32)
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)

2 个答案:

答案 0 :(得分:1)

感谢mrry寻找解决方案。事实证明,from_generator中的类型必须与条目中的张量数相匹配。

新代码:

dataset = tf.data.Dataset.from_generator(generator,(tf.float32,tf.float32))
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)

答案 1 :(得分:0)

TensorShape不接受嵌套列表。 tf.TensorShape([None, None, None, 3, None, 5])TensorShape(None)(请注意没有[])是合法的。

然而,结合这两个张量听起来很奇怪。我不确定你想要完成什么,但我建议你尝试不使用不同尺寸的张量。