从TF 1.x转换为TF 2.0 keras

时间:2020-04-30 17:22:19

标签: python tensorflow tf.keras

我也有一个使用TF-Slim API用TF 1.x代码编写的模型。是否可以按原样将其在TF 2.0中转换为tf.keras?例如,参数和训练的数量是否完全相同?

就我而言,我已经尝试过这样做,但是我在tf.keras中的模型实际上比TF 1.x中的模型具有约5%的LESS参数。我还注意到我在tf.keras中的模型的训练阶段也不太顺利。有什么想法吗?谢谢

也许我正在设置一些参数以不同方式初始化图层?任何其他建议将不胜感激

这不是我的完整模型,但是我使用了以下许多组件:

原始TF.1x模型:

import tensorflow as tf
from tensorflow.contrib import slim

def batch_norm_relu(inputs, is_training):
    net = slim.batch_norm(inputs, is_training=is_training)
    net = tf.nn.relu(net)
    return net

def conv2d_transpose(inputs, output_channels, kernel_size):
    upsamp = tf.contrib.slim.conv2d_transpose(
                                                    inputs,
                                                    num_outputs=output_channels,
                                                    kernel_size=kernel_size,
                                                    stride=2,
                                            )
    return upsamp

def conv2d_fixed_padding(inputs, filters, kernel_size, stride, rate):
    net = slim.conv2d(inputs,
                      filters,
                      kernel_size,
                      stride=stride,
                      rate = rate,
                      padding=('SAME' if stride == 1 else 'VALID'),
                      activation_fn=None
                      )
    return net

def block(inputs, filters, is_training, projection_shortcut, stride):
    inputs = batch_norm_relu(inputs, is_training)  
    shortcut = inputs

    if projection_shortcut is not None:
        shortcut = projection_shortcut(inputs)

    conv_k1_s1_r1 = shortcut
    conv_k3_s1_r1 = slim.conv2d(shortcut,
                                  filters,
                                  kernel_size = 3,
                                  stride = 1,
                                  rate = 1,
                                  padding=('SAME' if stride == 1 else 'VALID'),
                                  activation_fn=None
                              )

    conv_k3_s1_r3 = slim.conv2d(shortcut,
                                  filters,
                                  kernel_size = 3,
                                  stride = 1,
                                  rate = 3,
                                  padding=('SAME' if stride == 1 else 'VALID'),
                                  activation_fn=None
                              )

    conv_k3_s1_r5 = slim.conv2d(shortcut,
                                  filters,
                                  kernel_size = 3,
                                  stride = 1,
                                  rate = 5,
                                  padding=('SAME' if stride == 1 else 'VALID'),
                                  activation_fn=None
                              )

    net = conv_k1_s1_r1 + conv_k3_s1_r1 + conv_k3_s1_r3 + conv_k3_s1_r5
    net = batch_norm_relu(net, is_training)
    net = conv2d_fixed_padding(inputs=net, filters=filters, kernel_size=1, stride=1, rate = 1)
    outputs = shortcut + net
    return outputs

尝试使用TF 2.x.keras模型的相同组件:

import tensorflow as tf

class BatchNormRelu(tf.keras.layers.Layer):
    """Batch normalization + ReLu"""
    def __init__(self, name=None):
        super(BatchNormRelu, self).__init__(name=name)
        self.bnorm = tf.keras.layers.BatchNormalization(momentum=0.999,
                                                        scale=False)
        self.relu = tf.keras.layers.ReLU()

    def call(self, inputs, is_training):
        x = self.bnorm(inputs, training=is_training)
        x = self.relu(x)
        return x

class Conv2DTranspose(tf.keras.layers.Layer):
    """Conv2DTranspose layer"""
    def __init__(self, output_channels, kernel_size, name=None):
        super(Conv2DTranspose, self).__init__(name=name)
        self.tconv1 = tf.keras.layers.Conv2DTranspose(
                                            filters=output_channels,
                                            kernel_size=kernel_size,
                                            strides=2,
                                            padding='same',
                                            activation=tf.keras.activations.relu
                                            )

    def call(self, inputs):
        x = self.tconv1(inputs)
        return x

class Conv2DFixedPadding(tf.keras.layers.Layer):
    """Conv2D Fixed Padding layer"""
    def __init__(self, filters, kernel_size, stride, rate, name=None):
        super(Conv2DFixedPadding, self).__init__(name=name)
        self.conv1 = tf.keras.layers.Conv2D(filters, 
                           kernel_size, 
                           strides=stride, 
                           dilation_rate=rate,
                           padding=('same' if stride==1 else 'valid'),
                           activation=None
                           )

    def call(self, inputs):
        x = self.conv1(inputs)
        return x

class block(tf.keras.layers.Layer):
    def __init__(self,
                 filters,
                 stride,
                 projection_shortcut=True,
                 name=None):
        super(block, self).__init__(name=name)
        self.projection_shortcut = projection_shortcut
        self.brelu1 = BatchNormRelu()
        self.brelu2 = BatchNormRelu()
        self.conv1 = tf.keras.layers.Conv2D(filters, 
                                           kernel_size=3, 
                                           strides=1,
                                           dilation_rate=1,
                                           padding=('same' if stride==1 else 'valid'),
                                           activation=None
                                           )
        self.conv2 = tf.keras.layers.Conv2D(filters,
                                           kernel_size=3, 
                                           strides=1, 
                                           dilation_rate=3,
                                           padding=('same' if stride==1 else 'valid'),
                                           activation=None
                                           )
        self.conv3 = tf.keras.layers.Conv2D(filters, 
                                           kernel_size=3, 
                                           strides=1, 
                                           dilation_rate=5,
                                           padding=('same' if stride==1 else 'valid'),
                                           activation=None
                                           )
        self.conv4 = Conv2DFixedPadding(filters, 1, 1, 1)
        self.conv_sc = Conv2DFixedPadding(filters, 1, stride, 1)

    def call(self, inputs, is_training):
        x = self.brelu1(inputs, is_training)
        shortcut = x
        if self.projection_shortcut:
            shortcut = self.conv_sc(x)
        conv_k1_s1_r1 = shortcut
        conv_k3_s1_r1 = self.conv1(shortcut)
        conv_k3_s1_r3 = self.conv2(shortcut)
        conv_k3_s1_r5 = self.conv3(shortcut)
        x = conv_k1_s1_r1 + conv_k3_s1_r1 + conv_k3_s1_r3 + conv_k3_s1_r5
        x = self.brelu2(x, is_training)
        x = self.conv4(x)
        outputs = shortcut + x
        return outputs

0 个答案:

没有答案