预测器预测功能,可使用保存的BERT模型对文本进行分类

时间:2019-07-03 10:13:24

标签: python tensorflow deep-learning text-classification

我创建了一个BERT模型,用于将用户生成的文本字符串分类为FAQ或不是FAQ。我已经使用export_savedmodel()函数保存了模型。我希望编写一个函数来预测一组新字符串的输出,该函数将字符串列表作为输入。

我尝试使用预报器.from_saved_model()方法,但是该方法要求传递键值对来输入ID,段ID,标签ID和输入掩码。 我是一个初学者,我不完全明白该怎么做。

导出或保存模型

def serving_input_fn():
    label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'label_ids': label_ids,
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
    })()
    return input_fn

export_dir = "..."
estimator._export_to_tpu = False
estimator.export_savedmodel(export_dir, serving_input_fn)

#Predicting
with tf.Session() as sess:   
    predict_fn = predictor.from_saved_model(...')

#Data description
My data is a simple table having a column for input string and another for output label.

# Error.
ValueError: Got unexpected keys in input_dict: {'pred'}
expected: {'label_ids', 'input_mask', 'segment_ids', 'input_ids'}

#Thank you for any help!

0 个答案:

没有答案