我也有一个使用TF-Slim API用TF 1.x代码编写的模型。是否可以按原样将其在TF 2.0中转换为tf.keras?例如,参数和训练的数量是否完全相同?
就我而言,我已经尝试过这样做,但是我在tf.keras中的模型实际上比TF 1.x中的模型具有约5%的LESS参数。我还注意到我在tf.keras中的模型的训练阶段也不太顺利。有什么想法吗?谢谢
也许我正在设置一些参数以不同方式初始化图层?任何其他建议将不胜感激
这不是我的完整模型,但是我使用了以下许多组件:
原始TF.1x模型:
import tensorflow as tf
from tensorflow.contrib import slim
def batch_norm_relu(inputs, is_training):
net = slim.batch_norm(inputs, is_training=is_training)
net = tf.nn.relu(net)
return net
def conv2d_transpose(inputs, output_channels, kernel_size):
upsamp = tf.contrib.slim.conv2d_transpose(
inputs,
num_outputs=output_channels,
kernel_size=kernel_size,
stride=2,
)
return upsamp
def conv2d_fixed_padding(inputs, filters, kernel_size, stride, rate):
net = slim.conv2d(inputs,
filters,
kernel_size,
stride=stride,
rate = rate,
padding=('SAME' if stride == 1 else 'VALID'),
activation_fn=None
)
return net
def block(inputs, filters, is_training, projection_shortcut, stride):
inputs = batch_norm_relu(inputs, is_training)
shortcut = inputs
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
conv_k1_s1_r1 = shortcut
conv_k3_s1_r1 = slim.conv2d(shortcut,
filters,
kernel_size = 3,
stride = 1,
rate = 1,
padding=('SAME' if stride == 1 else 'VALID'),
activation_fn=None
)
conv_k3_s1_r3 = slim.conv2d(shortcut,
filters,
kernel_size = 3,
stride = 1,
rate = 3,
padding=('SAME' if stride == 1 else 'VALID'),
activation_fn=None
)
conv_k3_s1_r5 = slim.conv2d(shortcut,
filters,
kernel_size = 3,
stride = 1,
rate = 5,
padding=('SAME' if stride == 1 else 'VALID'),
activation_fn=None
)
net = conv_k1_s1_r1 + conv_k3_s1_r1 + conv_k3_s1_r3 + conv_k3_s1_r5
net = batch_norm_relu(net, is_training)
net = conv2d_fixed_padding(inputs=net, filters=filters, kernel_size=1, stride=1, rate = 1)
outputs = shortcut + net
return outputs
尝试使用TF 2.x.keras模型的相同组件:
import tensorflow as tf
class BatchNormRelu(tf.keras.layers.Layer):
"""Batch normalization + ReLu"""
def __init__(self, name=None):
super(BatchNormRelu, self).__init__(name=name)
self.bnorm = tf.keras.layers.BatchNormalization(momentum=0.999,
scale=False)
self.relu = tf.keras.layers.ReLU()
def call(self, inputs, is_training):
x = self.bnorm(inputs, training=is_training)
x = self.relu(x)
return x
class Conv2DTranspose(tf.keras.layers.Layer):
"""Conv2DTranspose layer"""
def __init__(self, output_channels, kernel_size, name=None):
super(Conv2DTranspose, self).__init__(name=name)
self.tconv1 = tf.keras.layers.Conv2DTranspose(
filters=output_channels,
kernel_size=kernel_size,
strides=2,
padding='same',
activation=tf.keras.activations.relu
)
def call(self, inputs):
x = self.tconv1(inputs)
return x
class Conv2DFixedPadding(tf.keras.layers.Layer):
"""Conv2D Fixed Padding layer"""
def __init__(self, filters, kernel_size, stride, rate, name=None):
super(Conv2DFixedPadding, self).__init__(name=name)
self.conv1 = tf.keras.layers.Conv2D(filters,
kernel_size,
strides=stride,
dilation_rate=rate,
padding=('same' if stride==1 else 'valid'),
activation=None
)
def call(self, inputs):
x = self.conv1(inputs)
return x
class block(tf.keras.layers.Layer):
def __init__(self,
filters,
stride,
projection_shortcut=True,
name=None):
super(block, self).__init__(name=name)
self.projection_shortcut = projection_shortcut
self.brelu1 = BatchNormRelu()
self.brelu2 = BatchNormRelu()
self.conv1 = tf.keras.layers.Conv2D(filters,
kernel_size=3,
strides=1,
dilation_rate=1,
padding=('same' if stride==1 else 'valid'),
activation=None
)
self.conv2 = tf.keras.layers.Conv2D(filters,
kernel_size=3,
strides=1,
dilation_rate=3,
padding=('same' if stride==1 else 'valid'),
activation=None
)
self.conv3 = tf.keras.layers.Conv2D(filters,
kernel_size=3,
strides=1,
dilation_rate=5,
padding=('same' if stride==1 else 'valid'),
activation=None
)
self.conv4 = Conv2DFixedPadding(filters, 1, 1, 1)
self.conv_sc = Conv2DFixedPadding(filters, 1, stride, 1)
def call(self, inputs, is_training):
x = self.brelu1(inputs, is_training)
shortcut = x
if self.projection_shortcut:
shortcut = self.conv_sc(x)
conv_k1_s1_r1 = shortcut
conv_k3_s1_r1 = self.conv1(shortcut)
conv_k3_s1_r3 = self.conv2(shortcut)
conv_k3_s1_r5 = self.conv3(shortcut)
x = conv_k1_s1_r1 + conv_k3_s1_r1 + conv_k3_s1_r3 + conv_k3_s1_r5
x = self.brelu2(x, is_training)
x = self.conv4(x)
outputs = shortcut + x
return outputs