Keras layer.set_weights不会修改图层。为什么?

时间:2017-10-20 05:30:37

标签: keras

当我初始化并为模型加载权重时,输出的准确率为67%。

model.load_weights(path+'results/finetune_train_last_layer.h5')  
batches = model.get_batches(path, shuffle=False, batch_size=128, class_mode=None)
preds = model.predict_generator(batches, batches.nb_sample)
matches = 0
for guess, ans in zip(np.argmax(preds, axis=1), batches.classes):
    if guess == ans:
        matches += 1
print('%s/%s' % (matches, len(batches.classes)))

532/792

正确加载图层。这与我在保存之前对这些重量的最后一轮训练所获得的准确度相同。

但是,当我尝试使用与model中最后一层相同的图层创建新模型并复制权重时,它们的权重不同。这怎么可能?

no_drop_model = Sequential([
    MaxPooling2D(input_shape=(512, 14, 14)),
    Flatten(),
    Dense(4096, activation='relu'),
    Dropout(0.),
    Dense(4096, activation='relu'),
    Dropout(0.),
    Dense(120, activation='softmax')
])
for ndl, fcl in zip(no_drop_model.layers, model.layers[31:]):
    print(type(ndl), type(fcl))
    ndl.set_weights(fcl.get_weights())
    if ndl.get_weights():
        print(np.array_equiv(ndl.get_weights(), fcl.get_weights()))

输出:

(<class 'keras.layers.pooling.MaxPooling2D'>, <class 'keras.layers.pooling.MaxPooling2D'>)
(<class 'keras.layers.core.Flatten'>, <class 'keras.layers.core.Flatten'>)
(<class 'keras.layers.core.Dense'>, <class 'keras.layers.core.Dense'>)
False
(<class 'keras.layers.core.Dropout'>, <class 'keras.layers.core.Dropout'>)
(<class 'keras.layers.core.Dense'>, <class 'keras.layers.core.Dense'>)
False
(<class 'keras.layers.core.Dropout'>, <class 'keras.layers.core.Dropout'>)
(<class 'keras.layers.core.Dense'>, <class 'keras.layers.core.Dense'>)
False

1 个答案:

答案 0 :(得分:0)

model.get_weights()的返回值是numpy数组的列表,而不是单个数组。你应该像这样比较权重:

def create_model():
    i = Input((2,))
    o = Dense(3)(i)
    return Model(i, o)

model1 = create_model()
model2 = create_model()

for w1, w2 in zip(model1.get_weights(), model2.get_weights()):
    print(np.array_equiv(w1, w2))

model2.set_weights(model1.get_weights())
print 'Weights after:'
for w1, w2 in zip(model1.get_weights(), model2.get_weights()):
    print(np.array_equiv(w1, w2))

这将产生以下输出:

False
True
Weights after:
True
True

权重列表中的第二个元素对应于初始化为零的偏差值,因此在复制权重之前,这些值是相同的。