跳转至

示例:入门介绍

本教程用于介绍 torchview 的 API、图中节点的含义,以及几个最常用的参数组合。

0. 安装

你需要 Graphviz(用于渲染最终图),以及 torchview:

pip install graphviz
pip install torchview

如果系统里还没有 dot 命令,请按你的操作系统安装 Graphviz(例如 Windows 可用 choco install graphviz)。

1. 第一个例子:MLP

下面是一个简单的多层感知机(MLP):

import torch
from torch import nn
from torchview import draw_graph


class MLP(nn.Module):
    def __init__(self, inplace: bool = True) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(inplace),
            nn.Linear(128, 128),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

draw_graph 生成可视化图:

model_graph = draw_graph(
    MLP(),
    input_size=(2, 128),
    graph_name="MLP",
    hide_inner_tensors=False,
    hide_module_functions=False,
)
model_graph.visual_graph

这里我们把 hide_inner_tensors=Falsehide_module_functions=False 打开,是为了让图展示更多细节(更适合学习)。

2. 图里有哪些节点?

torchview 的图里主要有三类节点:

  • Tensor Node:表示张量(输入/中间/输出)
  • Function Node:表示算子/函数调用(例如 torch.relutorch.add
  • Module Node:表示模块调用(例如 nn.Linearnn.Sequential

节点标签通常包含:节点名、层级深度、输入/输出 shape 等信息。

3. Rolling:折叠递归模块

当模型中存在“重复调用同一个模块对象”(例如 RNN 的 cell 反复复用)时,图可能会非常长。

这时可以使用 roll=True 将递归结构折叠

import torch
from torch import nn


class SimpleRNN(nn.Module):
    """一个用于演示 rolling 的简化 RNN(使用 LSTMCell 反复迭代)"""

    def __init__(self, inplace: bool = True) -> None:
        super().__init__()
        self.hid_dim = 2
        self.input_dim = 3
        self.max_length = 4
        self.lstm = nn.LSTMCell(self.input_dim, self.hid_dim)
        self.activation = nn.LeakyReLU(inplace=inplace)

    def forward(self, token_embedding: torch.Tensor) -> torch.Tensor:
        b_size = token_embedding.size()[0]
        hx = torch.randn(b_size, self.hid_dim, device=token_embedding.device)
        cx = torch.randn(b_size, self.hid_dim, device=token_embedding.device)

        for _ in range(self.max_length):
            hx, cx = self.lstm(token_embedding, (hx, cx))
            hx = self.activation(hx)

        return hx

model_graph = draw_graph(
    SimpleRNN(),
    input_size=(2, 3),
    graph_name="RecursiveNet",
    roll=True,
)
model_graph.visual_graph

折叠后,图中边旁边的数字表示该边在 forward 中被“重复使用”的次数;若次数为 1 则通常不会显示。

4. Resize:缩放输出图

如果渲染出来的图过大,可以缩放:

model_graph.resize_graph(scale=0.5)
model_graph.visual_graph

5. 一个常见的小坑(VSCode 渲染裁切)

在 VSCode 的 Notebook 渲染里,较大的 SVG 图可能被裁切。可以改用 PNG:

import graphviz
graphviz.set_jupyter_format("png")