包含巨大矩阵乘法的喀拉拉邦自定义损失函数

时间:2019-10-22 01:22:56

标签: tensorflow keras deep-learning keras-layer loss-function

我很难在Keras中编写自定义损失函数。我有图层权重“ W”和矩阵“ M”。我想执行以下操作trace((W * M)* W')以计算我的损失函数。迹线是对角元素的总和。在numpy中,我将执行以下操作:

np.trace(np.dot(np.dot(W,M),W.T))) or 

def custom_regularizer(W,M):
    sum_reg = 0
    for i in range(W.shape[1]):
        for j in range(i,W.shape[1]):
            vector = W[:,i] - W[:,j]
            sum_reg = sum_reg + M[i,j] * (LA.norm(vector)**2)
    return sum_reg

对于喀拉拉邦,我编写了以下损失函数

def custom_loss(W):

  def lossFunction(y_true,y_pred):    
    loss = tf.trace(K.dot(K.dot(W,K.constant(M)),K.transpose(W)))
    return loss

return lossFunction

问题在于keras正在计算维度为200000 * 200000的整个外部矩阵,从而导致内存错误。有什么方法可以让我只获取对角线元素的总和而无需进行整个矩阵计算。

怎么做与keras损失函数一样?

1 个答案:

答案 0 :(得分:1)

如果您遵循一些巧妙的技巧来计算轨迹,则不应耗尽内存。例如,您可以参考this