TF Hub通用语句编码器句子相似度的微调

时间:2020-07-29 13:47:35

标签: python tensorflow keras tf-hub

我正在从tf集线器微调USE v4模型。 使用的数据集是带有目标标签[0,1]的句子对。

以下是我的代码,

model = tf.keras.models.Sequential()
model.add(hub.KerasLayer('https://tfhub.dev/google/universal-sentence-encoder/4', 
                        input_shape=[2,], 
                        dtype=tf.string, 
                        trainable=True))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
model.summary()

导致错误

ValueError: Shape must be rank 1 but is rank 2 for '{{node text_preprocessor/tokenize/StringSplit/StringSplit}} = StringSplit[skip_empty=true](text_preprocessor/StaticRegexReplace_1, text_preprocessor/tokenize/StringSplit/Const)' with input shapes: [?,2], [].

如果有人可以帮助我了解我哪里出了问题,那将是很好的。

1 个答案:

答案 0 :(得分:0)

如@qmeeus所述,input_shape需要为[],否则您可以一起跳过指定input_shape。因此,如下所示:

use_url = "https://tfhub.dev/google/universal-sentence-encoder-large/4"


feature_extractor_layer = hub.KerasLayer(use_url, input_shape=[], trainable=True)
model = tf.keras.Sequential([
    feature_extractor_layer,
    layers.Dense(1, activation='sigmoid')
])

github issue可能会有所帮助。

为了传递一对句子,您可以在暹罗网络中重用feature_extractor_layer。