无法使用tensorflow 2.0保存模型

时间:2019-12-11 11:54:44

标签: tensorflow2.0

在保存模型(例如我的HRNet)时,遇到以下错误: ValueError:无法保存对象ListWrapper([ListWrapper([None])])

我的某些图层看起来像(完整版本:https://github.com/zheLim/auto-face-parsing/blob/master/lib/model/hrnet_blocks.py):

class MultiResolutionLayer(layers.Layer):
def __init__(self, n_channels_list, bn_momentum=0.01, activation='relu'):
    """
    fuse feature from different branch with adding
    :param n_branches:
    :param n_channels:
    :param multi_scale_output:
    """
    super(MultiResolutionLayer, self).__init__()
    self.n_branches = len(n_channels_list)
    self.fuse_layers = [[] for branch_i in range(self.n_branches)]
    for branch_i in range(self.n_branches):
        layer = []
        for branch_j in range(self.n_branches):
            if branch_i < branch_j:
                # resolution of branch i is greater than branch_j
                # branch_j will be upsample with nearest resize
                layer.append(keras.Sequential(
                    [layers.Conv2D(filters=n_channels_list[branch_i], kernel_size=1, strides=1, padding='same',
                                   use_bias=False, activation=activation),
                     layers.BatchNormalization(momentum=bn_momentum)]))
            elif branch_i == branch_j:
                # branch i is branch_j
                layer.append(None)

错误消息建议:“如果不需要此列表经过检查的点,请将其包装在tf.contrib.checkpoint.NoDependency对象中;它将自动展开并随后被忽略。”但是,tf 2.0没有contri模块。 有人知道如何解决这个问题吗?

1 个答案:

答案 0 :(得分:3)

我通过切换到PyTorch解决了这个问题