层重量形状(1,1)与keras模型提供的重量形状(1,)不兼容

时间:2020-07-20 16:48:19

标签: python keras

我使用Keras训练了一个模型,但忘了保存模型。该模型是开发了许多其他模型的项目的一部分,但是现在我无法继续进行该项目。幸运的是,我节省了初始和最终的训练权重。现在,我正在尝试创建具有相同最终权重的模型以获取预测。我正在编译keras模型,并使用函数model.set_weights将丢失模型的最终训练权重设置为新模型。 这是代码。

model = Sequential()
model.add(Dense(1,input_dim = 1, activation = 'relu'))
model.add(Dense(1, activation = 'relu'))
model.compile(loss = 'mean_squared_error', optimizer = 'Adam', metrics = ['mse'])
listOfNumpyArrays = [np.array([0.2]),np.array([0.2])]
listOfNumpyArrays1 = listOfNumpyArrays
model.layers[0].set_weights(listOfNumpyArrays)
model.layers[1].set_weights(listOfNumpyArrays1)

追踪

ValueError                                Traceback (most recent call last)
<ipython-input-31-e63437554e30> in <module>()
----> 1 model.layers[0].set_weights(listOfNumpyArrays)
      2 model.layers[1].set_weights(listOfNumpyArrays1)
1 frames
/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py in set_weights(self, weights)
   1124                                  str(pv.shape) +
   1125                                  ' not compatible with '
-> 1126                                  'provided weight shape ' + str(w.shape))
   1127             weight_value_tuples.append((p, w))
   1128         K.batch_set_value(weight_value_tuples)
ValueError: Layer weight shape (1, 1) not compatible with provided weight shape (1,)

1 个答案:

答案 0 :(得分:1)

使用np.array([0.2])创建的numpy数组的形状为(1,),而权重数组的形状为(1,1)。虽然它们存储相同数量的数据,但numpy将它们视为不同的形状。您可以通过以下操作解决此问题:

代替:

listOfNumpyArrays = [np.array([0.2]),np.array([0.2])]

使用:

listOfNumpyArrays = [np.empty(shape = (1,1), dtype = np.float32), np.empty(shape = (1,1), dtype = np.float32)]
listOfNumpyArrays[0][0] = 0.2
listOfNumpyArrays[1][0] = 0.2

无关提示:

在这一行:

listOfNumpyArrays1 = listOfNumpyArrays

似乎您要创建两个初始化为相同值的numpy数组的不同列表。但是,listOfNumpyArrays1实际上将引用与listOfNumpyArrays相同的列表。因此,当您在set_weights上进行listOfNumpyArrays1时,它也会同时修改listOfNumpyArrays。要在创建两个不同的列表时将它们初始化为相同的值,可以使用以下代码:

listOfNumpyArrays1 = [np.copy(listOfNumpyArrays[0]), np.copy(listOfNumpyArrays[1])]

np.copy创建一个新数组,该数组是您传递的数组的副本。可以使用列表理解以更Python的方式编写此代码,如下所示:

listOfNumpyArrays1 = [np.copy(x) for x in listOfNumpyArrays]