在需要使用单位L2范数进行输出的回归问题中,如何标准化Keras网络输出?

时间:2018-11-23 06:01:13

标签: python tensorflow machine-learning keras regression

我的回归问题要求网络输出y具有单位范数||y|| = 1.。我想将其强加为线性激活后的Lambda层:

from keras import backend as K  
...  
model.add(Dense(numOutputs, activation='linear'))  
model.add(Lambda(lambda x: K.l2_normalize(x)))  

后端是TensorFlow。代码可以编译,但是网络可以预测具有不同范数(范数不是1且会变化)的输出向量。

关于我做错事情的任何提示吗?

1 个答案:

答案 0 :(得分:0)

问题是您尚未将axis参数传递给K.l2_normalize函数。结果,它将标准化整个批处理中的所有元素,以使它们的范数等于1。要解决此问题,只需传递axis=-1以在最后一个轴上进行标准化:

model.add(Lambda(lambda x: K.l2_normalize(x, axis=-1)))