如何在Tensorflow中为每个通道应用tf.map_fn

时间:2019-01-20 17:38:54

标签: tensorflow

我需要将tf.map_fn之类的smth应用于每个通道以获取一批img (512、32、32、3) 最后的暗淡对应于频道(rgb) 例如

x = tf.map_fn(lambda channel: func(y), x)

其中func(y)是应用于每个通道矩阵的函数,例如(512,32,32)

有什么办法吗?

或者我可以这样做

for ch in range(3):
   cp = tf.copy(x[:,:,:,ch])   #TF does not have copy, it's just pseudo code
   cp = tf.reshape(xp, [xp.shape[0], -1])
   out = func(cp)
   unsq = tf.reshape(out, [x.shape[0], 32, 32])
   [:,:,:,ch] = unsq

例如,我需要在每个通道的扁平化图像上应用func。 我完全是tf,所以我不知道该如何完成。

谢谢!

1 个答案:

答案 0 :(得分:1)

您可以在调用map_fn之前转置尺寸:

x = tf.transpose(x, perm=[3,0,1,2]) # shape 3, 512, 32, 32
x = tf.map_fn(lambda channel: func(y), x)