在tf.keras模型中嵌入预处理功能

时间:2020-04-12 15:46:34

标签: python tensorflow keras

我试图在已经训练有素的tf.keras模型中嵌入一个简单的图像预处理功能。这是一个有用的功能,因为它可以帮助我们减少使用任何模型进行服务时所需的大量样板代码。借助此功能,您可以为模型提供更多的灵活性和模块化。

因此,在训练了模型之后,我首先要定义一个预处理函数-

def preprocess_image_cv2(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = cv2.resize(img, (28, 28)).astype("float32")
    img = img / 255
    img = np.expand_dims(img, 0)
    img = tf.convert_to_tensor(img)
    return img

然后我将其与受过训练的模型一起用于创建另一个模型类-

# Define the model for predcition purpose
class ExportModel(tf.keras.Model):
    def __init__(self, preproc_func, model):
        super().__init__(self)
        self.preproc_func = preproc_func
        self.model = model

    @tf.function
    def my_serve(self, image_path):
        print("Inside")
        preprocessed_image = self.preproc_func(image_path) # Preprocessing
        probabilities = self.model(preprocessed_image, training=False) # Model prediction
        class_id = tf.argmax(probabilities[0], axis=-1) # Postprocessing
        return {"class_index": class_id}

然后我可以使用以下设置对示例图像进行推断:

# Now initialize a dummy model and fill its parameters with that of
# the model we trained
restored_model = get_training_model()
restored_model.set_weights(apparel_model.get_weights())

# Now use this model, preprocessing function, and the same image
# for checking if everything is working
serving_model = ExportModel(preprocess_image_cv2, restored_model)
class_index = serving_model.my_serve("sample_image.png")
CLASSES[class_index["class_index"].numpy()] # prints Dress

但是我无法导出该模型进行投放。我正在执行以下导出操作-

# Make sure we are *not* letting the model to train
tf.keras.backend.set_learning_phase(0)

# Serialize model
export_path = "model_preprocessing_func"
tf.saved_model.save(serving_model, export_path, signatures={"serving_default": serving_model.my_serve})

这产生-

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-97-9e2616e04da9> in <module>()
      1 export_path = "model_preprocessing_func"
----> 2 tf.saved_model.save(serving_model, export_path, signatures={"serving_default": serving_model.my_serve})

2 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
    949 
    950   _, exported_graph, object_saver, asset_info = _build_meta_graph(
--> 951       obj, export_dir, signatures, options, meta_graph_def)
    952   saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
    953 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, export_dir, signatures, options, meta_graph_def)
   1009 
   1010   signatures, wrapped_functions = (
-> 1011       signature_serialization.canonicalize_signatures(signatures))
   1012   signature_serialization.validate_saveable_view(checkpoint_graph_view)
   1013   signature_map = signature_serialization.create_signature_map(signatures)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_serialization.py in canonicalize_signatures(signatures)
    110           ("Expected a TensorFlow function to generate a signature for, but "
    111            "got {}. Only `tf.functions` with an input signature or "
--> 112            "concrete functions can be used as a signature.").format(function))
    113 
    114     wrapped_functions[original_function] = signature_function = (

ValueError: Expected a TensorFlow function to generate a signature for, but got <tensorflow.python.eager.def_function.Function object at 0x7fd5b646ea58>. Only `tf.functions` with an input signature or concrete functions can be used as a signature.

我能够解释错误的最后一部分,但是我无法弄清楚应该采取什么步骤来解决它。可以使用此Colab Notebook重现该问题。感谢帮助。

0 个答案:

没有答案