使用不同的嵌入来计算三重态损耗

时间:2020-02-01 18:53:55

标签: tensorflow triplet

我尝试使用不同的嵌入来计算三重态损耗,但损耗被限制在边缘。

此外,我尝试查找有效的Triplet数量,但结果停留在实例总数的一半。

这是我的代码:

# Sample Data
import numpy as np
import tensorflow as tf

batch = 64
emb_dim = 1024

np.random.seed(1234)
emb1 = np.random.rand(batch,emb_dim)
np.random.seed(2345)
emb2 = np.random.rand(batch,emb_dim)
margin = 0.3
labels = np.expand_dims(np.arange(batch), axis=1)

def _distance_metric(x, y):
    """
    Args:
        x: tensor, with shape [m, d], (batch_size, d)
        y: tensor, with shape [n, d], (batch_size, d)
    Returns:
        dist: tensor, with shape [m, n], (batch_size, batch_size)
    """
    # |x-y|^2 = x^2 - 2xy + y^2
    # xy
    xy = tf.matmul(x, tf.transpose(y))
    # x^2
    xx = tf.matmul(x, tf.transpose(x))
    xx = tf.linalg.diag_part(xx)
    # y^2
    yy = tf.matmul(y, tf.transpose(y))
    yy = tf.linalg.diag_part(yy)
    '''
    (batch_size,1)-(batch_size,batch_size):
        Equivalent to each column operation
    (batch_size,batch_size)+(1,batch_size):
        Equivalent to each row operation
    '''
    distances = tf.expand_dims(xx, 1) -2.0*xy + tf.expand_dims(yy, 0)
    return distances

def _label_mask(labels):
    '''
    if label is same, label_mask will return True
    ------------------------------------
    Args:
        labels:     Label Data, shape = (batch_size,1)
    Returns:
        label_mask: tensor, with shape [m, n], (batch_size, batch_size)
        ex.
            labels = [1,0,1]
            label_mask = [[1, 0, 1],
                          [0, 1, 0],
                          [1, 0, 1]]
    '''
    labels = tf.dtypes.cast(labels, tf.float32)
    label_mask = _distance_metric(labels, labels)
    label_mask = (label_mask == 0)
    return label_mask

def batch_all(labels, emb1, emb2, margin):
    '''
    batch all triplet loss of a batch
    ------------------------------------
    Args:
        labels:     Label Data, shape = (batch_size,1)
        emb1, emb2: Embedding Feature, shape = (batch_size, vector_size)
        margin:     margin, scalar
    Returns:
        triplet_loss: scalar, for one batch
    '''
    dist_mat = _distance_metric(emb1, emb2)
    # an and ap mask
    ap_mask = _label_mask(labels)
    an_mask = tf.dtypes.cast(tf.math.logical_not(ap_mask), dtype=tf.float64)
    ap_mask = tf.dtypes.cast(ap_mask, dtype=tf.float64)
    # distance between anchor and positive
    dist_ap = tf.reduce_sum(dist_mat*ap_mask, axis=1)/tf.reduce_sum(ap_mask, axis=1)
    # ap - dist_mat + margin
    mat = tf.expand_dims(dist_ap, 1) - dist_mat + margin
    # only need ap-an
    mat = mat*an_mask
    # caluculate the number of valid triplet loss
    mask = tf.dtypes.cast(tf.math.greater(mat, margin), dtype=tf.float64)
    num_valid_triplets = tf.reduce_sum(mask)
    triplet_loss = mat*mask
    # <1 : 1
    num_valid_triplets = tf.maximum(num_valid_triplets, 1.0)
    # divided triplet_loss by num_valid_triplets
    triplet_loss = tf.reduce_sum(triplet_loss)/(num_valid_triplets + 1e-16)
    return triplet_loss, num_valid_triplets

我犯了愚蠢的错误吗?

这是学习曲线和有效Tripelt的数量。

emb1 to emb2 triplet loss

the number of emb1 to emb2 valid triplet

emb2 to emb1 triplet loss

the number of emb2 to emb1 valid triplet

0 个答案:

没有答案