Keras的多维回归

时间:2017-05-15 17:00:51

标签: tensorflow deep-learning keras theano

我想用Keras训练神经网络进行二维回归。

我的输入是一个数字,我的输出有两个数字:

model = Sequential()
model.add(Dense(16, input_shape=(1,), kernel_initializer=initializers.constant(0.0), bias_initializer=initializers.constant(0.0)))
model.add(Activation('relu'))
model.add(Dense(16, input_shape=(1,), kernel_initializer=initializers.constant(0.0), bias_initializer=initializers.constant(0.0)))
model.add(Activation('relu'))
model.add(Dense(2, kernel_initializer=initializers.constant(0.0), bias_initializer=initializers.constant(0.0)))
adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss='mean_squared_error', optimizer=adam)

然后我创建了一些用于训练的虚拟数据:

inputs = np.zeros((10, 1), dtype=np.float32)
targets = np.zeros((10, 2), dtype=np.float32)

for i in range(10):
    inputs[i] = i / 10.0
    targets[i, 0] = 0.1
    targets[i, 1] = 0.01 * i

最后,我在一个循环中训练了小型飞机,同时测试了训练数据:

while True:

    loss = model.train_on_batch(inputs, targets)

    test_outputs = model.predict(inputs)

    print test_outputs

问题是,打印输出如下:

[0.1,0.045] [0.1,0.045] [0.1,0.045] ..... ..... .....

因此,虽然第一个维度是正确的(0.1),但第二个维度是不正确的。第二个维度应为[0.01,0.02,0.03,.....]。所以实际上,网络的输出(0.45)只是第二维中所有值的平均值。

我做错了什么?

1 个答案:

答案 0 :(得分:7)

问题是,您正在使用零初始化所有权重。问题是,如果所有权重都相同,那么所有梯度都是相同的。所以就好像每个层都有一个神经元网络。删除它,以便使用默认的随机初始化并且它可以工作:

model = Sequential()
model.add(Dense(16, input_shape=(1,)))
model.add(Activation('relu'))
model.add(Dense(16, input_shape=(1,)))
model.add(Activation('relu'))
model.add(Dense(2))
model.compile(loss='mean_squared_error', optimizer='Adam')

1000个纪元后的结果:

Epoch 1000/1000
10/10 [==============================] - 0s - loss: 5.2522e-08

In [59]: test_outputs
Out[59]:
array([[ 0.09983768,  0.00040025],
       [ 0.09986718,  0.010469  ],
       [ 0.09985521,  0.02051429],
       [ 0.09984323,  0.03055958],
       [ 0.09983127,  0.04060487],
       [ 0.09995781,  0.05083206],
       [ 0.09995599,  0.06089856],
       [ 0.09995417,  0.07096504],
       [ 0.09995237,  0.08103154],
       [ 0.09995055,  0.09109804]], dtype=float32)