功能API中带有训练标志的Keras自定义层

时间:2020-06-05 07:45:49

标签: python tensorflow image-processing keras nlp

我正在尝试使用编码器-解码器体系结构制作图像字幕模型。

这是我的编码器:

sub_inception = k.Model(inputs=inc_v3.input,outputs=inc_v3.get_layer('mixed5').output)

inp = k.layers.Input([250,250,3])
processed = k.applications.inception_v3.preprocess_input(inp)
encoded = sub_inception(processed)
reshaped = k.layers.Reshape((-1,encoded.shape[3]))(encoded)
encoder =  k.Model(inputs=inp,outputs=reshaped) 

这是我的解码器:

class Decoder(k.layers.Layer):
  def __init__(self,decoder_rnn_units,
               target_vocab_size,
               input_caption_vocab_size,
               emb_size=100,
               emb_weights=None,
               attention_units=1,
               start_token=459,
               end_token=162,
               *args,**kwargs):
    super(Decoder,self).__init__(*args,**kwargs)

    self.start_token=start_token
    self.end_token=end_token

    self.rnn_cell = tf.keras.layers.LSTMCell(decoder_rnn_units)
    self.out_dec = tf.keras.layers.Dense(target_vocab_size,activation=tf.keras.activations.softmax)

    self.emb = tf.keras.layers.Embedding(input_dim=input_caption_vocab_size,
                                         output_dim=emb_size,
                                         mask_zero=True, 
                                         trainable=False,weights=[emb_weights]) 

    self.sampler = tfa.seq2seq.sampler.TrainingSampler()
    self.greedy_sampler = tfa.seq2seq.GreedyEmbeddingSampler(self.emb)

    self.att_mechanism = tfa.seq2seq.BahdanauAttention(units=attention_units)
    self.decoder_cell = tfa.seq2seq.AttentionWrapper(self.rnn_cell, self.att_mechanism)

    self.decoder = tfa.seq2seq.BasicDecoder(self.decoder_cell,self.sampler,self.out_dec)
    self.inf_decoder = tfa.seq2seq.BasicDecoder(self.decoder_cell,self.greedy_sampler,self.out_dec)

  def build_decoder_initial_state(self, inputs,Dtype):
      decoder_initial_state = self.decoder_cell.get_initial_state(inputs = inputs, dtype = Dtype)
      return decoder_initial_state

  def call(self,inputs,training=None):
    inps, encoder_out = inputs
    inputs_emb = self.emb(inps)
    self.att_mechanism.setup_memory(encoder_out)
    init_state = self.build_decoder_initial_state(inputs_emb,tf.float32)


    if training:
      decoder_outputs,_,_ = self.decoder(inputs_emb,initial_state=init_state)
      return decoder_outputs.rnn_output
    else:
      start_tokens = tf.zeros_like(inputs_emb[:,0]) + self.start_token
      decoder_outputs,_,_ = self.inf_decoder(inputs_emb,initial_state=init_state,
                                             start_tokens=start_tokens,end_token=self.end_token)
      return decoder_outputs.rnn_output

在我的解码器的call方法中,将training标志设置为None。但是,当我尝试通过以下方式构建最终模型时:

k.backend.clear_session()

decoder = Decoder(10,len(vocab) + 2,len(vocab) + 2,emb_weights=emb_mat)

img_captioner_img_inp = k.layers.Input(shape=(250,250,3),name='images')
img_captioner_cap_inp = k.layers.Input(shape=(None,),name='captions')

encoded_img = encoder(img_captioner_img_inp)
decoded = decoder([img_captioner_cap_inp,encoded_img])

img_captioner = k.Model(inputs=[img_captioner_img_inp,img_captioner_cap_inp],outputs=[decoded])

它给出错误TypeError: tf__initialize() got an unexpected keyword argument 'mask'
如果我将training标志设置为True,那么它可以工作,但是该模型无法以这种方式在推理模式下工作。
为什么会这样?功能API不支持屏蔽吗?

0 个答案:

没有答案