如何将通用句子编码器3的嵌入微调到自己的语料库

时间:2019-08-05 15:11:49

标签: tensorflow machine-learning nlp transfer-learning tensorflow-hub

我希望将Google通用句子编码器大3(https://tfhub.dev/google/universal-sentence-encoder-large/3)产生的嵌入微调到我自己的语料库。任何有关如何执行此操作的建议将不胜感激。我目前的想法是将句子对从我的语料库馈送到编码器,然后使用额外的一层对它们在语义上是否相同进行分类。我的麻烦是我不确定如何设置此设置,因为这需要设置两个具有权重的USE模型,我认为它被称为暹罗网络。对此方法的任何帮助将不胜感激

def train_and_evaluate_with_module(hub_module, train_module=False):
    embedded_text_feature_column1 = hub.text_embedding_column(
      key="sentence1", module_spec=hub_module, trainable=train_module)

    embedded_text_feature_column2 = hub.text_embedding_column(
      key="sentence2", module_spec=hub_module, trainable=train_module)


    estimator = tf.estimator.DNNClassifier(
      hidden_units=[500, 100],
      feature_columns=[embedded_text_feature_column1,embedded_text_feature_column2],
      n_classes=2,
      optimizer=tf.train.AdagradOptimizer(learning_rate=0.003))

    estimator.train(input_fn=train_input_fn, steps=1000)

    train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)
    test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)

    training_set_accuracy = train_eval_result["accuracy"]
    test_set_accuracy = test_eval_result["accuracy"]

    return {
      "Training accuracy": training_set_accuracy,
      "Test accuracy": test_set_accuracy
    }

1 个答案:

答案 0 :(得分:1)

请参见https://github.com/tensorflow/hub/issues/134:初始化一个hub.Module(..., trainable=True)对象并调用两次。