如何知道火炬模型的输入形状?

时间:2019-10-10 12:11:30

标签: machine-learning deep-learning conv-neural-network pytorch

我有一个列表,列表中包含100个形状为(20,48)的矩阵,我想在pytorch中传递该矩阵。

这是示例代码

import torch.nn.functional as F
import torch.nn as nn
import torch

sample = torch.randn(100,20,48)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv1d(20, 48, kernel_size=2)
    def forward(self, x):
        return self.conv(x)

net = Net()

for i in net.state_dict().keys():
    print(i)

for i in list(net.parameters()):
    print(i.shape)

#output

conv.weight
conv.bias

torch.Size([48, 20, 2])
torch.Size([48])

如何检查我的模型是否采用了特定形状的输入?就我而言,如何确定我的模型输入转换层是否采用大小为(bs,20,48)的矩阵?

0 个答案:

没有答案