我很难在Keras中编写自定义损失函数。我有图层权重“ W”和矩阵“ M”。我想执行以下操作trace((W * M)* W')以计算我的损失函数。迹线是对角元素的总和。在numpy中,我将执行以下操作:
np.trace(np.dot(np.dot(W,M),W.T))) or
def custom_regularizer(W,M):
sum_reg = 0
for i in range(W.shape[1]):
for j in range(i,W.shape[1]):
vector = W[:,i] - W[:,j]
sum_reg = sum_reg + M[i,j] * (LA.norm(vector)**2)
return sum_reg
对于喀拉拉邦,我编写了以下损失函数
def custom_loss(W):
def lossFunction(y_true,y_pred):
loss = tf.trace(K.dot(K.dot(W,K.constant(M)),K.transpose(W)))
return loss
return lossFunction
问题在于keras正在计算维度为200000 * 200000的整个外部矩阵,从而导致内存错误。有什么方法可以让我只获取对角线元素的总和而无需进行整个矩阵计算。
怎么做与keras损失函数一样?