Advertisement

pytorch 绘制Depth Anything网络结构

阅读量:

通过 pytorch 可以绘制模型网络结构的方法有很多种。对于个人而言,我更倾向于使用 torchview 软件所呈现出来的 Graphviz 风格图像。

Graphviz 介绍

该软件名为Graphviz,并源自"Graph Visualization Software"这一缩写形式;采用基于DOT语言的描述机制来构建图结构;具备生成复杂网络图的能力,并提供多样化的功能模块及核心组件;其自动布局技术显著提升了用户的绘图效率与个性化定制能力;广泛应用于多个技术领域如软件工程、数据可视化及学术研究等;官方文档和技术资源均可在官网获取进一步支持;

安装也比较简单,参考:

[mert-kurttutan/torchview: torchview: visualize pytorch models (github.com)

icon-default.png?t=N7T8

v github.com/mert-kurttutan/torchview](mert-kurttutan开发的 torchview 是一个用于可视化 PyTorch 模型的工具(参考 GitHub))

为了使graphviz的Python接口正常运行,在你的系统中需要配置dot layout命令。如果尚未安装该接口,请使用相应的command行指令进行设置。

安装 graphviz

以Windows系统为例,以管理员权限打开 PowerShell ,然后该命令****

choco install graphviz

安装pip包

接下来可以切到你的pytorch conda环境下,安装pip包

复制代码
 pip install -i https://pypi.tuna.tsinghua.edu.cn/simple graphviz

    
 pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torchview
    
    
    
    
    bash

绘制 resent18 网络结构

以resent18来实验绘制网络结构

复制代码
 from torchvision.models import resnet18

    
 from torchview import draw_graph
    
  
    
 # 将图表保存到本地
    
 model_graph = draw_graph(resnet18(), input_size=(1,3,32,32), expand_nested=True, save_graph = True, filename = "resnet18"))
    
 # 可视化显示
    
 model_graph.visual_graph
    
    
    
    
    python

将展示出resent18的网络结构


绘制 depth_anything_v2 vits 网络结构

depth_anything_v2 的安装这里就不讲了,重点看 网络结构的绘制代码

复制代码
 import torch

    
 from torchview import draw_graph
    
 from depth_anything_v2.dpt import DepthAnythingV2
    
  
    
 DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
    
  
    
 model_configs = {
    
     'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    
     'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
    
     'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
    
     'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
    
 }
    
  
    
 encoder = 'vits' # or 'vitl', 'vitb', 'vitg'
    
  
    
 model = DepthAnythingV2(**model_configs[encoder])
    
 model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location='cpu'))
    
  
    
 # 绘制并保存模型架构图
    
 model_graph = draw_graph(model, input_size=(1,3,518,518), expand_nested=True, save_graph = True, filename = "vits")
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/9J0Iaz5hZEgrpdOcSixRK6qVCl4n.png)

展示的这个便是 vits 的网络结构图

全部评论 (0)

还没有任何评论哟~