我想在Keras中编写一个图层,但出现错误:
AttributeError:“ NoneType”对象没有属性“ _inbound_nodes”
后端功能可能是问题所在。
但是,我搜索了几个小时,发现的大多数答案是“在Lambda中翘曲”,但是当我在Lambda中包装所有后端函数(我使用tensorflow)时,我仍然会出错。
我什至在keras github#12672中打开了一个问题,但没有人提出申请,帮助。 :-( https://github.com/keras-team/keras/issues/12672
from keras.engine.topology import Layer
import keras.backend as K
from keras.layers import Lambda
import numpy as np
class Attention(Layer):
def __init__(self, **kwargs):
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
# three weight
# Wh: att_size, att_size --> previous hidden state
# Wq: query_dim, att_size --> target hidden state
# V: att_size, 1 --> tanh
# score(previous, target) = Vt * tanh(Wh * previous + target * Wq + b???) --> (1, 1)
# the dimension of previous hidden state
self.att_size = input_shape[0][-1]
# the dimension of target hidden state
self.query_dim = input_shape[1][-1]
self.Wq = self.add_weight(name='kernal_query_features', shape=(self.query_dim, self.att_size),
initializer='glorot_normal', trainable=True)
self.Wh = self.add_weight(name='kernal_hidden_features', shape=(self.att_size, self.att_size),
initializer='glorot_normal', trainable=True)
self.v = self.add_weight(name='query_vector', shape=(self.att_size, 1),
initializer='zeros', trainable=True)
super(Attention, self).build(input_shape)
def call(self, inputs, mask=None):
# score(previous, target) = Vt * tanh(Wh * memory + target * Wq)
memory, query = inputs[0], inputs[1]
hidden = K.dot(memory, self.Wh) + K.expand_dims(K.dot(query, self.Wq), 1)
hidden = K.tanh(hidden)
# remove the dimension whose shape is 1
s = K.squeeze(K.dot(hidden, self.v), -1)
#s= K.reshape(K.dot(hidden, self.v), (-1, self.att_size))
# compute the weight use soft_max
s = K.softmax(s)
return K.sum(memory * K.expand_dims(s), axis=1)
def compute_output_shape(self, input_shape):
att_size = input_shape[0][-1]
batch = input_shape[0][0]
return batch, att_size
在我的主要代码中
lstm_global_seq = LSTM(units=2 * neighbor_slide_len * neighbor_slide_len,
return_sequences=True, dropout=0.1,
recurrent_dropout=0.1, name="att_global")(global_data)
att_global = Attention()([lstm_global_seq[:-1], lstm_global_seq[-1]])
lstm_global = merge.Concatenate(axis=-1)([att_global, lstm_global_seq[-1]])
我希望这个自定义图层可以正确运行