如何将此代码从tensorflow 1升级到tensorflow 2

时间:2020-08-26 03:33:59

标签: python tensorflow

我有以下代码:


    def _separable_conv(features, depth, kernel_size, depth_multiplier,
                        regularize_depthwise, rate, stride, scope):
      if activation_fn_in_separable_conv:
        activation_fn = tf.nn.relu
      else:
        activation_fn = None
        features = tf.nn.relu(features)
      return separable_conv2d_same(features,
                                   depth,
                                   kernel_size,
                                   depth_multiplier=depth_multiplier,
                                   stride=stride,
                                   rate=rate,
                                   activation_fn=activation_fn,
                                   regularize_depthwise=regularize_depthwise,
                                   scope=scope)
    for i in range(3):
      residual = _separable_conv(residual,
                                 depth_list[i],
                                 kernel_size=3,
                                 depth_multiplier=1,
                                 regularize_depthwise=regularize_depthwise,
                                 rate=rate*unit_rate_list[i],
                                 stride=stride if i == 2 else 1,
                                 scope='separable_conv' + str(i+1))
    if skip_connection_type == 'conv':
      shortcut = tf.Conv2D(inputs,
                             depth_list[-1],
                             [1, 1],
                             stride=stride,
                             activation_fn=None,
                             scope='shortcut')
      outputs = residual + shortcut
    elif skip_connection_type == 'sum':
      outputs = residual + inputs
    elif skip_connection_type == 'none':
      outputs = residual
    else:
      raise ValueError('Unsupported skip connection type.')

    return slim.utils.collect_named_outputs(outputs_collections,
                                            sc.name,
                                            outputs)

在最后一行中,我们使用了tf.contrib中的slim模块,该模块在tensorflow 2中已弃用。在tensorflow 2中存在哪些功能,或者其他功能与sli​​m.utils.collect_named_outputs行相同? / p>

1 个答案:

答案 0 :(得分:0)

TF-slim现在可作为Github上的外部程序包使用,并且支持Tensorflow2。该库具有相同的确切功能(包括此方法!),它只有一个新的目录和一个不同的安装方式。

核心库中没有Tensorflow 2代码可以直接替换您的代码。