如何在Keras / TensorFlow中迭代张量?

时间:2019-08-13 22:39:28

标签: python tensorflow keras

我正在尝试在Tensorflow中重写以下Python代码。但是,我在使用tf.map_fn遍历张量时遇到麻烦。

这里depth是形状[batch_size,256,256]的张量,normal是形状[batch_size,256,256,3]的张量,scale是形状{{1}的张量}:

[batch_size,256,256]

我收到一条错误消息:

for b in range(0,batch_size):
    depth[b,:,:] = [scale[b,0,0] + (scale[b,0,1] - scale[b,0,0])* x for x in depth[b,:,:]]
    normal[b,:,:,:] = [scale[b,0,2] + (scale[b,0,3] - scale[b,0,2])* x for x in normal[b,:,:,:]]

1 个答案:

答案 0 :(得分:1)

您正在做的事情可以简单地写成矩阵元素方式的操作:

depth_new = scale[:,0:1,0:1] + (scale[:,0:1,1:2] - scale[:,0:1,0:1]) * depth
normal_new = scale[:,0:1,2:3] + (scale[:,0:1,3:4] - scale[:,0:1,2:3]) * normal

请注意,我们已将范围用于大小为1的切片(例如,0:1代替0,或3:4代替3)以保留轴,使其可以广播(有关更多信息,请参见TensorFlow broadcasting guide或Numpy广播指南herehere)。