如何在U-net内的张量(Tensorflow 2.0)中翻转/旋转特定特征图?函数tf.tensor_scatter_nd_update

时间:2019-10-28 19:11:38

标签: python tensorflow keras

我有一个在Tensorflow 2.0(和tf.keras)中实现的简单Dense-U-net来进行语义分割。在网络中的某些点,我想获取张量并仅翻转/旋转一些特征图。具体而言,初始张量的形状为[α,256,256,5](NHWC)。当前,批量大小为N = 8。

  

为了澄清:请注意,我的目的不是进行增强,而是为了   分析中间特征图遭受的网络行为   轻微的方向变化。研究了类似的想法,或者   翻转/旋转内核或特征图(Gao et al.2017,   高效且不变的CNN用于密集预测,arXiv:1711.09064)。

我创建了一个简单的图层来执行此操作(因为尝试通过某个函数执行此操作会带来问题;请参见类似的问题here)。

功能/层应执行以下操作:

  • 输入参数:张量(NHWC),C的值(代码中的 growth_rate )和要翻转/旋转的特征图的百分比( rotperc ) 。 完成!
  • 我随机选择了所需数量的特征图。 完成!
  • 我翻转/旋转那些特征图。 完成!
  • 我使用翻转/旋转的特征图更新初始张量。 解决!

这是代码:

class MyPartialRotationFlipLayer(layers.Layer):
    def __init__(self, growth_rate, rotperc=0.2):
      super(MyPartialRotationFlipLayer, self).__init__()
      self.nCh = int(tf.multiply(rotperc, growth_rate))
      if self.nCh < 1:
        self.nCh = 1

    def call(self, input):
      ixCh = tf.random.shuffle(tf.range(0, input.shape[-1]))
      ixCh = ixCh[0:self.nCh]           # Randomly select some channel indices
      selChn = tf.gather(input, ixCh, axis=-1) # Get the tensors for those channels (4D tensor)
      selChn = tf.image.random_flip_left_right(selChn) 
      selChn = tf.image.rot90(selChn, k=tf.random.uniform(shape=[], minval=1, 
                                                          maxval=4, 
                                                          dtype=tf.int32))

      # From this point onwards, it is wrong:
      indGrid = tf.stack(tf.meshgrid(tf.range(input.shape[0]), 
                                     tf.range(input.shape[1]), 
                                     tf.range(input.shape[2]), 
                                     ixCh,
                                     indexing='ij'))
      return tf.tensor_scatter_nd_update(input, indGrid, selChn)

例如,使用默认的rotperc = 0.2,我仅翻转/旋转一个通道。假设ixCh = 3(提取和修改的通道)。因此, selChn 是形状= [?, 256,256,1]的张量,而 input 具有形状= [?, 256,256,5]。现在,我需要使用 selChn 更新 input 中的通道ixCh = 3。

我的问题来自函数 tf.tensor_scatter_nd_update 。我的第一个想法是创建一个可用作 tf.tensor_scatter_nd_update 中的索引的网状网格,但是显然批处理数量目前尚不清楚,因此我无法创建该网状网格。

如何更新/覆盖初始张量? 预先感谢!

更新:

我提供了一个可行的解决方案,尽管它对所有选定的频道都进行了相同的修改(我也想对每个选定的频道执行不同的翻转/旋转操作……欢迎提出任何建议!)< / p>

因此,解决方案如下:

def build_channel_mask(tensor, channels):
    # Given a 4D tensor (NHWC) and an array of channels (of C), it builds    
    # boolean a 4D mask with True in the indicated channels.
    channels = tf.convert_to_tensor(channels)
    shape = tf.shape(tensor)
    mask = tf.equal(tf.range(shape[3], dtype=channels.dtype), 
                    tf.expand_dims(channels, 1))
    mask = tf.reduce_any(mask, axis=0)
    mask = tf.expand_dims(mask, 0)
    mask = tf.tile(mask, (shape[2], 1))
    mask = tf.expand_dims(mask, 0)
    mask = tf.tile(mask, (shape[1], 1, 1))
    mask = tf.expand_dims(mask, 0)
    mask = tf.tile(mask, (shape[0], 1, 1, 1))
    return mask


class MyPartialRotationFlipLayer(layers.Layer):
    def __init__(self, growth_rate, rotperc=0.2):
      super(MyPartialRotationFlipLayer, self).__init__()
      self.nCh = int(tf.multiply(rotperc, growth_rate))
      if self.nCh < 1:
        self.nCh = 1

    def call(self, x):
      # Select randomly the channels to flip/rotate
      ixCh = tf.random.shuffle(tf.range(0, x.shape[-1]))
      ixCh = ixCh[0:self.nCh]

      # Make a new tensor with the flip/rotation modification
      xNew = tf.image.random_flip_left_right(x)
      xNew = tf.image.rot90(xNew, k=tf.random.uniform(shape=[], minval=1, 
                                                      maxval=4, 
                                                      dtype=tf.int32))
      # Build a mask to use it in tf.where function
      mask = build_channel_mask(x, ixCh)
      return tf.where(mask, xNew, x)

因此,作为最后一个问题,除了如何为每个选定频道进行独立的翻转/旋转外,如何执行此操作的任何想法?

0 个答案:

没有答案