在保存模型(例如我的HRNet)时,遇到以下错误: ValueError:无法保存对象ListWrapper([ListWrapper([None])])
我的某些图层看起来像(完整版本:https://github.com/zheLim/auto-face-parsing/blob/master/lib/model/hrnet_blocks.py):
class MultiResolutionLayer(layers.Layer):
def __init__(self, n_channels_list, bn_momentum=0.01, activation='relu'):
"""
fuse feature from different branch with adding
:param n_branches:
:param n_channels:
:param multi_scale_output:
"""
super(MultiResolutionLayer, self).__init__()
self.n_branches = len(n_channels_list)
self.fuse_layers = [[] for branch_i in range(self.n_branches)]
for branch_i in range(self.n_branches):
layer = []
for branch_j in range(self.n_branches):
if branch_i < branch_j:
# resolution of branch i is greater than branch_j
# branch_j will be upsample with nearest resize
layer.append(keras.Sequential(
[layers.Conv2D(filters=n_channels_list[branch_i], kernel_size=1, strides=1, padding='same',
use_bias=False, activation=activation),
layers.BatchNormalization(momentum=bn_momentum)]))
elif branch_i == branch_j:
# branch i is branch_j
layer.append(None)
错误消息建议:“如果不需要此列表经过检查的点,请将其包装在tf.contrib.checkpoint.NoDependency对象中;它将自动展开并随后被忽略。”但是,tf 2.0没有contri模块。 有人知道如何解决这个问题吗?
答案 0 :(得分:3)
我通过切换到PyTorch解决了这个问题