请考虑以下假设但非常实际的情况:
I have to create a network which uses two pre-trained networks `A`
and `B`. I would like to concatenate the output of `Layer-L` of
`A` with the output of `Layer-M` of `B` and do further computations
on the result.
现在在TensorFlow 1.x中,我可以使用tf-slim
库(我当然假设A
和B
的预训练模型和代码在{ {1}})。我们知道tf-slim
提供了tf-slim
字典。创建自己的网络时,我可以利用end_points
建立任意连接。
转到TensorFlow 2.x,没有end_points
。 我的问题是,通过尝试将tf-slim
特定代码移植到TensorFlow 2.x,以下内容是否是一种良好的实现做法
作为示例,如果我尝试在TensorFlow 2.x中移植tf-slim
网络的tf-slim
代码,则将其实现为VGG16
类的子类。下面是一个示例:
tensorflow.keras.Model
然后我尝试使用以下代码段测试上述代码:
import tensorflow as tf
"""
The following class RepeatLayer is a basic implementation of
slim.repeat class in TensorFlow-slim.
"""
class RepeatLayer(tf.keras.layers.Layer):
def __init__(self, layerobj, count, layernames, **kwargs):
"""
class instantiator
:param layerobj: A tf.keras.layers object (e.g:- tf.keras.layers.Conv2D)
:param count: (int) Number of times layerobj should be repeated.
:param layernames: (List of length count) Name of each repeatition of layerobj
:param kwargs: layerobj specific named arguments.
"""
super(RepeatLayer, self).__init__()
if not isinstance(count, int):
raise TypeError('The argument "count" must be a positive integer.')
if count <= 0:
raise ValueError('The argument "count" is provided as {}. It must be a'
'positive integer.'.format(count))
if not isinstance(layernames, list):
raise TypeError('The argument "layernames" must be a list of strings.')
if not len(layernames) == count:
raise ValueError('The length of "layernames" must be the value of "count".')
for name, value in kwargs.items():
if not isinstance(value, list):
value = [value] * count
kwargs[name] = value
self._layernames = layernames
self._count = count
self._end_points = dict()
self._outputs = []
self._kwargs = kwargs
for layernum in range(self._count):
args = dict(
map(
lambda x: (x[0], x[1][layernum]),
self._kwargs.items()
)
)
output = layerobj(**args)
self._outputs.append(output)
def call(self, input_tensor):
out = input_tensor
for node, layername in zip(self._outputs, self._layernames):
out = node(out)
self._end_points[layername] = out
return out
@property
def end_points(self):
return self._end_points
因此,from utils import RepeatLayer
import tensorflow as tf
layer = RepeatLayer(tf.keras.layers.Conv2D, 3, ['c1', 'c2', 'c3'],
filters=64, kernel_size=7, strides=1, padding='same',
activation='relu', data_format='channels_last',
use_bias=True)
input_tensor = tf.keras.backend.random_uniform(shape=(3,512,512,3))
print(layer.end_points) # should print an empty dictionary
output_tensor = layer(input_tensor)
print(layer.end_points) # Prints a dictionary in which values are output tensors
字典只能在对模型进行至少一个end_points
的访问之后才能使用。由于默认的急切执行,这是我认为的预期行为。
我现在面临的问题如下:
__call__
移植到TensorFlow 2.x吗?tf-slim
中以进行静态执行以提高效率,那么这行得通吗?我没有尝试过,但是如果没有@tf.function
,TensorFlow会不会发现__call__
字典为空并引发错误? end_points
中的预训练模型恢复到这些新代码结构? 我确信他们必须有充分的理由不移植tf-slim
而是完全放弃它。因此,我对上述方法是否正确感到困惑。