在Tensorflow中有效地计算成对排名损失函数

时间:2016-10-30 15:45:37

标签: python machine-learning tensorflow deep-learning

我目前正在tensorflow中实施http://www.aclweb.org/anthology/P15-1061

我已经实现了成对排名亏损函数(论文的第2.5节),如下所示:

s_theta_y = tf.gather(tf.reshape(s_theta, [-1]), y_true_index)
s_theta_c_temp = tf.reshape(tf.gather(tf.reshape(s_theta, [-1]), y_neg_index), [-1, classes_size])
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1])

我不得不使用tf.gather而不是tf.gather_nd,因为后者尚未实现梯度下降。我还必须使用展平矩阵将所有索引转换为正确的。

如果使用渐变下降实现tf.gather_nd,我的代码将如下:

s_theta_y = tf.gather_nd(s_theta, y_t_index)
s_theta_c_temp = tf.gather_nd(s_theta, y_neg_index)
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1])

s_theta是每个类别标签的计算得分,如文中所述。 y_true_index包含真实类的索引,以便计算s_theta_y。 y_neg_index是所有负类的索引,其维度为#class-1或#class,该关系被归类为其他类。

然而,几个句子被归类为其他,因此,s_theta_y 不存在,我们不应该将其考虑在内。为了处理这种情况,我有一个取消该项的常数因子0,并且为负类具有相同的维度向量,我只是复制索引的随机值,因为最后,我们只对所有负类(而不是指数)中的最大值。

是否有更有效的方法在损失函数中计算这些项?我的印象是,使用tf.gather进行如此多的重塑非常慢

1 个答案:

答案 0 :(得分:1)

当然听起来像gather_nd就是你想要的,但是在那里实现渐变之前,我不会犹豫使用你的reshape()解决方案,因为reshape()实际上是免费的。

C++ implementation of the reshape() op似乎做了很多工作,但这只是对形状信息的快速错误检查。 "工作"发生在第90行的CopyFrom中,这听起来可能很昂贵,但实际上只是一个指针副本(CopyFrom调用CopyFromInternal复制指针)。

这完全合理:底层缓冲区只是row-major order中数字的平面数组,并且该排序不依赖于形状信息。出于同样的原因,像tf.transpose()这样的东西一般都需要复制。