我正在训练keras imdb数据集,所以输入看起来像 [200401500,...],[310551730,...]]。我正在构建我的网络
model_input = Input(shape=(50,), dtype='int32')
lstm_0 = Bidirectional(LSTM(512, return_sequences=True))(model_input)
x = MyLayer(name='mylayer')([lstm_0, model_input])
和我自己的图层
class MyLayer(Layer):
def __init__(self, **kwargs):
...
def build(self, input_shape):
...
def call(self, x, mask=None):
lstm_0, model_input = x[0], x[1]
#pseudo code
if 551 in model_input:
print(position_of_the_word_in_sentence)
a = np.zeros((1,50))
a[0,position_of_the_word_in_sentence] = 1
所以model_input是一个占位符
我想在call方法中做的是检查model_input中的句子是否包含某个单词。例如,单词的索引" happy"可能是551.一个句子形式model_input"我很高兴",我想检查这句话是否满意,并创建一个热门的numpy arry。
那么可以实现伪代码吗?
写keras和tensorflow是如此困难><