保存和读取自定义Tensorflow模型

时间:2020-09-18 21:46:35

标签: tensorflow keras model save layer

我正在使用tensorflow 2.3生成自定义的编码器/解码器网络。我可以生成,训练和评估网络;但是,当我尝试将其保存到文件中并随后进行加载时,根据是否以急切的执行模式运行,会出现不同的错误:

  1. 在不使用急切模式的情况下,代码可以正常运行并显示结果。但是,显示警告:

tensorflow / core / common_runtime / graph_constructor.cc:808]节点“ max_pool_layer_with_indices / PartitionedCall”具有2个输出,但 _output_shapes属性为4个输出指定形状。输出形状可能不正确。

我不明白这4个输出来自何处。我的图层定义有问题吗?

  1. 在热切的执行模式下,当我尝试使用训练好的模型进行预测时,tensorflow返回一个错误:

找不到匹配的函数来调用从SavedModel加载的函数。获得:位置参数(共1个): tf.Tensor(...,shape =(1,256,256,1),dtype = float32)关键字 参数:{'training':False}期望这些参数匹配一个 以下2个选项中的一个:选项1:位置自变量(2 总): * TensorSpec(形状=(无,256,256,1),dtype = tf.float32,name ='input_1') *错误的关键字参数:{}选项2:位置参数(共2个): * TensorSpec(形状=(无,256,256,1),dtype = tf.float32,name ='input_1') *真实的关键字参数:{}

我尝试生成一个最小的示例来重现该问题:

import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow import keras
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import conv_utils
import matplotlib.pyplot as plt

class MaxPoolLayerWithIndices(Layer):
    
  def __init__(self, in_shape, pool_size=(2, 2), strides=None, padding='valid', 
                 data_format=None, name=None, **kwargs):
        super(MaxPoolLayerWithIndices, self).__init__(name=name, trainable=False, **kwargs)
        if data_format is None:
            data_format = backend.image_data_format()
        if strides is None:
            strides = pool_size
        self.pool_size = conv_utils.normalize_tuple(pool_size, 2, 'pool_size')
        self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
        self.padding = conv_utils.normalize_padding(padding)
        self.data_format = conv_utils.normalize_data_format(data_format)
        self.input_spec = InputSpec(ndim=4)
        #self.in_shape = in_shape
  
  @tf.function
  def call(self, inputs):
        if self.data_format == 'channels_last':
            pool_shape = (1,) + self.pool_size + (1,)
            strides = (1,) + self.strides + (1,)
        else:
            pool_shape = (1, 1) + self.pool_size
            strides = (1, 1) + self.strides
           
        res = tf.nn.max_pool_with_argmax(
            inputs, ksize=pool_shape, strides=strides, 
            padding=self.padding.upper(), output_dtype=tf.int64,
            data_format=conv_utils.convert_data_format(self.data_format, 4))

        return res

  def compute_output_shape(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape).as_list()
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
        else:
            rows = input_shape[1]
            cols = input_shape[2]
        rows = conv_utils.conv_output_length(rows, self.pool_size[0], 
                                             self.padding, self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.pool_size[1], 
                                             self.padding, self.strides[1])
        if self.data_format == 'channels_first':
            shape = tensor_shape.TensorShape(
                [input_shape[0], input_shape[1], rows, cols])
        else:
            shape = tensor_shape.TensorShape(
                [input_shape[0], rows, cols, input_shape[3]])

        return [shape, shape]

  def get_config(self):
        config = {
            'pool_size': self.pool_size,
            'padding': self.padding,
            'strides': self.strides,
            'data_format': self.data_format
            }
        base_config = super(MaxPoolLayerWithIndices, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class SimpleModel(tf.keras.Model):

    def __init__(self, in_shape, filter_size=3):
        super(SimpleModel, self).__init__()

        # Set Parameters
        self.filter_size = filter_size

        # Create Layers       
        self.input_layer = tf.keras.layers.InputLayer(input_shape=in_shape, batch_size=1)
        self.conv_layer = tf.keras.layers.Conv2D(1, [filter_size, filter_size], strides=(1,1), padding='same', use_bias=False)
        self.pooling_layer = MaxPoolLayerWithIndices(in_shape, pool_size=(2,2), strides=(2,2), padding = 'valid')


    @tf.function
    def call(self, inputs):
        x = self.input_layer(inputs)
        x, ind = self.pooling_layer(self.conv_layer(x))

        return x


tf.config.run_functions_eagerly(True)

# Set Input Path
model_filename = 'E:\\simple_test_model'

# Initialize GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

# Create Random Dataset
ds = tf.data.Dataset.from_tensor_slices( (tf.random.uniform([1,256,256,1]), tf.random.uniform([1,128,128,1])) ).batch(1).repeat()
img, lbl = next(iter(ds))    

# Generate Model
model = SimpleModel([256,256])

# Compile Model
model.compile(optimizer='adam',
                loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),  #SparseCategoricalCrossentropy
                metrics=['accuracy'])

# Train Model
model_history = model.fit(ds, epochs=1, steps_per_epoch=10)

# Save Model
model.save(model_filename)

# Load Model
rec_model = tf.keras.models.load_model(model_filename)
 
# Run Loaded Model
pred_org = model(img)
pred_rec = rec_model.predict(img)

# Display Results
plt.imshow(tf.concat([pred_org[0,:,:,0], pred_rec[0,:,:,0]], axis=1), cmap=plt.get_cmap('gray'), aspect=1.0)
plt.show()

0 个答案:

没有答案