Keras将numpy数组加载为权重的过滤器

时间:2016-07-27 18:03:29

标签: python numpy import theano keras

我在keras中实现了一个模型网络,其中theano作为后端,当前使用随机权重为过滤器进行初始化:

# Combine and reshape for convolution
seq = concat(embeddings)
cshape = (config.window_size, sum(f.output_dim for f in features))
seq = Reshape((1,)+cshape)(seq)

# Convolutions
conv_outputs = []
for filter_size, filter_num in zip(config.filter_sizes, config.filter_nums):
    conv = Convolution2D(filter_num, filter_size, cshape[1], activation='relu')(seq)
    cout = Flatten()(conv)
    conv_outputs.append(cout)
seq = concat(conv_outputs)

但是现在我希望能够加载以前生成的权重,这些权重存储为numpy数组。 这是尝试读取数组的改变代码:

# Combine and reshape for convolution
seq = concat(embeddings)
cshape = (config.window_size, sum(f.output_dim for f in features))
seq = Reshape((1,)+cshape)(seq)

# Convolutions
conv_outputs = []
for filter_size, filter_num in zip(config.filter_sizes, config.filter_nums):
    filters = np.load('path/to/filters/size-%d.npy' % filter_size)
    conv = Convolution2D(filter_num, filter_size, cshape[1], weights=filters)(seq)
    cout = Flatten()(conv)
    conv_outputs.append(cout)
seq = concat(conv_outputs)

当我尝试运行增强脚本时,我遇到以下错误:

Traceback (most recent call last):
  File "conv.py", line 78, in <module>
    weights=filters)(seq)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/topology.py", line 458, in __call__
    self.build(input_shapes[0])
  File "/usr/local/lib/python2.7/dist-packages/keras/layers/convolutional.py", line 324, in build
    self.set_weights(self.initial_weights)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/topology.py", line 846, in set_weights
    ' weights. Provided weights: ' + str(weights))
Exception: You called `set_weights(weights)` on layer "convolution2d_1" with a  weight list of length 3, but the layer was expecting 2 weights.

我尝试过在线搜索,但我无法弄清楚为什么keras会抛出这个错误。

0 个答案:

没有答案