我正在尝试添加“键”以匹配Google AI平台的批量预测输出,但是我的模型输入仅允许一个输入。
看起来像这样:
input = tf.keras.layers.Input(shape=(max_len,))
x = tf.keras.layers.Embedding(max_words, embed_size, weights=[embedding_matrix], trainable=False)(input)
x = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(128, return_sequences=True, dropout=0.3, recurrent_dropout=0.1))(x)
x = tf.keras.layers.Conv1D(128, kernel_size=3, padding="valid", kernel_initializer="glorot_uniform")(x)
avg_pool = tf.keras.layers.GlobalAveragePooling1D()(x)
max_pool = tf.keras.layers.GlobalMaxPooling1D()(x)
x = tf.keras.layers.concatenate([avg_pool, max_pool])
preds = tf.keras.layers.Dense(2, activation="sigmoid")(x)
model = tf.keras.Model(input, preds)
model.summary()
model.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=1e-3), metrics=['accuracy','binary_crossentropy'])
我遇到了this article,但不知道如何将其应用于我的代码。
有什么想法吗?谢谢!
答案 0 :(得分:3)
在执行代码之后,您可以执行以下操作:
首先,从输入中获取键值:
input = tf.keras.layers.Input(shape=(max_len,))
key_raw = tf.keras.layers.Input(shape=(), name='key')
Reshape,以便以后用于连接
key = tf.keras.layers.Reshape((1,), input_shape=())(key_raw)
Concatenate键和最终结果
preds = tf.keras.layers.Dense(2, activation="sigmoid")(x)
preds = tf.keras.layers.concatenate([preds, key])
将其添加到模型的输入中
model = tf.keras.Model([input, key_raw], preds)
输入json文件示例:
{"input_1": [1.2,1.1,3.3,4.3], "key":1}
{"input_1": [0.3, 0.4, 1.5, 1], "key":2}
现在,您可以获取密钥作为预测结果的最后一个元素。 输出示例:
[0.48686566948890686, 0.5113844275474548, 1.0]
[0.505149781703949, 0.5156428813934326, 2.0]
答案 1 :(得分:0)
或者,您可以更改序列化为SavedModel的服务功能(或第二个服务功能)。这很方便,因为您可以拥有一个服务基础架构(即TFServing,Google Cloud AI Platform在线/批处理)同时为有关键和无关键的预测提供服务。另外,当您无权访问生成键的基础keras代码时,可以将键添加到SavedModel。
tf.saved_model.save(model, MODEL_EXPORT_PATH)
loaded_model = tf.keras.models.load_model(MODEL_EXPORT_PATH)
@tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.string),tf.TensorSpec([None, 28, 28], dtype=tf.float32)])
def keyed_prediction(key, image):
pred = loaded_model(image, training=False)
return {
'preds': pred,
'key': key
}
loaded_model.save(KEYED_EXPORT_PATH, signatures={'serving_default': keyed_prediction})