我遇到了这个page
1)在完成微调之后,我想获得句子级嵌入(由[CLS]
令牌赋予的嵌入)。我该怎么办?
2)我还注意到该页面上的代码需要大量时间才能返回测试数据的结果。这是为什么?与我尝试获得测试预测相比,培训模型所需的时间更少。 从该页面上的代码中,我没有使用下面的代码块
test_InputExamples = test.apply(lambda x: bert.run_classifier.InputExample(guid=None,
text_a = x[DATA_COLUMN],
text_b = None,
label = x[LABEL_COLUMN]), axis = 1
test_features = bert.run_classifier.convert_examples_to_features(test_InputExamples, label_list, MAX_SEQ_LENGTH, tokenizer)
test_input_fn = run_classifier.input_fn_builder(
features=test_features,
seq_length=MAX_SEQ_LENGTH,
is_training=False,
drop_remainder=False)
estimator.evaluate(input_fn=test_input_fn, steps=None)
我只是在整个测试数据上使用了以下功能
def getPrediction(in_sentences):
labels = ["Negative", "Positive"]
input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
predictions = estimator.predict(predict_input_fn)
return [(sentence, prediction['probabilities'], labels[prediction['labels']]) for sentence, prediction in zip(in_sentences, predictions)]
3)我如何获得预测概率。有没有办法使用keras predict
方法?
问题2更新-
您可以使用getPrediction
函数在20000个训练示例上进行测试吗?...。对我来说,这花费的时间更长。甚至比在20000个例子上训练模型所花费的时间还要长。
答案 0 :(得分:3)
输出字典包含:
pooled_output:合并具有形状的整个序列的输出 [batch_size,hidden_size]。 sequence_output:每一个的表示 输入序列中形状为[batch_size, max_sequence_length,hidden_size]。
我添加了pooled_output
向量,它对应于CLS向量。
3)您收到日志概率。只需应用softmax
即可获得正常概率。
现在剩下要做的就是让模型报告它。我已经离开了日志概率,但是它们不再是必需的。
查看代码更改:
def create_model(is_predicting, input_ids, input_mask, segment_ids, labels,
num_labels):
"""Creates a classification model."""
bert_module = hub.Module(
BERT_MODEL_HUB,
trainable=True)
bert_inputs = dict(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids)
bert_outputs = bert_module(
inputs=bert_inputs,
signature="tokens",
as_dict=True)
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_outputs" for token-level output.
output_layer = bert_outputs["pooled_output"]
pooled_output = output_layer
hidden_size = output_layer.shape[-1].value
# Create our own layer to tune for politeness data.
output_weights = tf.get_variable(
"output_weights", [num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [num_labels], initializer=tf.zeros_initializer())
with tf.variable_scope("loss"):
# Dropout helps prevent overfitting
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
probs = tf.nn.softmax(logits, axis=-1)
# Convert labels into one-hot encoding
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
predicted_labels = tf.squeeze(tf.argmax(log_probs, axis=-1, output_type=tf.int32))
# If we're predicting, we want predicted labels and the probabiltiies.
if is_predicting:
return (predicted_labels, log_probs, probs, pooled_output)
# If we're train/eval, compute loss between predicted and actual label
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
return (loss, predicted_labels, log_probs, probs, pooled_output)
现在在model_fn_builder()
中添加对这些值的支持:
# this should be changed in both places
(predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
# return dictionary of all the values you wanted
predictions = {
'log_probabilities': log_probs,
'probabilities': probs,
'labels': predicted_labels,
'pooled_output': pooled_output
}
相应地调整getPrediction()
,最终您的预测将如下所示:
('That movie was absolutely awful',
array([0.99599314, 0.00400678], dtype=float32), <= Probability
array([-4.0148855e-03, -5.5197663e+00], dtype=float32), <= Log probability, same as previously
'Negative', <= Label
array([ 0.9181199 , 0.7763732 , 0.9999883 , -0.93533266, -0.9841384 ,
0.78126144, -0.9918988 , -0.18764131, 0.9981035 , 0.99999994,
0.900716 , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
0.9501321 , 0.75836045, 0.49151263, -0.7886792 , 0.97505844,
-0.8931161 , -1. , 0.9318583 , -0.60531116, -0.8644371 ,
...
and this is 768-d [CLS] vector (sentence embedding).
关于2):在我的训练结束时,大约花了5分钟,测试了大约40秒。非常合理。
更新
对于2万个样本,需要12:48的训练时间和2:07分钟的测试时间。
对于1万个样本,计时分别为8:40和1:07。
答案 1 :(得分:3)
当然,剩下的就是这些更改:
# model_fn_builder actually creates our model function
# using the passed parameters for num_labels, learning_rate, etc.
def model_fn_builder(num_labels, learning_rate, num_train_steps,
num_warmup_steps):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
is_predicting = (mode == tf.estimator.ModeKeys.PREDICT)
# TRAIN and EVAL
if not is_predicting:
(loss, predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
train_op = bert.optimization.create_optimizer(
loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)
# Calculate evaluation metrics.
def metric_fn(label_ids, predicted_labels):
accuracy = tf.metrics.accuracy(label_ids, predicted_labels)
f1_score = tf.contrib.metrics.f1_score(
label_ids,
predicted_labels)
auc = tf.metrics.auc(
label_ids,
predicted_labels)
recall = tf.metrics.recall(
label_ids,
predicted_labels)
precision = tf.metrics.precision(
label_ids,
predicted_labels)
true_pos = tf.metrics.true_positives(
label_ids,
predicted_labels)
true_neg = tf.metrics.true_negatives(
label_ids,
predicted_labels)
false_pos = tf.metrics.false_positives(
label_ids,
predicted_labels)
false_neg = tf.metrics.false_negatives(
label_ids,
predicted_labels)
return {
"eval_accuracy": accuracy,
"f1_score": f1_score,
"auc": auc,
"precision": precision,
"recall": recall,
"true_positives": true_pos,
"true_negatives": true_neg,
"false_positives": false_pos,
"false_negatives": false_neg
}
eval_metrics = metric_fn(label_ids, predicted_labels)
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
train_op=train_op)
else:
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
eval_metric_ops=eval_metrics)
else:
(predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
predictions = {
'log_probabilities': log_probs,
'probabilities': probs,
'labels': predicted_labels,
'pooled_output': pooled_output
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
# Return the actual model function in the closure
return model_fn
def getPrediction(in_sentences):
labels = ["Negative", "Positive"]
input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
predictions = estimator.predict(predict_input_fn)
return [(sentence, prediction['probabilities'], prediction['log_probabilities'], labels[prediction['labels']], prediction['pooled_output']) for sentence, prediction in zip(in_sentences, predictions)]
和第一个输出(其他输出截止到答案的不超过30K符号限制):
[('That movie was absolutely awful',
array([0.99599314, 0.00400678], dtype=float32),
array([-4.0148855e-03, -5.5197663e+00], dtype=float32),
'Negative',
array([ 0.9181199 , 0.7763732 , 0.9999883 , -0.93533266, -0.9841384 ,
0.78126144, -0.9918988 , -0.18764131, 0.9981035 , 0.99999994,
0.900716 , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
0.9501321 , 0.75836045, 0.49151263, -0.7886792 , 0.97505844,
-0.8931161 , -1. , 0.9318583 , -0.60531116, -0.8644371 ,
-0.9999866 , 0.5820049 , 0.3257555 , -0.81900954, -0.8326617 ,
0.87788117, -0.7791749 , 0.11098853, 0.67873836, 0.9999771 ,
0.9833652 , -0.8420576 , 0.83076835, 0.37272754, 0.8667175 ,
0.792386 , -0.82003427, -0.9999999 , -0.9382297 , -0.9713775 ,
0.55752313, 1. , -0.72632766, -0.4752956 , -0.9999852 ,
-0.99974227, -0.9998661 , -0.3094257 , -0.93023825, -0.72663504,
0.92974335, -0.8601105 , -0.8113003 , 0.7660112 , 0.9313508 ,
0.21427669, -0.45660907, 0.99970686, 0.56852764, -0.9997675 ,
-0.9999096 , 0.8247045 , 0.7205424 , 0.47192624, -0.7523966 ,
-0.9588541 , -0.48866934, 0.9809366 , -0.07110611, -0.99886 ,
-0.63922834, -0.68144 , -1. , 0.8531816 , 0.26078308,
-0.99898577, -0.99968046, 0.6711601 , 0.99857473, -0.99990964,
1. , -0.97127694, -0.10644457, 0.46306637, -0.32486317,
-0.68167734, 0.43291137, -0.996574 , 0.05164305, 0.9897354 ,
0.93853104, 0.94800174, 0.9995697 , 0.6532897 , 0.93846226,
-0.6281378 , 0.5574107 , 0.725278 , 0.74160355, -0.6486919 ,
0.88869256, 0.9439776 , -0.9654787 , -0.95139974, -0.9366148 ,
0.17409436, 0.83473635, -0.87414986, -0.35965624, -0.8395183 ,
0.5546853 , 0.7452196 , -0.6152899 , -0.82187194, -0.65487677,
0.94367695, 0.6834396 , -0.72266734, 0.99376386, -0.76821744,
0.4485644 , 0.99982166, 1. , 0.9260674 , 0.9759094 ,
0.9397613 , 0.8128903 , -0.7918152 , 0.30299878, -0.95160294,
0.25385544, -0.57780135, -0.9999994 , 0.9168113 , -0.36585295,
0.9798102 , 0.95976156, -0.99428 , 0.6471789 , -0.9948078 ,
-0.9686591 , 0.93615085, -0.11481134, 0.87566274, -0.91601896,
0.9952683 , 0.26532048, 0.99861896, 0.79298306, 0.5872364 ,
-0.56314534, 0.96794534, 0.9999797 , 0.9879324 , 0.5003342 ,
0.9516269 , -0.8878316 , -0.9665091 , -0.88037425, 0.8356687 ,
-0.71543014, -0.99985015, -0.9414574 , 0.8681497 , 0.950698 ,
-0.8007153 , 0.78748596, 0.9999305 , 0.40210736, 0.4856055 ,
-0.9390776 , 0.63564163, -0.85989815, -0.8421344 , -0.99436 ,
0.78081733, -0.97038007, 0.39290914, 0.7834218 , 0.88715357,
-0.03653741, 0.99126273, -0.96559966, 0.11924513, -0.99363935,
-0.9901692 , 0.963858 , 0.5713922 , 0.5676979 , 0.69982123,
0.858003 , 0.9983819 , -0.87965024, 0.46213093, -0.3256273 ,
0.77337253, 0.7246244 , -0.99894017, -0.9170495 , -0.98803675,
-0.93148243, 0.09674019, 0.09448949, -0.7453027 , -0.78955775,
-0.6304773 , -0.5597632 , 0.992308 , 0.7769483 , 0.04146893,
-0.15876745, -0.7682887 , -0.5231416 , 0.7871302 , 0.9503481 ,
-0.9607153 , 0.99047405, -0.9948017 , -0.82257754, 0.9990552 ,
0.79346406, -0.78624016, 0.8760266 , -0.7855991 , 0.13444276,
-0.7183107 , -0.9999819 , 0.7019429 , -0.918913 , -0.6569654 ,
0.9998794 , -0.33805153, -0.9427715 , 0.10419375, -0.94257164,
0.9187495 , -0.9994855 , -0.99979955, -0.9277688 , 0.6353426 ,
0.9994905 , 0.90688777, 0.9992008 , 0.7817533 , -0.9996674 ,
-0.999962 , -0.13310781, -0.82505953, 0.9997485 , 0.82616794,
-0.999998 , 0.45386457, 0.6069964 , 0.52272975, 0.8811922 ,
0.52668494, -0.9994814 , -0.21601789, -0.99882716, 0.90246916,
0.94196504, 0.30058604, -0.9876776 , -0.7699927 , -0.9980288 ,
0.7727592 , 0.9936947 , 0.98021245, -0.77723926, -0.785372 ,
0.5150317 , 0.9983137 , -0.7461883 , 0.3311537 , -0.63709795,
-0.6487831 , -0.9173727 , 0.9997706 , -0.9999893 , -1. ,
0.60389155, -0.6516268 , -0.95422006, 1. , 0.09109057,
-0.99999994, 0.99998957, 1. , -0.19451752, 0.94624877,
-0.2761865 , 1. , 0.52399474, 0.70230734, 0.5218801 ,
-0.99716544, -0.70075685, -0.99992603, 1. , -0.9785006 ,
0.22457084, -0.5356722 , -0.9991887 , 0.7062409 , 0.66816545,
-0.90308225, -0.8084922 , 0.50301254, -0.7062079 , 0.9998321 ,
0.9823206 , 0.9984027 , 0.9948857 , -1. , -0.7067878 ,
0.975454 , 0.87161005, -0.9882297 , 0.8296374 , -0.88615334,
0.4316883 , 0.86287475, -0.9893329 , -0.9022001 , -0.68322754,
-0.84212875, 0.78632677, -0.5131366 , -0.996949 , -0.75479275,
-0.06342169, 0.92238575, 0.66769385, 0.9926053 , -0.78391105,
0.9976865 , 0.07086544, 0.34079495, 0.69730175, -0.99970955,
-1. , -0.9860551 , 0.89584446, -0.96889114, -0.90435815,
0.944296 , -1. , -0.9931756 , -0.7014334 , -0.6742562 ,
-0.96786517, 0.848328 , 0.8903087 , -0.9998633 , 0.73993397,
0.99345684, 0.9691821 , 0.87563246, -0.6073146 , -0.9999999 ,
0.90763575, 0.30225936, -0.47824544, 0.7179979 , 0.9450465 ,
0.9715953 , -0.5422173 , 0.99995065, -0.5920663 , 0.92390317,
-0.9670669 , -0.3623574 , 0.74825 , -0.7817521 , 0.9888685 ,
-0.7653631 , -0.8933355 , 0.9481424 , 0.97803396, -0.9999731 ,
-0.89597356, 0.35502487, -0.7190486 , 0.30777818, 0.55025375,
0.6365793 , -0.99094397, -1. , 0.93482614, -0.99970514,
0.98721176, 0.14699097, -0.86038756, -0.68365514, -0.8104672 ,
0.57238674, 0.97475344, -0.9963499 , 0.98476464, 0.40495875,
-0.7001948 , -0.40898973, 0.61900675, -1. , -0.9371812 ,
-0.62749994, -0.8841316 , -0.9999847 , -0.39386114, -0.925245 ,
-0.99991447, -0.5872595 , 0.5835767 , 0.7003338 , -0.9761974 ,
0.99995846, 0.33676207, 0.9079994 , -0.76412004, -0.7648706 ,
0.68863285, 0.43983305, 0.74911463, -0.99995685, -0.6692586 ,
-0.45761266, -0.9980771 , -1. , 0.31244457, -0.8834693 ,
0.9388263 , -0.987405 , 1. , 0.9512058 , 0.23448633,
0.37940192, 0.99989796, 0.8402514 , -0.84526414, 0.7378776 ,
-0.9996204 , -0.99434114, 0.9987527 , 0.5569713 , 0.99648696,
-0.9933159 , -0.13116199, 0.9999992 , 0.9642579 , -0.48285434,
-0.97517425, 0.7185596 , 0.5286405 , 0.9902838 , 0.7796022 ,
-0.80703837, 0.2376029 , 0.534117 , -0.9999413 , 0.99828076,
0.9998345 , 0.93249476, 0.3620626 , 0.7567034 , -0.9222681 ,
0.97832036, 0.9999682 , 0.6433209 , -1. , 0.9268615 ,
-0.9999511 , -0.9145363 , -0.9213852 , 0.7606066 , -0.5501025 ,
-0.99999434, -0.7783993 , 0.9999771 , 0.99980384, 0.987094 ,
0.7531475 , -0.8551696 , -0.9973968 , -0.9999853 , -0.08913276,
-0.9919206 , -0.49190572, 0.70230234, -0.31277484, -0.99999964,
0.828591 , 0.6363776 , 0.86796165, 0.81575817, 0.7782955 ,
0.9436437 , -1. , -0.7509046 , -0.9946139 , -0.6647415 ,
0.999543 , 0.9312092 , -1. , 0.5639159 , 0.9482462 ,
-0.9289936 , -0.9678435 , 0.60937124, -0.987818 , 0.5511619 ,
0.75886583, -0.48466644, -0.71833754, 0.8042149 , 0.9154103 ,
-0.8177468 , 0.7195895 , -0.82283056, 0.24990956, -1. ,
0.7729634 , 0.84048635, 0.7989596 , 0.9469012 , -0.9898951 ,
-0.92565274, 0.74726975, 0.78213847, -0.672894 , -0.58831286,
-0.8039038 , -0.72197783, 0.5289216 , -0.9998796 , -0.9904479 ,
0.9996592 , -0.28984115, 0.23964961, -0.7427149 , -0.662416 ,
-1. , -0.5538268 , -0.9945287 , -0.63471127, 0.5896127 ,
-0.48429146, 0.9976076 , -0.94329506, -0.49143887, 0.7695602 ,
0.8638134 , -0.82130384, 0.50105464, 0.9336961 , -0.24716294,
-0.6922282 , -0.02228704, 0.75649065, 0.82303154, -0.30867255,
-0.9602714 , 0.64568967, 0.314201 , -0.4811752 , 0.27952817,
0.9227022 , 0.88095886, 0.89470226, 1. , -0.19237158,
1. , -0.991253 , -0.9991121 , 0.5637482 , -0.75780976,
-0.3904836 , -0.9881965 , -0.2912058 , 0.9998215 , 0.9869475 ,
-0.12784953, 0.81566185, 0.9787118 , -0.17835459, -0.7027824 ,
0.72269535, -0.18194303, 0.9968796 , 0.03490257, 0.7751488 ,
-1. , -0.7761089 , 0.85105944, 0.9968074 , -0.8156342 ,
0.5300792 , -1. , 0.99626255, -0.7515625 , -0.6672005 ,
0.9792111 , 0.8660997 , -0.69161206, 0.32184905, 0.9071073 ,
0.9999385 , -0.82744277, -0.99044186, -0.71309817, -0.5004305 ,
0.70707524, 0.89751345, -0.6819585 , -0.9999414 , -0.45255637,
-0.94375473, -0.91838425, 0.64272994, 0.9375524 , 0.6609169 ,
-0.88743365, -0.9534722 , -0.47888806, -1. , -0.5251781 ,
0.8274516 , 0.9326824 , 0.8961964 , 0.5295862 , 0.43714878,
-0.7488347 , -0.75295556, -0.5187054 , 0.75924635, -0.7862662 ,
0.99981725, -0.80290836, 0.97651815, 0.99763787, -0.29619345,
-0.1252967 , 0.33606276, -0.65137684, -0.9680231 , 0.77586985,
0.22347753, 0.27245504, -0.07826214, -0.8383849 , -0.85373163,
1. , -0.4563588 , -0.91339815, -0.9999861 , 0.66063935,
-0.985843 , -0.7818757 , -0.7000497 , -0.6840764 , 0.9995542 ,
0.60819125, 0.80064404, -0.9776968 , -0.90925264, -0.6644932 ,
-0.8771755 , 0.71411085, 0.8113569 , 0.9974196 , -0.75211936,
0.63400257, -0.8272833 , 0.99780786, 0.9965285 , 0.59551436,
-0.9876875 , -0.04439292, 0.9939223 , 0.9993717 , -0.9965501 ,
-0.9630328 , -0.9027949 , -0.48490363, -0.60193753, -0.6870232 ,
-0.95355797, -0.67561924, 0.9997761 , -0.85473967, 0.998495 ,
-0.95756954, 0.633171 , 0.4570475 , -0.5316367 , -0.9663824 ,
0.9567106 , -0.45497724, 0.12964879, 0.9964744 , -0.9711668 ,
0.69636106, -0.9178346 , 0.8313186 , 0.69686604, 0.8141587 ,
-0.33600506, 0.94798595, 0.8800869 , 0.15029034, -0.91185665,
0.6322724 , -0.9971475 , 0.71948224, 0.9695236 , 0.84242374,
0.99995124, 0.5982563 , -0.98341423, 0.61301434, 0.9997318 ,
-0.9981808 , -0.65651804, -0.8484874 , -0.9961815 , 0.9030814 ,
0.87141925, 0.8897381 , -0.92870414, 0.07134341, 0.8739935 ,
0.91630197, -0.9465984 , -0.59741104, -1. , 0.9989559 ,
0.99991184, 0.67439264, 0.92025673, -0.60730827, 0.8362061 ,
1. , -0.70801497, 0.9883806 , -0.9984141 , 0.9919259 ,
-0.998869 , 0.9976203 , 0.9888036 , 0.8556838 , -0.9722744 ,
-0.99810714, 0.8182833 , 0.98808485, 0.6643728 , 0.99212515,
-0.99988 , 0.26405996, 0.93139845, 0.99021816, 0.6846886 ,
0.9986462 , 0.92254627, -0.6406982 ], dtype=float32)),
('The acting was a bit lacking',
array([0.9921152 , 0.00788479], dtype=float32),
array([-0.00791603, -4.842819 ], dtype=float32),
'Negative',
array([ 0.67417824, 0.8235167 , 0.99999565, -0.8565971 , -0.99499583,
0.8219966 , -0.9185583 , -0.5234593 , 0.99962074, 0.99999714,
0.9507927 , -0.9996754 , 0.22211392, -0.99826247, 0.7562492 ,
0.93803996, 0.82738185, 0.4773049 , -0.73478544, 0.85207295,