如何加快Tensorflow中的屏蔽代码?

时间:2019-11-26 04:42:05

标签: python tensorflow keras

我的张量形状(batch_size,14、14,num_of_filters)当前值(32,14,14,512)。

我需要执行以下步骤:

for every image in this batch:
     for every filter in 512 filters:
          i,j = index of maximum value in 14x14 map
          mu = maximum value in 14x14 map
          mask = pos[i,j] // here pos is a pre-computed set of masks
          if mu = 0:
              mask = neg // here neg is a pre-computed mask for max value 0

这些蒙版的大小为14x14。我需要返回这些掩码的(32,14,14,512)张量。我为执行此任务而编写的代码如下:


  def true_fn(self):
    return T_neg

  def false_fn(self,TMU):
    return TMU

  def func2(self,feature_map):
    i = tf.math.argmax(feature_map,0)
    i_v = tf.math.reduce_max(feature_map,0)
    mu = tf.math.reduce_max(i_v,0)
    j = tf.math.argmax(i_v,0)
    i = tf.gather(i,j)
    Tmu = tf.gather(T_pos,i,axis=0)
    Tmu = tf.gather(Tmu,j,axis=0)
    Tmu = tf.cond( tf.math.equal( mu , tf.fill(tf.shape(mu), 0.0) ) , self.true_fn , lambda: self.false_fn(Tmu) )
    return Tmu

  def func(self,x1):
    x2 = tf.transpose(x1 , perm=[2,0,1])
    // iterate over num of filters
    x2 = tf.map_fn(self.func2,x2)
    return tf.transpose(x2 , perm=[1,2,0])

  def applyMask(self,x0):
    // iterate over batch size
    masks = tf.map_fn(self.func,x0)
    return masks

我在调用函数中编写

masks = applyMask(x)

但是,这条特殊的线运行太慢,将我每个时期的时间增加到几个小时。有什么方法可以加快计算速度吗?

****编辑**** 在上面的代码T_pos和T_neg中,指的是预先计算的掩码。 T_pos的维数为(14,14,14,14),其中前两个维用于索引,后两个维是实际地图。 T_neg是最大值为0时使用的(14,14)映射

0 个答案:

没有答案