在BERT的SavedModel中包含令牌生成器

时间:2019-10-12 05:03:40

标签: tensorflow tensorflow-serving tensorflow-estimator

我正在尝试将经过微调的BERT模型导出到SavedModel中。我发现了这样的东西

def serving_input_fn():
    reciever_tensors = {
        "input_ids": tf.placeholder(dtype=tf.int32,
                                    shape=[1, MAX_SEQ_LENGTH])
    }
    features = {
        "input_ids": reciever_tensors['input_ids'],
        "input_mask": 1 - tf.cast(tf.equal(reciever_tensors['input_ids'], 0), dtype=tf.int32),
        "segment_ids": tf.zeros(dtype=tf.int32, shape=[1, MAX_SEQ_LENGTH]),
        "label_ids": tf.placeholder(tf.int32, [None], name='label_ids')
    }
    return tf.estimator.export.ServingInputReceiver(features, reciever_tensors)

estimator._export_to_tpu = False
estimator.export_saved_model("export", serving_input_fn)

但是我更喜欢直接输入一个字符串(例如通用句子编码器模型)而不是编码。有没有办法使标记化成为SavedModel的一部分?

谢谢!

0 个答案:

没有答案