跳转至

torchview

Python 3.7+ PyPI version Conda version Build Status GitHub license codecov Downloads

banner

Torchview 用于将 PyTorch 模型可视化为计算图(visual graph)。可视化内容包括:tensor、module、torch 函数调用,以及输入/输出 shape 等信息。

它可以理解为 PyTorch 版本的 Keras plot_model(并且支持更多细节)。

支持 PyTorch 版本:(\geq 1.7)。

主要特性

Useful FeaturesUseful Features

安装

首先需要安装 Graphviz:

pip install graphviz

为了让 Graphviz 的 Python 接口正常工作,你的系统中需要能够调用 dot 布局命令。如果尚未安装 Graphviz,建议按操作系统安装:

Debian 系 Linux(如 Ubuntu):

apt-get install graphviz

Windows:

choco install graphviz

macOS:

brew install graphviz

更多细节可参考:Graphviz 文档

然后用 pip 安装 torchview:

pip install torchview

或者用 conda:

conda install -c conda-forge torchview

如果你想安装最新版本,可以直接从仓库安装:

pip install git+https://github.com/mert-kurttutan/torchview.git

快速使用

from torchview import draw_graph

model = MLP()
batch_size = 2
# device='meta' -> 可视化时不会消耗实际显存/内存(只做结构推导)
model_graph = draw_graph(model, input_size=(batch_size, 128), device='meta')
model_graph.visual_graph

output

Notebook 示例

更多示例可参考下面的 Colab:

入门介绍: Introduction

计算机视觉模型: Vision

NLP 模型: NLP

注意: torchview 的 Graphviz 可视化会返回“适配尺寸”的图像。但在 VSCode 中,由于 SVG 渲染与画布尺寸的限制,较大的图可能出现裁切。可通过以下方式改用 PNG 渲染:

import graphviz
graphviz.set_jupyter_format('png')

该问题在 JupyterLab、Google Colab 等平台通常不会出现。

支持的能力

  • 支持多数常见模型:RNN、Sequential、跳连(Skip Connection)、Hugging Face 模型等
  • 支持 Meta Tensor,可在可视化超大模型时做到几乎不消耗内存(PyTorch (\geq 1.13))
  • 除了 module 调用,还能显示 tensor 之间的算子操作
  • 支持 Rolling/Unrolling:可将递归调用的模块在图上“折叠/展开”(见下方示例)
  • 支持多种输入/输出类型:如嵌套结构(dict/list 等)、Hugging Face tokenizer 输出等

API 文档

