我在pytorch中有这个模型,它基本上是一个多尺度CNN模型,它返回最后为该图像连接的特征图。我不确定如何使用keras在tensorflow 2中编写相同的内容。任何指针表示赞赏。
class VGG16ScaledFeatures(object):
def __init__(self, last_layer=22):
self.vgg16_features = torch.nn.ModuleList(
list(models.vgg16(pretrained=True).features)[:last_layer]
).eval()
def __call__(self, org):
x_ = torch.tensor([])
with torch.no_grad():
for s in range(3):
x = F.max_pool2d(org, (2 ** s, 2 ** s))
for i, f in enumerate(self.vgg16_features):
x = f(x)
if (
(s == 0 and i == 21)
or (s == 1 and i == 14)
or (s == 2 and i == 7)
):
x_ = torch.cat([x_, x], dim=1)
break
x_ = (x_ - x_.mean(dim=(2, 3), keepdim=True)) / x_.std(dim=(2, 3), keepdim=True)
return x_