在TensorFlow 2.x中实现类似tf-slim的界面

时间:2019-06-02 12:37:49

标签: python tensorflow tf-slim

请考虑以下假设但非常实际的情况:

I have to create a network which uses two pre-trained networks `A`
 and `B`. I would like to concatenate the output of  `Layer-L` of 
`A` with the output of `Layer-M` of `B` and do further computations
 on the result.

现在在TensorFlow 1.x中,我可以使用tf-slim库(我当然假设AB的预训练模型和代码在{ {1}})。我们知道tf-slim提供了tf-slim字典。创建自己的网络时,我可以利用end_points建立任意连接。

转到TensorFlow 2.x,没有end_points我的问题是,通过尝试将tf-slim特定代码移植到TensorFlow 2.x,以下内容是否是一种良好的实现做法

作为示例,如果我尝试在TensorFlow 2.x中移植tf-slim网络的tf-slim代码,则将其实现为VGG16类的子类。下面是一个示例:

tensorflow.keras.Model

然后我尝试使用以下代码段测试上述代码:

import tensorflow as tf
"""
The following class RepeatLayer is a basic implementation of 
slim.repeat class in TensorFlow-slim.
"""

class RepeatLayer(tf.keras.layers.Layer):
    def __init__(self, layerobj, count, layernames, **kwargs):
        """
        class instantiator
        :param layerobj: A tf.keras.layers object (e.g:- tf.keras.layers.Conv2D)
        :param count: (int) Number of times layerobj should be repeated.
        :param layernames: (List of length count) Name of each repeatition of layerobj
        :param kwargs: layerobj specific named arguments.
        """
        super(RepeatLayer, self).__init__()
        if not isinstance(count, int):
            raise TypeError('The argument "count" must be a positive integer.')

        if count <= 0:
            raise ValueError('The argument "count" is provided as {}. It must be a'
                             'positive integer.'.format(count))

        if not isinstance(layernames, list):
            raise TypeError('The argument "layernames" must be a list of strings.')

        if not len(layernames) == count:
            raise ValueError('The length of "layernames" must be the value of "count".')

        for name, value in kwargs.items():
            if not isinstance(value, list):
                value = [value] * count
                kwargs[name] = value

        self._layernames = layernames
        self._count = count
        self._end_points = dict()
        self._outputs = []
        self._kwargs = kwargs
        for layernum in range(self._count):
            args = dict(
                map(
                    lambda x: (x[0], x[1][layernum]),
                    self._kwargs.items()
                )
            )
            output = layerobj(**args)
            self._outputs.append(output)

    def call(self, input_tensor):
        out = input_tensor
        for node, layername in zip(self._outputs, self._layernames):
            out = node(out)
            self._end_points[layername] = out
        return out

    @property
    def end_points(self):
        return self._end_points

因此,from utils import RepeatLayer import tensorflow as tf layer = RepeatLayer(tf.keras.layers.Conv2D, 3, ['c1', 'c2', 'c3'], filters=64, kernel_size=7, strides=1, padding='same', activation='relu', data_format='channels_last', use_bias=True) input_tensor = tf.keras.backend.random_uniform(shape=(3,512,512,3)) print(layer.end_points) # should print an empty dictionary output_tensor = layer(input_tensor) print(layer.end_points) # Prints a dictionary in which values are output tensors 字典只能在对模型进行至少一个end_points的访问之后才能使用。由于默认的急切执行,这是我认为的预期行为。

我现在面临的问题如下:

  1. 从实现的角度来看,这是正确的方法吗?也就是说,尝试将模型结构代码从__call__移植到TensorFlow 2.x吗?
  2. 如果我正在使用此方法,并且想将所有内容包装在tf-slim中以进行静态执行以提高效率,那么这行得通吗?我没有尝试过,但是如果没有@tf.function,TensorFlow会不会发现__call__字典为空并引发错误?
  3. 如何将end_points中的预训练模型恢复到这些新代码结构?

我确信他们必须有充分的理由不移植tf-slim而是完全放弃它。因此,我对上述方法是否正确感到困惑。

0 个答案:

没有答案