def draw_graph(
    model: nn.Module,
    input_data: INPUT_DATA_TYPE | None = None,
    input_size: INPUT_SIZE_TYPE | None = None,
    graph_name: str = 'model',
    depth: int | float = 3,
    device: torch.device | str | None = None,
    dtypes: list[torch.dtype] | None = None,
    mode: str | None = None,
    strict: bool = True,
    expand_nested: bool = False,
    graph_dir: str | None = None,
    hide_module_functions: bool = True,
    hide_inner_tensors: bool = True,
    roll: bool = False,
    show_shapes: bool = True,
    save_graph: bool = False,
    filename: str | None = None,
    directory: str = '.',
    **kwargs: Any,
) -> ComputationGraph:
    '''返回输入 PyTorch Module 的可视化表示(ComputationGraph)。
    ComputationGraph 包含:

    1) 根节点(通常是输入 tensor 节点),它连接到 forward 过程中记录的所有其它节点

    2) `graphviz.Digraph` 对象,用于承载计算图的可视化表示。图中会展示 module/module 层级、
    torch 函数、shape,以及 forward 中记录到的 tensor。相关示例可见文档与 Colab notebook。

    Args:
        model (nn.Module):
            需要可视化的 PyTorch 模型。

        input_data (包含 torch.Tensor 的数据结构):
            作为模型 forward 的输入。多个位置参数可放入 list;
            或用 dict / kwargs 形式传入。

        input_size (shape 序列):
            输入数据的 shape(list/tuple/torch.Size)。如果给了 `input_size`,
            那么 `dtypes` 需要与模型输入一致(默认使用 FloatTensor)。
            Default: None

        graph_name (str):
            Graphviz `Digraph` 的名称,也会作为默认输出文件名。
            Default: 'model'

        depth (int):
            可视化中节点展示的最大深度。深度定义为:节点在模块层级中的“嵌套层数”。
            例如主模块 depth=0,主模块的子模块 depth=1,以此类推。
            Default: 3

        device (str or torch.device):
            放置输入 tensor 的 device。若未指定:
            - PyTorch 检测到 CUDA 则使用 GPU
            - 否则使用 CPU
            Default: None

        dtypes (list[torch.dtype]):
            当提供 `input_size` 时,用 `dtypes` 设置输入 tensor 的 dtype。

        mode (str):
            forward 传播时使用的模型模式;未指定则默认用 eval。
            Default: None

        strict (bool):
            如果为 true,则 Graphviz 可视化不允许同一对节点之间出现多条边。
            多条边可能发生在:module 节点之间同时存在 tensor 边,但你又选择隐藏这些 tensor 时。
            Default: True

        expand_nested(bool):
            如果为 true,则用虚线边框展示嵌套模块。

        graph_dir (str):
            设置图的方向:
            'TB' -> 从上到下
            'LR' -> 从左到右
            'BT' -> 从下到上
            'RL' -> 从右到左
            Default: None -> TB

        hide_module_function (bool):
            是否隐藏 module 内部的 torch function。部分模块只由 torch function 构成(无子模块),
            例如 `nn.Conv2d`。
            True => 不在图中展示 module functions
            False => 在图中展示 module functions
            Default: True

        hide_inner_tensors (bool):
            inner tensor 指除输入/输出外,在计算图内部流转的 tensor。
            True => 不展示 inner tensors
            False => 展示 inner tensors
            Default: True

        roll (bool):
            若为 true,则折叠递归模块(Rolling)。
            Default: False

        show_shapes (bool):
            True => 展示 tensor 的 shape(含输入/输出)
            False => 不展示 shape
            Default: True

        save_graph (bool):
            True => 保存 Graphviz 输出文件
            False => 不保存
            Default: False

        filename (str):
            保存 dot 语法与图像文件时使用的文件名;默认等于 graph_name。

        directory (str):
            保存 Graphviz 输出文件的目录。
            Default: .

    Returns:
        ComputationGraph:包含 Graphviz `Digraph` 的计算图对象。
    '''

示例

递归网络的折叠(Rolled Version)

from torchview import draw_graph

model_graph = draw_graph(
    SimpleRNN(), input_size=(2, 3),
    graph_name='RecursiveNet',
    roll=True
)
model_graph.visual_graph

rnns

显示/隐藏中间(hidden)tensor 与 functionals

# Show inner tensors and Functionals
model_graph = draw_graph(
    MLP(), input_size=(2, 128),
    graph_name='MLP',
    hide_inner_tensors=False,
    hide_module_functions=False,
)

model_graph.visual_graph

download

ResNet / 跳连 / 支持 torch 运算 / 嵌套模块展示

import torchvision

model_graph = draw_graph(resnet18(), input_size=(1,3,32,32), expand_nested=True)
model_graph.visual_graph

expand_nested_resnet_model gv

TODO

  • [ ] 展示 Module 的参数信息(parameter info)
  • [ ] 支持图神经网络(GNN)
  • [ ] 为 GNN 支持无向边
  • [ ] 支持 torch-based functions[^1]

[^1]: 这里的 torch-based functions 指“只使用 torch 函数和模块实现的函数”。该概念比 module 更泛化。

贡献指南

中文版本由 @1985312383(GitHub)友情提供。

我们非常欢迎 issue 与 PR!如果你想了解如何构建本项目:

  • torchview 使用最新 Python 版本进行活跃开发。
  • 改动需要保持对 Python 3.7 的向后兼容,并遵循 Python 的旧版本生命周期策略。
  • 运行 pip install -r requirements-dev.txt 安装开发依赖(我们使用最新的 dev 包版本)。
  • 单测:运行 pytest
  • 更新期望输出:运行 pytest --overwrite
  • 跳过输出文件测试:运行 pytest --no-output

参考

  • 输入处理与校验相关部分借鉴/参考了 torchinfo 仓库
  • 软件工程相关部分(如测试)也借鉴了 torchinfo(感谢 @TylerYep)
  • 计算图构建算法得益于 __torch_function__torch.Tensor 的 subclass 机制(感谢相关贡献者)