如何在Pytorch中可视化网络?

时间:2018-09-23 18:15:05

标签: python pytorch

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

我想从pytorch模型中看到resnet。我该怎么做?我尝试使用torchviz,但出现错误:

'ResNet' object has no attribute 'grad_fn'

5 个答案:

答案 0 :(得分:11)

以下是使用不同工具的三种不同图形可视化效果。

为了生成可视化示例,我将使用一个简单的RNN来执行对online tutorial的情感分析:

class RNN(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

        super().__init__()
        self.embedding  = nn.Embedding(input_dim, embedding_dim)
        self.rnn        = nn.RNN(embedding_dim, hidden_dim)
        self.fc         = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):

        embedding       = self.embedding(text)
        output, hidden  = self.rnn(embedding)

        return self.fc(hidden.squeeze(0))

如果您print()模型,则为输出。

RNN(
  (embedding): Embedding(25002, 100)
  (rnn): RNN(100, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

下面是来自三种不同的可视化工具的结果。

对于所有这些对象,您都需要具有可通过模型的forward()方法传递的虚拟输入。获取此输入的一种简单方法是从您的Dataloader中检索批次,如下所示:

batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().

Torchviz

https://github.com/szagoruyko/pytorchviz

我相信此工具使用向后传递来生成其图形,因此所有盒子都使用PyTorch组件进行反向传播。

from torchviz import make_dot

make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

此工具产生以下输出文件:

torchviz output

这是唯一清楚提及我的模型中的三层embeddingrnnfc的输出。操作员名称是从后向传递过来的,因此其中一些难以理解。

HiddenLayer

https://github.com/waleedka/hiddenlayer

我相信,该工具使用前向通过。

import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')

这是输出。我喜欢蓝色的阴影。

hiddenlayer output

我发现输出内容太多,使我的体系结构变得模糊。例如,为什么unsqueeze被提及很多次?

Netron

https://github.com/lutzroeder/netron

此工具是Mac,Windows和Linux的桌面应用程序。它依赖于首先导出到ONNX format中的模型。然后,应用程序读取ONNX文件并进行渲染。然后可以选择将模型导出到图像文件。

input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)

这是应用程序中模型的外观。我认为该工具非常漂亮:您可以缩放和平移,还可以深入研究图层和运算符。我发现的唯一缺点是它只能进行垂直布局。

Netron screenshot

答案 1 :(得分:7)

如果要保存图像,请按照以下步骤用torchviz进行操作:

# http://www.bnikolic.co.uk/blog/pytorch-detach.html

import torch
from torchviz import make_dot

x=torch.ones(10, requires_grad=True)
weights = {'x':x}

y=x**2
z=x**3
r=(y+z).sum()

make_dot(r).render("attached", format="png")

您获得的图像的屏幕截图:

enter image description here

来源:http://www.bnikolic.co.uk/blog/pytorch-detach.html

答案 2 :(得分:1)

您可以使用TensorBoard进行可视化。 现在,PyTorch 1.2.0版完全支持TensorBoard。 更多信息: https://pytorch.org/docs/stable/tensorboard.html

答案 3 :(得分:0)

make_dot需要一个变量(即带有grad_fn的张量),而不是模型本身。
尝试:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module

答案 4 :(得分:0)

您可以查看PyTorchViz(https://github.com/szagoruyko/pytorchviz),“一个用于创建PyTorch执行图和轨迹可视化的小程序包。”

Example PyTorchViz visualization