发布时间:2026/7/3 4:09:40
bash# 创建虚拟环境,安装必要依赖conda create -n ai_stack python=3.9conda activate ai_stackpip install torch==2.1.0 torchvision==0.16.0 onnx==1.15.0 onnxruntime==1.17.0 numpy==1.26.0### 2. 训练一个简单的分类模型我们使用经典的 MNIST 数据集的子集,训练一个两层卷积网络。注意:这里故意简化数据量(只取 1000 张图),方便快速运行。python# train_model.pyimport torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoader, Subset# 0. 定义网络(兼容 ONNX 导出:避免使用动态控制流)class SimpleCNN(nn.Module): def __init__(self, num_classes=10): super().__init__() self.features = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), ) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(32 * 7 * 7, 64), nn.ReLU(), nn.Linear(64, num_classes), ) def forward(self, x): x = self.features(x) x = self.classifier(x) return x# 1. 加载数据(仅取前1000张训练,200张测试)transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])full_train = datasets.MNIST('./data', train=True, download=True, transform=transform)full_test = datasets.MNIST('./data', train=False, download=True, transform=transform)train_loader = DataLoader(Subset(full_train, indices=range(1000)), batch_size=32, shuffle=True)test_loader = DataLoader(Subset(full_test, indices=range(200)), batch_size=32, shuffle=False)# 2. 训练device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = SimpleCNN().to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)print("开始训练...")for epoch in range(3): model.train() for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 简单验证 model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Epoch {epoch+1}, Test Acc: {100 * correct / total:.2f}%")# 保存 PyTorch 权重torch.save(model.state_dict(), 'model.pth')print("模型权重保存至 model.pth")### 3. 导出为 ONNXONNX 导出时需要指定输入张量的形状(动态批处理使用dynamic_axes)。注意:导出前必须将模型设置为 eval 模式。python# export_onnx.pyimport torchfrom train_model import SimpleCNN, devicemodel = SimpleCNN()model.load_state_dict(torch.load('model.pth', map_location='cpu'))model.eval()# 构造 dummy 输入:batch_size=1, channel=1, height=28, width=28dummy_input = torch.randn(1, 1, 28, 28)# 定义动态轴:batch 维度可变dynamic_axes = { 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}torch.onnx.export( model, dummy_input, 'model.onnx', input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes, opset_version=17, do_constant_folding=True)print("ONNX 模型导出成功:model.onnx")### 4. 使用 ONNX Runtime 进行推理ONNX Runtime 支持 CPU/GPU,且无需安装 PyTorch。下面代码展示加载 ONNX 模型并对单张图片进行推理。python# inference_onnx.pyimport onnxruntime as ortimport numpy as npfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoader, Subsetimport time# 加载测试数据中的一张图(仅用于演示)transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])full_test = datasets.MNIST('./data', train=False, download=True, transform=transform)test_loader = DataLoader(Subset(full_test, indices=range(1)), batch_size=1)image, label = next(iter(test_loader))# 转成 numpy,并确保维度顺序为 NCHWinput_np = image.numpy().astype(np.float32)# 创建 ONNX Runtime 会话session = ort.InferenceSession('model.onnx')input_name = session.get_inputs()[0].nameoutput_name = session.get_outputs()[0].name# 推理(预热)_ = session.run([output_name], {input_name: input_np})# 正式计时(模拟批处理)batch_np = np.repeat(input_np, 10, axis=0) # batch_size=10start = time.time()for _ in range(100): outputs = session.run([output_name], {input_name: batch_np})end = time.time()print(f"ONNX Runtime 推理 100 批 (batch=10) 耗时: {(end-start)*1000:.2f} ms")pred = np.argmax(outputs[0], axis=1)print(f"预测标签: {pred[:5]}, 真实标签: {label.item()}")### 5. 对比 PyTorch 直接推理(可选)作为对照,我们也可以将导出的 ONNX 模型与 PyTorch 推理速度对比。但注意:ONNX Runtime 通常比 PyTorch eager 模式快 1.5~3 倍(尤其是 CPU 环境)。python# compare.py(附加代码,不单独运行)import torchfrom train_model import SimpleCNN, device# PyTorch 推理model_pt = SimpleCNN()model_pt.load_state_dict(torch.load('model.pth', map_location='cpu'))model_pt.eval()with torch.no_grad(): start = time.time() for _ in range(100): _ = model_pt(torch.from_numpy(batch_np)) end = time.time()print(f"PyTorch 推理 100 批耗时: {(end-start)*1000:.2f} ms")>说明:上述代码直接粘贴即可运行(需先运行train_model.py和export_onnx.py)。如果 GPU 可用,可安装onnxruntime-gpu获得加速。—## 四、总结与最佳实践### 1. 技术栈选型建议-个人/小团队:PyTorch + ONNX Runtime 组合最灵活,学习成本低。-企业生产:训练用 PyTorch,部署用 NVIDIA Triton Inference Server(支持多框架多模型并行)。-边缘端:考虑 TensorRT(NVIDIA 设备)或 OpenVINO(Intel 设备),或者直接使用 TFLite(移动端/嵌入式)。-量化加速:ONNX Runtime 内置动态、静态量化工具,FP16/INT8 推理可提升 2~4 倍速度。### 2. 导出 ONNX 的常见陷阱-动态控制流(如if语句依赖输入)会导致导出失败,需提前改写成静态形式(如使用torch.where)。-自定义算子:ONNX 不支持 PyTorch 所有算子,需要注册或 fallback。建议先在torch.onnx.export中设置verify=True检查。-输入输出动态形状:务必设置dynamic_axes,否则导出的模型只能固定 batch 大小。### 3. 可操作建议1.从简单模型开始:不要一上来就尝试 YOLO / GPT,先跑通上面的 MNIST 案例,理解整个链路。2.使用版本锁定:ONNX 生态与框架版本强相关,建议在requirements.txt中固定 torch、onnx、onnxruntime 版本。3.压测找瓶颈:用 ONNX Runtime 的session.get_providers()检查是否真的使用了 GPU(CUDAExecutionProvider)。CPU 环境下可以尝试IExecutionProvider的enable_cpu_mem_arena=False以降低内存碎片。4.拥抱容器化:用 Docker 打包运行时环境,避免不同机器 CUDA/系统库版本冲突。推荐用nvidia/cuda:11.8-runtime-ubuntu22.04作为基础镜像。### 4. 未来的扩展当你掌握了基础链路,可以进一步学习:-模型服务化:使用FastAPI 包装 ONNX Runtime 推理,配合 Docker 对外提供 REST API。-自动化量化:ONNX Runtime 的 QAT(量化感知训练) 在保持精度的同时大幅提速。-多 GPU 推理:用 Triton Inference Server 管理多进程推理,吞吐量翻倍。AI 技术栈的深度远不止上述内容,但“训练→导出→部署”这一闭环是所有 AI 工程师必须掌握的硬技能。不要被碎片化的工具吓倒,用最小的闭环启动,在实践中扩展认知——这就是最优的成长路径。