ONNX 模型拆分完整技巧(工具、场景、坑点、示例代码)
拆分 ONNX ≠ 单纯提速,拆分是手段,提速只是其中一种场景,很多时候拆分反而会变慢。
一、需要拆分 ONNX 模型的核心场景
拆分模型不是单纯为了提速,是解决各类工程痛点的手段,分为性能优化类、工程兼容类、业务调试类三大场景:
(一)拆分后可提升推理速度的场景
- 异构硬件流水线并行设备同时具备 GPU/CPU/NPU/DLA/DSP 多种算力单元,将预处理、主干网络、NMS 后处理分配给不同硬件异步并行执行,硬件资源不闲置,整体推理耗时下降。
- 规避算子不兼容,减少 CPU 降级推理TensorRT、RKNN、Tengine 等端侧加速引擎对 Loop、If、复杂索引、动态 NMS 等算子支持差;拆分后把不兼容算子单独切出交由 CPU 运行,主干网络完整走硬件加速,避免整张模型全部软推理。
- 多任务多头按需推理,减少无效计算一体化多任务模型(检测 + 分割 + 分类)每次推理会执行全部分支;拆分后可按需只运行需要的子模型,跳过无关分支,降低计算量。
- 超大模型分片,解决显存 / 内存溢出BEV、大分割模型一次性加载权重会出现 OOM;拆分后分段加载、推理后释放权重,保证模型正常运行,间接避免卡顿、崩溃带来的效率损失。
(二)拆分不提速,仅解决工程 / 调试问题的场景
- 精度问题定位调试单独提取中间特征子图,对比框架导出与 ONNX 中间输出,快速定位误差层。
- 业务模块化复用主干网络通用,检测头、分割头可单独替换、迭代,无需重新导出完整模型。
- 分层交付与权限隔离模型主干、后处理分开交付,区分不同模块使用权限。
- 自定义算子改造单独切出目标算子所在子图,单独替换为 CUDA / 硬件自定义算子。
- 多工具链适配主干用 TensorRT 编译加速,后处理用 ONNXRuntime 运行,两类工具无法同时处理一张完整模型。
(三)不建议拆分的场景(拆分反而减速)
仅单 GPU 硬件、拆分后串行运行多个子模型,会产生多重损耗:多次创建推理会话、GPU/CPU 数据来回拷贝、算子融合优化碎片化、权重重复加载,推理耗时上升 20%~50%。
二、4 种主流拆分工具对比
| 工具 | 优势 | 适用场景 | 局限 |
|---|---|---|---|
onnx.utils.extract_model(onnx 原生) | 零额外依赖、极简 API | 简单直线图、无分支 / 残差 | 残差跳跃连接、循环子图易报错 |
onnx-graphsurgeon(TensorRT 配套) | 图编辑能力最强,可改节点、重连线、处理残差 | 复杂网络、残差、多分支、跨硬件拆分 | 需额外安装,仅 Python |
onnxslim | 轻量化、支持 CLI 命令行、一键提取子图 | 快速裁剪、删除冗余输出 | 复杂图修改能力弱于 graphsurgeon |
| 框架侧预拆分(PyTorch/Paddle 导出时分段) | 无 ONNX 图修复、张量名天然对齐 | 模型还未导出 onnx 阶段拆分 | 已有的 onnx 文件无法使用 |
三、基础拆分方法 1:原生 onnx extract_model(最简单)
原理
指定子图输入张量名、子图输出张量名,自动提取中间子图,自动携带所需权重 initializerONNX。
示例代码
import onnx # 原始模型、输出子模型 src_onnx = "full_model.onnx" sub_onnx = "backbone_sub.onnx" # 1. 用Netron打开onnx,找到分割边界张量名 # 例如:原图输入是["images"], 分割点张量为"feat_out" sub_inputs = ["images"] sub_outputs = ["feat_out"] # 提取子模型 onnx.utils.extract_model( src_onnx, sub_onnx, input_names=sub_inputs, output_names=sub_outputs ) # 校验子图合法性 model = onnx.load(sub_onnx) onnx.checker.check_model(model) print("子模型拆分完成")适用限制
- 不能切割
If/Loop等带子图的控制流算子; - 残差跳跃连接跨分割边界会丢失张量,直接报错;
- 仅适合单向无分支的简单网络(分类主干)。
四、基础拆分方法 2:OnnxSlim(命令行 + Python,快速裁剪)
CLI 一行拆分(推荐快速调试)
# 只保留从images输入到feat_out输出的子图 onnxslim full.onnx sub.onnx --inputs images --outputs feat_out # 只删除多余输出,保留原图输入 onnxslim full.onnx head.onnx --outputs det_out0,det_out1Python API
import onnxslim onnxslim.slim("full.onnx", "sub.onnx", inputs=["images"], outputs=["feat_out"])五、高级拆分方法 3:onnx-graphsurgeon(复杂网络首选,处理残差 / 多分支)
核心优势
手动遍历张量、重定向输入输出、清理孤立节点、完美兼容残差、跳跃连接、多分支网络,解决 extract_model 残差报错问题。
实战代码:切分 Backbone 与检测头
import onnx import onnx_graphsurgeon as gs # 1. 加载图 model = onnx.load("yolov8_full.onnx") graph = gs.import_onnx(model) tensors = graph.tensors() # 2. 定义分割边界张量(Netron查看) split_tensor = tensors["backbone_feat"] # 主干输出,头分支输入 origin_input = tensors["images"] # ========== 拆分1:主干子图 ========== # 新图输入:原图输入;新图输出:分割张量 graph_backbone = graph.copy() graph_backbone.inputs = [origin_input] graph_backbone.outputs = [split_tensor] # 清理无用节点、权重 graph_backbone.cleanup().toposort() onnx.save(gs.export_onnx(graph_backbone), "backbone.onnx") # ========== 拆分2:检测头子图 ========== graph_head = graph.copy() # 头模型输入改为分割张量,输出保留原图所有输出 graph_head.inputs = [split_tensor] graph_head.cleanup().toposort() onnx.save(gs.export_onnx(graph_head), "det_head.onnx")关键技巧:处理跨分支残差
- 拆分前先
graph.cleanup()移除冗余 Identity; - 所有跳跃连接张量不能跨分割边界,分割线必须放在残差 Add 之后;
- 多输出场景:将所有分支末端张量统一设为子图 output。
六、框架侧预拆分(导出前拆分,最优方案)
在 PyTorch 导出 ONNX 前直接拆分子网络,张量名天然对齐,无图修复问题,推荐新项目使用。
import torch from torchvision import models model = models.resnet50(pretrained=True).eval() dummy = torch.randn(1,3,640,640) # 拆分1:主干 backbone = torch.nn.Sequential(*list(model.children())[:-2]) torch.onnx.export( backbone, dummy, "backbone.onnx", input_names=["img"], output_names=["feat"], opset_version=17 ) # 拆分2:分类头 feat_dummy = torch.randn(1,2048,20,20) head = torch.nn.Sequential(model.avgpool, model.fc) torch.onnx.export( head, feat_dummy, "cls_head.onnx", input_names=["feat"], output_names=["pred"], opset_version=17 )七、工程通用拆分技巧
1. 分割线选择黄金规则
- 禁止切残差 / 跳跃连接中间:分割点放在 Add/Concat 之后,不要切 Shortcut 支路;
- 避开控制流算子内部:If、Loop、Scan 的子图不能被分割线截断;
- 边界张量尽量选 Identity 输出:Netron 中插入 Identity 节点作为分割标记,方便定位;
- 多头模型统一在分支起点分割:多检测头、多分割头从主干输出处一刀切。
2. 拆分前预处理(大幅降低报错)
- 简化模型:
onnxsim full.onnx sim.onnx,消除冗余 Reshape、Identity; - 推理 shape 推导:
onnx.shape_inference.infer_shapes(model),子图 shape 校验; - 外部权重分离:超大模型拆分前导出外部数据,避免 onnx 文件过大:
from onnx.external_data_helper import convert_model_to_external_data model = onnx.load("big.onnx") convert_model_to_external_data(model, location="weights.bin") onnx.save(model, "big_split.onnx")
3. 拆分后校验三板斧
# 1. 格式合法性校验 onnx.checker.check_model(sub_model) # 2. 推理输出对齐(原始图vs子图拼接输出) import onnxruntime as ort # 原图推理 ori_sess = ort.InferenceSession("full.onnx") ori_out = ori_sess.run(None, {"images": rand_input}) # 子图串联推理 b_sess = ort.InferenceSession("backbone.onnx") feat = b_sess.run(None, {"images": rand_input})[0] h_sess = ort.InferenceSession("head.onnx") sub_out = h_sess.run(None, {"backbone_feat": feat}) # 输出误差校验 import numpy as np print(np.allclose(ori_out[0], sub_out[0], atol=1e-5))4. 特殊场景拆分方案
场景 A:前后处理拆分(CPU 预处理,GPU 推理)
- 预处理(Resize/Normalize/Transpose)拆为独立子模型,CPU 执行;
- 推理主干 GPU 执行;
- 后处理 NMS/ArgMax 单独拆分,DLA/CPU 执行;
场景 B:多输出多头拆分
# extract_model同时提取多个输出,拆出单头 onnx.utils.extract_model( "full.onnx", "seg_head.onnx", input_names=["feat"], output_names=["seg_out"] # 只保留分割输出,丢弃检测输出 )场景 C:算子不兼容拆分(如 TensorRT 不支持循环)
- Netron 定位不支持算子的输入输出张量;
- 分割线放在该算子前后,拆出不兼容子图;
- 子图用 ONNXRuntime/CPU 自定义算子推理,其余用 TensorRT。
八、高频报错与解决方案
- extract_model 报错:张量不存在原因:分割线跨残差支路,shortcut 张量未包含在子图; 解决:改用 onnx-graphsurgeon 复制完整图后裁剪,或移动分割点到 Add 后。
- 子图推理 shape 不匹配原因:未做 shape 推理、动态维度冲突; 解决:拆分前执行
infer_shapes,固定输入 shape 或统一 dynamic 维度。 - 拆分后权重丢失原因:initializer 仅被分支使用,裁剪时被清理; 解决:使用
graph.copy()完整复制原图再裁剪,不要直接修改原图。 - Loop/If 子图拆分失败官方限制:分割线不能切割控制流子图,需将整个 Loop 作为一个完整子图拆分。
九、工具选型速记
- 简单直线网络、快速调试 →
onnx.utils.extract_model/onnxslim - ResNet、YOLO、U-Net 带残差 / 多分支 →onnx-graphsurgeon
- 模型还未导出、原生 PyTorch/Paddle → 框架内分段导出(最优)
- 命令行批量拆分脚本 → onnxslim CLI
以上内容,要是对您有用,请给个赞和关注,感谢您的支持。