提前谢谢。
下面是使用张量流的编码器网络。
def encoder_net():
input_ids = tf.keras.Input(shape=(MAX_SEQ_LEN,), batch_size=BATCH_SIZE, name="input_ids", dtype=tf.int32)
attention_mask = tf.keras.Input(shape=(MAX_SEQ_LEN,), batch_size=BATCH_SIZE, name="attention_mask", dtype=tf.int32)
token_type_ids = tf.keras.Input(shape=(MAX_SEQ_LEN,), batch_size=BATCH_SIZE, name="token_type_ids", dtype=tf.int32)
bert_input = [input_ids, attention_mask, token_type_ids]
encoder = TFAlbertModel.from_pretrained('albert-base-v2')
embeddings = encoder(bert_input)
encoder_network = tf.keras.Model(inputs=bert_input, outputs=embeddings)
return encoder_network
我想换成pytorch。
pytorch中有tf.keras.Model之类的模块吗?
def encoder_net(input_ids, attention_mask, token_type_ids):
bert_input = [input_ids, attention_mask, token_type_ids]
encoder = AlbertModel.from_pretrained('albert-base-v2')
embeddings = encoder(bert_input)
encoder_network = **something**(inputs=bert_input, outputs=embeddings)
return encoder_network