如何查看初始权重(即培训前)?

时间:2017-10-17 20:31:12

标签: python keras kernel

我正在使用Keras生成一个简单的单层前馈网络。我想通过kernel_initializer参数初始化权重时更好地处理权重值。

有没有办法可以在初始化之后(即在训练完成之前)查看权重值。

也许我应该解释为什么我要查看初始化的权重。在Keras,我对随机正交矩阵的实际外观感到有些困惑。如果我可以打印这些值,它将有助于我更好地理解这个功能。

3 个答案:

答案 0 :(得分:5)

只需在模型上使用get_weights()即可。例如:

i = Input((2,))
x = Dense(5)(i)

model = Model(i, x)

print model.get_weights()

这将打印2x5权重矩阵和1x5偏差矩阵:

[array([[-0.46599612,  0.28759909,  0.48267472,  0.55951393,  0.3887372 ],
   [-0.56448901,  0.76363671,  0.88165808, -0.87762225, -0.2169953 ]], dtype=float32), 
 array([ 0.,  0.,  0.,  0.,  0.], dtype=float32)]

偏差为零,因为默认偏差初始值设定为零。

答案 1 :(得分:2)

您需要指定第一层输入的尺寸,否则将为您提供一个空列表。比较两个打印的两个结果,唯一的区别是输入形状的初始化。

from keras import backend as K
import numpy as np 
from keras.models import Sequential
from keras.layers import Dense
# first model without input_dim prints an empty list   
model = Sequential()
model.add(Dense(5, weights=[np.ones((3,5)),np.zeros(5)], activation='relu'))
print(model.get_weights())


# second model with input_dim prints the assigned weights
model1 = Sequential()
model1.add(Dense(5,  weights=[np.ones((3,5)),np.zeros(5)],input_dim=3, activation='relu'))
model1.add(Dense(1, activation='sigmoid'))

print(model1.get_weights())

答案 2 :(得分:1)

@Chris_K给出的答案应该有效 - CREATE TABLE `Basic Cases` ( `Type of Case` LongText, `Receipt` DateTime, `Address` LongText, `Postcode` LongText, `Ref No` LongText, `Value` Currency) 在调用fit之前打印正确的初始化权重。尝试运行此代码作为完整性检查 - 它应该打印两个非零的矩阵(两个层),然后打印两个零的矩阵:

model.get_weights()

这是我看到的输出:

from keras.models import Sequential
from keras.layers import Dense
import keras
import numpy as np

X = np.random.randn(10,3)
Y = np.random.randn(10,)

# create model
model1 = Sequential()
model1.add(Dense(12, input_dim=3, activation='relu'))
model1.add(Dense(1, activation='sigmoid'))

print(model1.get_weights())

# create model
model2 = Sequential()
model2.add(Dense(12, input_dim=3, kernel_initializer='zero', activation='relu'))
model2.add(Dense(1, kernel_initializer='zero', activation='sigmoid'))

print(model2.get_weights())