Tensorflow连接未指定形状的张量

时间:2018-04-29 16:47:42

标签: python tensorflow concatenation shape

我试图连接两个张量。不幸的是,一些形状尺寸信息似乎在这个过程中丢失了。

我从一个Tensor(在我的例子中是一个翻译姿势)开始,形状为[Batch,3],例如:pose_t

然后我将这个张量分成两个形状的张量[Batch,2]和[Batch]

centroid = pose_t[:,:2]
tz = pose_t [:,2]

然后我对尺寸变为[Batch,28,28,2]

的质心进行一些处理

最后我想连接处理过的质心张量和tz张量来得到一个pose_t张量的形状[Batch,28,28,3]

因此我将expand_dims()应用于tz三次并将tile轴1和& 2之后我得到一个形状的张量[1,28,28,?]虽然我想要/需要的是形状[?,28,28,1]

不幸的是我认为在质心和tz的pose_t开始分裂期间,一些形状信息丢失了:

第一个维度应该仍然是批量维度但是如果我输出形状,则在未定义之前,批量维度[?,...]设置为1,其中前面定义的最后一个维度现在是未定义的。 / p>

不,我有连接形状张量[?,28,28,2]和[1,28,28,?]的问题,这给我一个错误。

以下完整代码:

# Process Centroid
cent_deltas = utils.compute_cent_deltas_graph(positive_rois, pose_t[:,0:2], config.MASK_SHAPE[0])
# Append tz from pose_t to cent_deltas in correct dimension
# Expand Dimension 3 times and scale each dimension to propper size
tz = pose_t[:,2]
tz = tf.tile(tf.expand_dims(tf.expand_dims(tf.expand_dims(tz,axis=0),axis=1),axis=2),multiples=[1,config.MASK_SHAPE[0], config.MASK_SHAPE[0],1])
pose_t = tf.concat([cent_deltas, tz],axis=3)

非常感谢所有帮助! 感谢

1 个答案:

答案 0 :(得分:0)

毕竟expand_dims,你拥有的形状是[1, 1, 1, batch_size]。平铺此形状会将batch_size保留为最后一个维度。您最后需要expand_dims,首先保持批量维度:

tz = tf.expand_dims(tz, axis=-1)
tz = tf.expand_dims(tz, axis=-1)
tz = tf.expand_dims(tz, axis=-1)
tz = tf.tile(tz, ...)