三维时间序列输出的keras均方误差损失函数

时间:2019-03-02 22:58:11

标签: python keras recurrent-neural-network mse

我想验证我的损失函数,因为我已经阅读到喀拉拉邦的mse损失函数存在问题。 考虑喀拉拉邦的lstm模型,该模型预测3d时间序列为多个目标(y1,y2,y3)。假设一批输出序列的形状为(10,31,1) 下面的损失函数会采用预测输出与真实输出之间的平方差,然后取310个样本的平均值,得出一个单一的损失值吗?如果将3个输出串联为(10,31,3)

,该操作将如何发生?
def mse(y_true, y_pred):
            return keras.backend.mean(keras.backend.square(y_pred - y_true), axis=1)

1 个答案:

答案 0 :(得分:2)

如果要获取单个损失值,则无需设置axis

import keras.backend as K

def mse(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true))

y_true = K.random_normal(shape=(10,31,3))
y_pred = K.random_normal(shape=(10,31,3))

loss = mse(y_true, y_pred)
print(K.eval(loss))

# print
2.0196152