为什么神经网络在重复训练后不能获得某一输出?

时间:2019-10-14 02:14:37

标签: python numpy tensorflow machine-learning keras

代码优先:

import numpy as np
np.random.seed(1231)

from keras import backend as K
import pickle
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv1D, MaxPooling1D
from keras.utils import np_utils
from sklearn.model_selection import train_test_split
from keras import initializers

data=pickle.load(open('DATA.pkl','rb'),encoding='latin1')

X=np.array(data[0])
Y=np.array(data[1])

X-=np.min(X)
X/=np.max(X)

X_train, X_test, Y_train, Y_test =train_test_split(X,Y,test_size=0.25, random_state=456)

X_train = X_train.reshape(X_train.shape[0],1000,1)
X_test = X_test.reshape(X_test.shape[0],1000,1)

model = Sequential()
model.add(Flatten(input_shape=(1000,1)))
gu=initializers.glorot_uniform(seed=789)
model.add(Dense(1, activation='sigmoid',kernel_initializer=gu))

model.compile(loss='binary_crossentropy',optimizer='adam', metrics=['accuracy'])

model.fit(X_train, Y_train,batch_size=1, epochs=1, verbose=2,shuffle=False)

score = model.evaluate(X_test, Y_test, verbose=0)

print(score)
print(model.get_layer(index=2).get_weights())

然后输出,有时是这样的:

Epoch 1/1
 - 1s - loss: 0.8007 - acc: 0.4720
[0.70776916790008548, 0.53200000023841854]
[array([[  2.89239828e-03],
       [ -1.48389703e-02],
       [  7.60693178e-02],
       ...
       [  2.92943567e-02],
       [  1.84460226e-02],
       [  2.38316301e-02]], dtype=float32), array([-0.00218478], dtype=float32)]

但有时像这样:

Epoch 1/1
 - 1s - loss: 0.8008 - acc: 0.4720
[0.7077700834274292, 0.53200000023841854]
[array([[  2.89072399e-03],
       [ -1.48402918e-02],
       [  7.60683641e-02],
       ...
       [  2.92898733e-02],
       [  1.84418838e-02],
       [  2.38287449e-02]], dtype=float32), array([-0.00218458], dtype=float32)]

每次,即使在这个简单的网络中,输出也会有所不同。

我分配了随机种子,但是效果不好。

那么,代码有什么问题?(我希望它具有完全可复制的输出)

1 个答案:

答案 0 :(得分:0)

请尝试tf.random.set_seed()进行全局设置,以防万一您在意料之外的地方(例如,优化程序)使用了其他随机性。

在版本1中,它是tf.set_random_seed()