我的张量形状(batch_size,14、14,num_of_filters)当前值(32,14,14,512)。
我需要执行以下步骤:
for every image in this batch:
for every filter in 512 filters:
i,j = index of maximum value in 14x14 map
mu = maximum value in 14x14 map
mask = pos[i,j] // here pos is a pre-computed set of masks
if mu = 0:
mask = neg // here neg is a pre-computed mask for max value 0
这些蒙版的大小为14x14。我需要返回这些掩码的(32,14,14,512)张量。我为执行此任务而编写的代码如下:
def true_fn(self):
return T_neg
def false_fn(self,TMU):
return TMU
def func2(self,feature_map):
i = tf.math.argmax(feature_map,0)
i_v = tf.math.reduce_max(feature_map,0)
mu = tf.math.reduce_max(i_v,0)
j = tf.math.argmax(i_v,0)
i = tf.gather(i,j)
Tmu = tf.gather(T_pos,i,axis=0)
Tmu = tf.gather(Tmu,j,axis=0)
Tmu = tf.cond( tf.math.equal( mu , tf.fill(tf.shape(mu), 0.0) ) , self.true_fn , lambda: self.false_fn(Tmu) )
return Tmu
def func(self,x1):
x2 = tf.transpose(x1 , perm=[2,0,1])
// iterate over num of filters
x2 = tf.map_fn(self.func2,x2)
return tf.transpose(x2 , perm=[1,2,0])
def applyMask(self,x0):
// iterate over batch size
masks = tf.map_fn(self.func,x0)
return masks
我在调用函数中编写
masks = applyMask(x)
但是,这条特殊的线运行太慢,将我每个时期的时间增加到几个小时。有什么方法可以加快计算速度吗?
****编辑**** 在上面的代码T_pos和T_neg中,指的是预先计算的掩码。 T_pos的维数为(14,14,14,14),其中前两个维用于索引,后两个维是实际地图。 T_neg是最大值为0时使用的(14,14)映射