如何使用keras进行异或

时间:2015-07-22 07:13:17

标签: python neural-network xor keras

我想通过代码xor练习keras,但结果不对,接下来是我的代码,感谢大家帮助我。

from keras.models import Sequential
from keras.layers.core import Dense,Activation
from keras.optimizers import SGD
import numpy as np

model = Sequential()# two layers
model.add(Dense(input_dim=2,output_dim=4,init="glorot_uniform"))
model.add(Activation("sigmoid"))
model.add(Dense(input_dim=4,output_dim=1,init="glorot_uniform"))
model.add(Activation("sigmoid"))
sgd = SGD(l2=0.0,lr=0.05, decay=1e-6, momentum=0.11, nesterov=True)
model.compile(loss='mean_absolute_error', optimizer=sgd)
print "begin to train"
list1 = [1,1]
label1 = [0]
list2 = [1,0]
label2 = [1]
list3 = [0,0]
label3 = [0]
list4 = [0,1]
label4 = [1] 
train_data = np.array((list1,list2,list3,list4)) #four samples for epoch = 1000
label = np.array((label1,label2,label3,label4))

model.fit(train_data,label,nb_epoch = 1000,batch_size = 4,verbose = 1,shuffle=True,show_accuracy = True)
list_test = [0,1]
test = np.array((list_test,list1))
classes = model.predict(test)
print classes
  

输出

[[ 0.31851079] [ 0.34130159]] [[ 0.49635666] [0.51274764]] 

3 个答案:

答案 0 :(得分:1)

如果我将代码中的纪元数增加到50000,它通常会收敛到正确的答案,只需要一段时间:)

但是,它经常会卡住。如果我将损失函数更改为' mean_squared_error',这会获得更好的收敛属性,这是一个更平滑的函数。

如果我使用Adam或RMSProp优化器,我的收敛速度会更快。我的最终编译行有效:

model.compile(loss='mse', optimizer='adam')
...
model.fit(train_data, label, nb_epoch = 10000,batch_size = 4,verbose = 1,shuffle=True,show_accuracy = True)

答案 1 :(得分:0)

我使用了一个带有4个隐藏节点的隐藏层,它几乎总是在500个时期内收敛到正确的答案。我使用了sigmoid激活。

答案 2 :(得分:0)

与Keras进行XOR培训

下面是学习XOR所需的最小神经元网络体系结构,它应该是(2,2,1)网络。实际上,如果数学表明(2,2,1)网络可以解决XOR问题,但数学并不表明(2,2,1)网络易于训练。有时可能需要花费很多时间(迭代次数)或无法收敛到全局最小值。就是说,使用(2,3,1)或(2,4,1)网络体系结构,我很容易获得良好的结果。

问题似乎与许多局部极小值的存在有关。请看1998年Richard Bland的论文《 Learning XOR: exploring the space of a classic problem》。此外,权重初始化(随机数在0.5到1.0之间)有助于收敛。

它与Keras或TensorFlow一起使用损失函数'mean_squared_error',S型激活和Adam优化器可以正常工作。即使有了非常好的超参数,我也观察到学习到的XOR模型被困在局部最小值中的时间约为15%。

image.png

***培训... ***

***培训完成! ***

*** [[0,0],[0,1],[1,0],[1,1]] ***上的模型预测

[[0.08662204] [0.9235283] [0.92356336] [0.06672956]]