Keras检查自定义层中的model_input / iterate占位符值中的值

时间:2018-06-19 04:15:09

标签: python tensorflow keras

我正在训练keras imdb数据集,所以输入看起来像 [200401500,...],[310551730,...]]。我正在构建我的网络

model_input = Input(shape=(50,), dtype='int32')
lstm_0 = Bidirectional(LSTM(512, return_sequences=True))(model_input)
x = MyLayer(name='mylayer')([lstm_0, model_input])

和我自己的图层

class MyLayer(Layer):
    def __init__(self, **kwargs):
        ...
    def build(self, input_shape):
        ...
    def call(self, x, mask=None):
        lstm_0, model_input = x[0], x[1]

        #pseudo code
        if 551 in model_input:
            print(position_of_the_word_in_sentence)
            a = np.zeros((1,50))
            a[0,position_of_the_word_in_sentence] = 1

所以model_input是一个占位符

我想在call方法中做的是检查model_input中的句子是否包含某个单词。例如,单词的索引" happy"可能是551.一个句子形式model_input"我很高兴",我想检查这句话是否满意,并创建一个热门的numpy arry。

那么可以实现伪代码吗?

写keras和tensorflow是如此困难><

0 个答案:

没有答案