我正在尝试将以下Keras代码转换为PyTorch。
tf.keras.Sequential([
Conv2D(128, 1, activation=tf.nn.relu),
Conv2D(self.channel_n, 1, activation=None),
])
使用self.channels = 16创建模型摘要时,我得到以下摘要。
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (1, 3, 3, 128) 6272
_________________________________________________________________
conv2d_1 (Conv2D) (1, 3, 3, 16) 2064
=================================================================
Total params: 8,336
Trainable params: 8,336
Non-trainable params: 0
一个人如何转换?
我曾经这样尝试过
import torch
from torch import nn
class CellCA(nn.Module):
def __init__(self, channels, dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=channels,out_channels=dim, kernel_size=1),
nn.ReLU(),
nn.Conv2d(in_channels=dim, out_channels=channels, kernel_size=1),
)
def forward(self, x):
return self.net(x)
但是,我得到了4240个参数
答案 0 :(得分:0)
如果正确配置了初始通道(在这种情况下为48个),则上述尝试是正确的。