How to Implement Center Loss and Other Running Averages of Labeled Embeddings

时间:2016-10-20 18:53:30

标签: tensorflow

A recent paper (here) introduced a secondary loss function that they called center loss. It is based on the distance between the embeddings in a batch and the running average embedding for each of the respective classes. There has been some discussion in the TF Google groups (here) regarding how such embedding centers can be computed and updated. I've put together some code to generate class-average embeddings in my answer below.

Is this the best way to do this?

2 个答案:

答案 0 :(得分:5)

之前发布的方法太简单,无法用于中心丢失等情况,其中嵌入的预期值随着模型变得更精细而随时间变化。这是因为先前的中心查找程序对自启动以来的所有实例进行平均,因此非常缓慢地跟踪预期值的变化。相反,移动窗口平均值是优选的。指数移动窗口变体如下:

def get_embed_centers(embed_batch, label_batch):
    ''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
    decay = 0.95
    with tf.variable_scope('embed', reuse=True):
        embed_ctrs = tf.get_variable("ctrs")

    label_batch = tf.reshape(label_batch, [-1])
    old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
    dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
    embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
    embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
    return embed_ctrs_batch


with tf.Session() as sess:
    with tf.variable_scope('embed'):
        embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
                        initializer=tf.constant_initializer(0), trainable=False)
    label_batch_ph = tf.placeholder(tf.int32)
    embed_batch_ph = tf.placeholder(tf.float32)
    embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
    sess.run(tf.initialize_all_variables())
    tf.get_default_graph().finalize()

答案 1 :(得分:2)

The get_new_centers() routine below takes in labelled embeddings and updates shared variables center/sums and center/cts. These variables are then used to calculate and return the embedding centers using the updated values.

The loop just exercises get_new_centers() and shows that it converges to the expected average embeddings for all classes over time.

Note that the alpha term used in the original paper isn't included here but should be straightforward to add if needed.

ndims = 2
nclass = 4
nbatch = 100

with tf.variable_scope('center'):
    center_sums = tf.get_variable("sums", [nclass, ndims], dtype=tf.float32,
                    initializer=tf.constant_initializer(0), trainable=False)
    center_cts = tf.get_variable("cts", [nclass], dtype=tf.float32,
                    initializer=tf.constant_initializer(0), trainable=False)

def get_new_centers(embeddings, indices):
    '''
    Update embedding for selected class indices and return the new average embeddings.
    Only the newly-updated average embeddings are returned corresponding to
    the indices (including duplicates).
    '''
    with tf.variable_scope('center', reuse=True):
        center_sums = tf.get_variable("sums")
        center_cts = tf.get_variable("cts")

    # update embedding sums, cts
    if embeddings is not None:
        ones = tf.ones_like(indices, tf.float32)
        center_sums = tf.scatter_add(center_sums, indices, embeddings, name='sa1')
        center_cts = tf.scatter_add(center_cts, indices, ones, name='sa2')

    # return updated centers
    num = tf.gather(center_sums, indices)
    denom = tf.reshape(tf.gather(center_cts, indices), [-1, 1])
    return tf.div(num, denom)


with tf.Session() as sess:
    labels_ph = tf.placeholder(tf.int32)
    embeddings_ph = tf.placeholder(tf.float32)

    unq_labels, ul_idxs = tf.unique(labels_ph)
    indices = tf.gather(unq_labels, ul_idxs)
    new_centers_with_update = get_new_centers(embeddings_ph, indices)
    new_centers = get_new_centers(None, indices)

    sess.run(tf.initialize_all_variables())
    tf.get_default_graph().finalize()

    for i in range(100001):
        embeddings = 100*np.random.randn(nbatch, ndims)
        labels = np.random.randint(0, nclass, nbatch)
        feed_dict = {embeddings_ph:embeddings, labels_ph:labels}
        rval = sess.run([new_centers_with_update], feed_dict)
        if i % 1000 == 0:
            feed_dict = {labels_ph:range(nclass)}
            rval = sess.run(new_centers, feed_dict)
            print('\nFor step ', i)
            for iclass in range(nclass):
                print('Class %d, center: %s' % (iclass, str(rval[iclass])))

A typical result at step 0 is:

For step  0
Class 0, center: [-1.7618252  -0.30574229]
Class 1, center: [ -4.50493908  10.12403965]
Class 2, center: [ 3.6156714  -9.94263649]
Class 3, center: [-4.20281982 -8.28845882]

and the output at step 10,000 demonstrates convergence:

For step  10000
Class 0, center: [ 0.00313433 -0.00757505]
Class 1, center: [-0.03476512  0.04682625]
Class 2, center: [-0.03865958  0.06585111]
Class 3, center: [-0.02502561 -0.03370816]