我的部分代码如下:
model_fn = model_fn_builder(
bert_config=bert_config,
init_checkpoint=FLAGS.init_checkpoint,
layer_indexes=layer_indexes,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)
estimator = Estimator(
model_fn=model_fn,
)
for id in xls_dict:
features = convert_examples_to_features(examples=examples, seq_length=FLAGS.max_seq_length,tokenizer=tokenizer)
unique_id_to_feature = {}
for feature in features:
unique_id_to_feature[feature.unique_id] = feature
input_fn = input_fn_builder(
features=features, seq_length=FLAGS.max_seq_length)
embs=[]
for result in estimator.predict(input_fn, yield_single_examples=True):
使用estimator.predict
进行的每个预测都将恢复模型,这是无效的。并且model_fn
也多次称为estimator.predict
。我该怎么做才能避免这种情况?