如何避免在TF 2.0 Keras自定义层中为自注意力或自匹配网络指定批次大小

时间:2020-03-23 09:26:52

标签: python tensorflow keras tf.keras tensorflow2.x

层定义

class SelfMatchingLayer_(keras.layers.Layer):

  def __init__(self, batch_size):
    self.batch_size = batch_size
    super().__init__()


  def build(self, input_shape):
    self.M = self.add_weight(
        shape=(input_shape[-2],input_shape[-2],input_shape[-1], input_shape[-1]),
        initializer='glorot_uniform',
        trainable=True)
    b_init = tf.zeros_initializer()
    self.ei = tf.Variable(
        initial_value=b_init(shape=(self.batch_size,input_shape[-2],input_shape[-2],1,input_shape[-1]),
        dtype = 'float32'),
        trainable=False)
    self.ej = tf.Variable(
        initial_value=b_init(shape=(self.batch_size,input_shape[-2],input_shape[-2],input_shape[-1],1),
        dtype = 'float32'),
        trainable=False)
    self.attn = tf.Variable(
        initial_value=b_init(shape=(self.batch_size,input_shape[-2]),
        dtype = 'float32'),
        trainable=False)
    super().build(input_shape)

  def call(self, x):
    for j in tf.range(self.batch_size):
      for i in tf.range(x.shape[-2]):
        self.ei[j,i,i].assign(x[j,i])
        self.ej[j,i,i].assign(tf.transpose([ x[j,i]]))

    w = tf.tanh(tf.einsum('xlkij,xktjf->xlt',tf.einsum('xlkij,ktjf->xltif',self.ei,self.M),self.ej))

    attn = tf.reduce_max(w,axis = 2)
    sent_attn = tf.nn.softmax(attn)
    return sent_attn
text_input = keras.Input(shape=(maxlen,),name='comment')
text_features = keras.layers.Embedding(vocab_size, vocab_dim,
                        embeddings_initializer=Constant(embedding_matrix), 
                        trainable = False, name='LexVec_Embeddings')(text_input)
output = SelfMatchingLayer_(100)(text_features)

model = tf.keras.Model(text_input, output)

tf.keras.utils.plot_model(model,'model.jpg',show_shapes=True)

model plot

输出的形状为(batch_size,embed_dimension),即(100,50)。 有没有一种方法可以避免在初始化层时固定批处理大小? 这样输出形状为(?,50),换句话说,我们可以具有任何批量大小值。

0 个答案:

没有答案