我正在尝试在Tensorflow中重写以下Python代码。但是,我在使用tf.map_fn
遍历张量时遇到麻烦。
这里depth
是形状[batch_size,256,256]
的张量,normal
是形状[batch_size,256,256,3]
的张量,scale
是形状{{1}的张量}:
[batch_size,256,256]
我收到一条错误消息:
for b in range(0,batch_size):
depth[b,:,:] = [scale[b,0,0] + (scale[b,0,1] - scale[b,0,0])* x for x in depth[b,:,:]]
normal[b,:,:,:] = [scale[b,0,2] + (scale[b,0,3] - scale[b,0,2])* x for x in normal[b,:,:,:]]
答案 0 :(得分:1)
您正在做的事情可以简单地写成矩阵元素方式的操作:
depth_new = scale[:,0:1,0:1] + (scale[:,0:1,1:2] - scale[:,0:1,0:1]) * depth
normal_new = scale[:,0:1,2:3] + (scale[:,0:1,3:4] - scale[:,0:1,2:3]) * normal
请注意,我们已将范围用于大小为1的切片(例如,0:1
代替0
,或3:4
代替3
)以保留轴,使其可以广播(有关更多信息,请参见TensorFlow broadcasting guide或Numpy广播指南here或here)。