我试图了解Keras自定义图层的工作原理。我正在尝试创建一个采用标量输入的乘法层,并将其与被乘数相乘。我生成一些随机数据,并想学习被乘数。当我尝试10个数字时,它工作正常。但是,当我尝试使用20个数字时,损失就会爆炸。
from keras import backend as K
from keras.engine.topology import Layer
from keras import initializers
class MultiplicationLayer(Layer):
def __init__(self, **kwargs):
super(MultiplicationLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='multiplicand',
shape=(1,),
initializer='glorot_uniform',
trainable=True)
self.built = True
def call(self, x):
return self.kernel*x
def compute_output_shape(self, input_shape):
return input_shape
使用TensorFlow后端。
用10个数字测试模型1
from keras.layers import Input
from keras.models import Model
# input is a single scalar
input = Input(shape=(1,))
multiply = MultiplicationLayer()(input)
model = Model(input, multiply)
model.compile(optimizer='sgd', loss='mse')
import numpy as np
input_data = np.arange(10)
output_data = 2 * input_data
model.fit(input_data, output_data, epochs=10)
#print(model.layers[1].multiplicand.get_value())
print(model.layers[1].get_weights())
大纪元1/10 10/10 [==============================] - 7s - 损失:257.6145 大纪元2/10 10/10 [==============================] - 0s - 损失:47.6329 大纪元3/10 10/10 [==============================] - 0s - 损失:8.8073 大纪元4/10 10/10 [==============================] - 0s - 损失:1.6285 大纪元5/10 10/10 [==============================] - 0s - 损失:0.3011 大纪元6/10 10/10 [==============================] - 0s - 损失:0.0557 大纪元7/10 10/10 [==============================] - 0s - 损失:0.0103 大纪元8/10 10/10 [==============================] - 0s - 损失:0.0019 大纪元9/10 10/10 [==============================] - 0s - 损失:3.5193e-04 Epoch 10/10 10/10 [==============================] - 0s - 损失:6.5076e-05
[array([1.99935019],dtype = float32)]
用20个数字测试模型2
from keras.layers import Input
from keras.models import Model
# input is a single scalar
input = Input(shape=(1,))
multiply = MultiplicationLayer()(input)
model = Model(input, multiply)
model.compile(optimizer='sgd', loss='mse')
import numpy as np
input_data = np.arange(20)
output_data = 2 * input_data
model.fit(input_data, output_data, epochs=10)
#print(model.layers[1].multiplicand.get_value())
print(model.layers[1].get_weights())
大纪元1/10 20/20 [==============================] - 0s - 损失:278.2014 大纪元2/10 20/20 [==============================] - 0s - 损失:601.1653 大纪元3/10 20/20 [==============================] - 0s - 损失:1299.0583 大纪元4/10 20/20 [==============================] - 0s - 损失:2807.1353 大纪元5/10 20/20 [==============================] - 0s - 损失:6065.9375 大纪元6/10 20/20 [==============================] - 0s - 损失:13107.8828 大纪元7/10 20/20 [==============================] - 0s - 损失:28324.8320 大纪元8/10 20/20 [==============================] - 0s - 损失:61207.1250 大纪元9/10 20/20 [==============================] - 0s - 损失:132262.4375 Epoch 10/10 20/20 [==============================] - 0s - 损失:285805.9688
[array([ - 68.71629333],dtype = float32)]
任何见解为什么会发生这种情况?
答案 0 :(得分:2)
您可以使用其他优化程序(例如botocore.exceptions.ClientError: An error occurred (UnrecognizedClientException) when calling the SendCommand operation: The security token included in the request is invalid.
)来解决此问题。不幸的是,这需要100个纪元....或者在SGD中使用较小的学习率,例如Adam(lr=0.1)
。
SGD(lr = 0.001)
进一步测试,我注意到from keras.optimizers import *
# input is a single scalar
inp = Input(shape=(1,))
multiply = MultiplicationLayer()(inp)
model = Model(inp, multiply)
model.compile(optimizer=Adam(lr=0.1), loss='mse')
import numpy as np
input_data = np.arange(20)
output_data = 2 * input_data
model.fit(input_data, output_data, epochs=100)
#print(model.layers[1].multiplicand.get_value())
print(model.layers[1].get_weights())
也有效,而SGD(lr = 0.001)
则会爆炸。
我认为是:
如果你的学习率足以使你的更新超过以前的距离,那么下一步将获得更大的渐变,让你再次超越这个点更远的距离。
只有一个数字的示例:
SGD(lr = 0.01)
同样的例子,学习率较低:
inputNumber = 20
x = currentMultiplicand = 1
targetValue = 40
lr = 0.01
#first step (x=1):
mse = (40-20x)² = 400
gradient = -2*(40-20x)*20 = -800
update = - lr * gradient = 8
new x = 9
#second step (x=9):
mse = (40-20x)² = 19600 #(!!!!!)
gradient = -2*(40-20x)*20 = 5600
update = - lr * gradient = -56
new x = -47
#you can see from here that this is not going to be contained anymore...