我需要将tf.map_fn之类的smth应用于每个通道以获取一批img (512、32、32、3) 最后的暗淡对应于频道(rgb) 例如
x = tf.map_fn(lambda channel: func(y), x)
其中func(y)
是应用于每个通道矩阵的函数,例如(512,32,32)
有什么办法吗?
或者我可以这样做
for ch in range(3):
cp = tf.copy(x[:,:,:,ch]) #TF does not have copy, it's just pseudo code
cp = tf.reshape(xp, [xp.shape[0], -1])
out = func(cp)
unsq = tf.reshape(out, [x.shape[0], 32, 32])
[:,:,:,ch] = unsq
例如,我需要在每个通道的扁平化图像上应用func。 我完全是tf,所以我不知道该如何完成。
谢谢!
答案 0 :(得分:1)
您可以在调用map_fn
之前转置尺寸:
x = tf.transpose(x, perm=[3,0,1,2]) # shape 3, 512, 32, 32
x = tf.map_fn(lambda channel: func(y), x)