如何从两个给定的2d遮罩张量中获得3d遮罩张量?

时间:2020-07-29 03:50:47

标签: tensorflow

给出两个2d蒙版m1,m2(形状均为[m,m]),获得3d蒙版m3(形状[m,m,m]):

如果m1 [i] [j] == True和m2 [i] [k] == True且i!= j和i!= k且j!= k,则m3 [i] [j] [ k] =真

请注意,m1和m2是对角线,m1 [i] [j] = m1 [j] [i],m2 [i] [k] = m2 [k] [i]。但是m3 [i] [k] [j]不一定为True。

例如:

m1=[[0,1,0],[1,0,0],[0,0,0]]
m2=[[0,0,1],[0,0,0],[1,0,0]]

m3(形状(3,3,3))唯一的True值是m3 [0] [1] [2]

1 个答案:

答案 0 :(得分:0)

def _get_triplet_mask(mask1, mask2):
    indices_equal = tf.cast(tf.eye(tf.shape(mask1)[0]), tf.bool)
    indices_not_equal = ~indices_equal
    i_not_equal_j = tf.expand_dims(indices_not_equal, 2)
    i_not_equal_k = tf.expand_dims(indices_not_equal, 1)
    j_not_equal_k = tf.expand_dims(indices_not_equal, 0)
    distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
    i_j = tf.expand_dims(mask1, 2)
    i_k = tf.expand_dims(mask2, 1)
    valid_labels = i_j & i_k
    return valid_labels & distinct_indices

sentence-transformers/sentence_transformers/losses/BatchHardTripletLoss.py修改而来。在张量流中