给出两个2d蒙版m1,m2(形状均为[m,m]),获得3d蒙版m3(形状[m,m,m]):
如果m1 [i] [j] == True和m2 [i] [k] == True且i!= j和i!= k且j!= k,则m3 [i] [j] [ k] =真
请注意,m1和m2是对角线,m1 [i] [j] = m1 [j] [i],m2 [i] [k] = m2 [k] [i]。但是m3 [i] [k] [j]不一定为True。
例如:
m1=[[0,1,0],[1,0,0],[0,0,0]]
m2=[[0,0,1],[0,0,0],[1,0,0]]
m3(形状(3,3,3))唯一的True值是m3 [0] [1] [2]
答案 0 :(得分:0)
def _get_triplet_mask(mask1, mask2):
indices_equal = tf.cast(tf.eye(tf.shape(mask1)[0]), tf.bool)
indices_not_equal = ~indices_equal
i_not_equal_j = tf.expand_dims(indices_not_equal, 2)
i_not_equal_k = tf.expand_dims(indices_not_equal, 1)
j_not_equal_k = tf.expand_dims(indices_not_equal, 0)
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
i_j = tf.expand_dims(mask1, 2)
i_k = tf.expand_dims(mask2, 1)
valid_labels = i_j & i_k
return valid_labels & distinct_indices
由sentence-transformers/sentence_transformers/losses/BatchHardTripletLoss.py修改而来。在张量流中