从Keras转换为Pytorch-conv2d

时间:2020-10-21 12:42:45

标签: keras pytorch

我正在尝试将以下Keras代码转换为PyTorch。

    tf.keras.Sequential([
          Conv2D(128, 1, activation=tf.nn.relu),
          Conv2D(self.channel_n, 1, activation=None),
    ])

使用self.channels = 16创建模型摘要时,我得到以下摘要。

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (1, 3, 3, 128)            6272      
_________________________________________________________________
conv2d_1 (Conv2D)            (1, 3, 3, 16)             2064      
=================================================================
Total params: 8,336
Trainable params: 8,336
Non-trainable params: 0

一个人如何转换?

我曾经这样尝试过

import torch
from torch import nn

class CellCA(nn.Module):
    def __init__(self, channels, dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=channels,out_channels=dim, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=dim, out_channels=channels, kernel_size=1),
        )
    def forward(self, x):
        return self.net(x)

但是,我得到了4240个参数

1 个答案:

答案 0 :(得分:0)

如果正确配置了初始通道(在这种情况下为48个),则上述尝试是正确的。