掌握PyTorch模型保存与加载:从训练到部署的完整指南

掌握PyTorch模型保存与加载:从训练到部署的完整指南

pytorch模型加载时,需要先定义模型结构,再加载保存的state_dict参数。这是因为pytorch通常只保存模型参数而非整个模型对象,以避免python对象序列化问题。本文将详细介绍如何分离模型的训练、保存与加载推理过程,并通过示例代码演示这一标准实践,帮助用户高效复用预训练模型。

在PyTorch中,将训练好的模型保存到磁盘并在后续加载进行推理是机器学习工作流中的常见需求。初学者常遇到的一个困惑是:加载模型时是否必须重新定义模型的完整结构?答案是肯定的,且这是PyTorch推荐的标准实践。本教程将深入探讨PyTorch的模型保存与加载机制,并提供清晰的示例代码,指导您如何正确地分离模型的训练、保存与推理过程。

理解PyTorch模型保存机制

PyTorch模型(nn.Module的实例)的保存通常有两种主要方式:

保存整个模型(不推荐):使用 torch.save(model, “model.pth”)。这种方法会保存整个模型对象,包括其结构和所有参数。然而,它依赖于Python的pickle模块进行序列化。当模型定义所在的类、包或文件结构发生变化时,或者在不同Python版本、PyTorch版本之间加载时,可能会遇到兼容性问题和序列化错误。因此,这种方法通常不被推荐用于生产环境或长期存储。保存模型的state_dict(推荐):使用 torch.save(model.state_dict(), “model.pth”)。state_dict是一个Python字典,它存储了模型中所有可学习参数(如权重和偏置)的映射。这种方式只保存参数,而模型的结构定义则需要独立存在。加载时,您需要先实例化一个具有相同结构的模型对象,然后将state_dict加载到这个新创建的对象中。这种方法更加健壮、灵活,且不易受环境变化的影响。

核心思想是: 模型结构(由nn.Module类定义)与模型参数(存储在state_dict中)是分离的。当您保存state_dict时,您只是保存了模型学到的“知识”,而模型的“骨架”——其架构定义——则需要在加载时重新提供。

模型训练与保存示例

为了演示这一过程,我们将使用一个简单的神经网络在FashionMNIST数据集上进行训练,并保存其state_dict。

Stable Diffusion 2.1 Demo Stable Diffusion 2.1 Demo

最新体验版 Stable Diffusion 2.1

Stable Diffusion 2.1 Demo 101 查看详情 Stable Diffusion 2.1 Demo

首先,我们需要设置环境、定义模型、数据加载器以及训练和测试函数。

# train_model.pyimport torchfrom torch import nnfrom torch.utils.data import DataLoaderfrom torchvision import datasetsfrom torchvision.transforms import ToTensor# 1. 准备数据training_data = datasets.FashionMNIST(    root="data",    train=True,    download=True,    transform=ToTensor(),)test_data = datasets.FashionMNIST(    root="data",    train=False,    download=True,    transform=ToTensor(),)batch_size = 64train_dataloader = DataLoader(training_data, batch_size=batch_size)test_dataloader = DataLoader(test_data, batch_size=batch_size)# 2. 获取设备device = (    "cuda"    if torch.cuda.is_available()    else "mps"    if torch.backends.mps.is_available()    else "cpu")print(f"Using {device} device")# 3. 定义模型class NeuralNetwork(nn.Module):    def __init__(self):        super().__init__()        self.flatten = nn.Flatten()        self.linear_relu_stack = nn.Sequential(            nn.Linear(28*28, 512),            nn.ReLU(),            nn.Linear(512, 512),            nn.ReLU(),            nn.Linear(512, 10)        )    def forward(self, x):        x = self.flatten(x)        logits = self.linear_relu_stack(x)        return logitsmodel = NeuralNetwork().to(device)print(model)# 4. 定义损失函数和优化器loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)# 5. 训练函数def train(dataloader, model, loss_fn, optimizer):    size = len(dataloader.dataset)    model.train()    for batch, (X, y) in enumerate(dataloader):        X, y = X.to(device), y.to(device)        pred = model(X)        loss = loss_fn(pred, y)        optimizer.zero_grad()        loss.backward()        optimizer.step()        if batch % 100 == 0:            loss, current = loss.item(), (batch + 1) * len(X)            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")# 6. 测试函数def test(dataloader, model, loss_fn):    size = len(dataloader.dataset)    num_batches = len(dataloader)    model.eval()    test_loss, correct = 0, 0    with torch.no_grad():        for X, y in dataloader:            X, y = X.to(device), y.to(device)            pred = model(X)            test_loss += loss_fn(pred, y).item()            correct += (pred.argmax(1) == y).type(torch.float).sum().item()    test_loss /= num_batches    correct /= size    print(f"Test Error: n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} n")# 7. 训练模型并保存epochs = 5for t in range(epochs):    print(f"Epoch {t+1}n-------------------------------")    train(train_dataloader, model, loss_fn, optimizer)    test(test_dataloader, model, loss_fn)print("Done training!")# 保存模型的state_dicttorch.save(model.state_dict(), "model.pth")print("Saved PyTorch Model State to model.pth")

运行上述代码后,您将得到一个名为 model.pth 的文件,其中包含了训练好的模型参数。

模型加载与推理示例

现在,假设我们希望在一个完全独立的脚本中加载 model.pth 文件并进行推理。这个脚本不需要知道模型是如何训练的,但它必须知道模型的结构定义。

# inference_model.pyimport torchfrom torch import nnfrom torchvision import datasetsfrom torchvision.transforms import ToTensor# 1. 获取设备 (与训练时保持一致)device = (    "cuda"    if torch.cuda.is_available()    else "mps"    if torch.backends.mps.is_available()    else "cpu"

以上就是掌握PyTorch模型保存与加载:从训练到部署的完整指南的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/918418.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月29日 06:32:53
下一篇 2025年11月29日 06:33:14

相关推荐

  • 怎样用免费工具美化PPT_免费美化PPT的实用方法分享

    利用KIMI智能助手可免费将PPT美化为科技感风格,但需核对文字准确性;2. 天工AI擅长优化内容结构,提升逻辑性,适合高质量内容需求;3. SlidesAI支持语音输入与自动排版,操作便捷,利于紧急场景;4. Prezo提供多种模板,自动生成图文并茂幻灯片,适合学生与初创团队。 如果您有一份内容完…

    2025年12月6日 软件教程
    000
  • Pages怎么协作编辑同一文档 Pages多人实时协作的流程

    首先启用Pages共享功能,点击右上角共享按钮并选择“添加协作者”,设置为可编辑并生成链接;接着复制链接通过邮件或社交软件发送给成员,确保其使用Apple ID登录iCloud后即可加入编辑;也可直接在共享菜单中输入邮箱地址定向邀请,设定编辑权限后发送;最后在共享面板中管理协作者权限,查看实时在线状…

    2025年12月6日 软件教程
    100
  • REDMI K90系列正式发布,售价2599元起!

    10月23日,redmi k90系列正式亮相,推出redmi k90与redmi k90 pro max两款新机。其中,redmi k90搭载骁龙8至尊版处理器、7100mah大电池及100w有线快充等多项旗舰配置,起售价为2599元,官方称其为k系列迄今为止最完整的标准版本。 图源:REDMI红米…

    2025年12月6日 行业动态
    200
  • Linux中如何安装Nginx服务_Linux安装Nginx服务的完整指南

    首先更新系统软件包,然后通过对应包管理器安装Nginx,启动并启用服务,开放防火墙端口,最后验证欢迎页显示以确认安装成功。 在Linux系统中安装Nginx服务是搭建Web服务器的第一步。Nginx以高性能、低资源消耗和良好的并发处理能力著称,广泛用于静态内容服务、反向代理和负载均衡。以下是在主流L…

    2025年12月6日 运维
    000
  • Linux journalctl与systemctl status结合分析

    先看 systemctl status 确认服务状态,再用 journalctl 查看详细日志。例如 nginx 启动失败时,systemctl status 显示 Active: failed,journalctl -u nginx 发现端口 80 被占用,结合两者可快速定位问题根源。 在 Lin…

    2025年12月6日 运维
    100
  • 华为新机发布计划曝光:Pura 90系列或明年4月登场

    近日,有数码博主透露了华为2025年至2026年的新品规划,其中pura 90系列预计在2026年4月发布,有望成为华为新一代影像旗舰。根据路线图,华为将在2025年底至2026年陆续推出mate 80系列、折叠屏新机mate x7系列以及nova 15系列,而pura 90系列则将成为2026年上…

    2025年12月6日 行业动态
    100
  • TikTok视频无法下载怎么办 TikTok视频下载异常修复方法

    先检查链接格式、网络设置及工具版本。复制以https://www.tiktok.com/@或vm.tiktok.com开头的链接,删除?后参数,尝试短链接;确保网络畅通,可切换地区节点或关闭防火墙;更新工具至最新版,优先选用yt-dlp等持续维护的工具。 遇到TikTok视频下载不了的情况,别急着换…

    2025年12月6日 软件教程
    100
  • Linux如何优化系统性能_Linux系统性能优化的实用方法

    优化Linux性能需先监控资源使用,通过top、vmstat等命令分析负载,再调整内核参数如TCP优化与内存交换,结合关闭无用服务、选用合适文件系统与I/O调度器,持续按需调优以提升系统效率。 Linux系统性能优化的核心在于合理配置资源、监控系统状态并及时调整瓶颈环节。通过一系列实用手段,可以显著…

    2025年12月6日 运维
    000
  • Linux命令行中wc命令的实用技巧

    wc命令可统计文件的行数、单词数、字符数和字节数,常用-l统计行数,如wc -l /etc/passwd查看用户数量;结合grep可分析日志,如grep “error” logfile.txt | wc -l统计错误行数;-w统计单词数,-m统计字符数(含空格换行),-c统计…

    2025年12月6日 运维
    000
  • 曝小米17 Air正在筹备 超薄机身+2亿像素+eSIM技术?

    近日,手机行业再度掀起超薄机型热潮,三星与苹果已相继推出s25 edge与iphone air等轻薄旗舰,引发市场高度关注。在此趋势下,多家国产厂商被曝正积极布局相关技术,加速抢占这一细分赛道。据业内人士消息,小米的超薄旗舰机型小米17 air已进入筹备阶段。 小米17 Pro 爆料显示,小米正在评…

    2025年12月6日 行业动态
    000
  • 「世纪传奇刀片新篇」飞利浦影音双11声宴开启

    百年声学基因碰撞前沿科技,一场有关声音美学与设计美学的影音狂欢已悄然引爆2025“双十一”! 当绝大多数影音数码品牌还在价格战中挣扎时,飞利浦影音已然开启了一场跨越百年的“声”活革命。作为拥有深厚技术底蕴的音频巨头,飞利浦影音及配件此次“双十一”精准聚焦“传承经典”与“设计美学”两大核心,为热爱生活…

    2025年12月6日 行业动态
    000
  • 荣耀手表5Pro 10月23日正式开启首销国补优惠价1359.2元起售

    荣耀手表5pro自9月25日开启全渠道预售以来,市场热度持续攀升,上市初期便迎来抢购热潮,一度出现全线售罄、供不应求的局面。10月23日,荣耀手表5pro正式迎来首销,提供蓝牙版与esim版两种选择。其中,蓝牙版本的攀登者(橙色)、开拓者(黑色)和远航者(灰色)首销期间享受国补优惠价,到手价为135…

    2025年12月6日 行业动态
    000
  • Vue.js应用中配置环境变量:灵活管理后端通信地址

    在%ignore_a_1%应用中,灵活配置后端api地址等参数是开发与部署的关键。本文将详细介绍两种主要的环境变量配置方法:推荐使用的`.env`文件,以及通过`cross-env`库在命令行中设置环境变量。通过这些方法,开发者可以轻松实现开发、测试、生产等不同环境下配置的动态切换,提高应用的可维护…

    2025年12月6日 web前端
    000
  • 环境搭建docker环境下如何快速部署mysql集群

    使用Docker Compose部署MySQL主从集群,通过配置文件设置server-id和binlog,编写docker-compose.yml定义主从服务并组网,启动后创建复制用户并配置主从连接,最后验证数据同步是否正常。 在Docker环境下快速部署MySQL集群,关键在于合理使用Docker…

    2025年12月6日 数据库
    000
  • Xbox删忍龙美女角色 斯宾塞致敬板垣伴信被喷太虚伪

    近日,海外游戏推主@HaileyEira公开发表言论,批评Xbox负责人菲尔·斯宾塞不配向已故的《死或生》与《忍者龙剑传》系列之父板垣伴信致敬。她指出,Xbox并未真正尊重这位传奇制作人的创作遗产,反而在宣传相关作品时对内容进行了审查和删减。 所涉游戏为年初推出的《忍者龙剑传2:黑之章》,该作采用虚…

    2025年12月6日 游戏教程
    000
  • 如何在mysql中分析索引未命中问题

    答案是通过EXPLAIN分析执行计划,检查索引使用情况,优化WHERE条件写法,避免索引失效,结合慢查询日志定位问题SQL,并根据查询模式合理设计索引。 当 MySQL 查询性能下降,很可能是索引未命中导致的。要分析这类问题,核心是理解查询执行计划、检查索引设计是否合理,并结合实际数据访问模式进行优…

    2025年12月6日 数据库
    000
  • VSCode入门:基础配置与插件推荐

    刚用VSCode,别急着装一堆东西。先把基础设好,再按需求加插件,效率高还不卡。核心就三步:界面顺手、主题舒服、功能够用。 设置中文和常用界面 打开软件,左边活动栏有五个图标,点最下面那个“扩展”。搜索“Chinese”,装上官方出的“Chinese (Simplified) Language Pa…

    2025年12月6日 开发工具
    000
  • VSCode性能分析与瓶颈诊断技术

    首先通过资源监控定位异常进程,再利用开发者工具分析性能瓶颈,结合禁用扩展、优化语言服务器配置及项目设置,可有效解决VSCode卡顿问题。 VSCode作为主流的代码编辑器,虽然轻量高效,但在处理大型项目或配置复杂扩展时可能出现卡顿、响应延迟等问题。要解决这些性能问题,需要系统性地进行性能分析与瓶颈诊…

    2025年12月6日 开发工具
    000
  • php查询代码怎么写_php数据库查询语句编写技巧与实例

    在PHP中进行数据库查询,最常用的方式是使用MySQLi或PDO扩展连接MySQL数据库。下面介绍基本的查询代码写法、编写技巧以及实用示例,帮助你高效安全地操作数据库。 1. 使用MySQLi进行查询(面向对象方式) 这是较为推荐的方式,适合大多数中小型项目。 // 创建连接$host = ‘loc…

    2025年12月6日 后端开发
    000
  • VSCode的悬浮提示信息可以自定义吗?

    可以通过JSDoc、docstring和扩展插件自定义VSCode悬浮提示内容,如1. 添加JSDoc或Python docstring增强信息;2. 调整hover延迟与粘性等显示行为;3. 使用支持自定义提示的扩展或开发hover provider实现深度定制,但无法直接修改HTML结构或手动编…

    2025年12月6日 开发工具
    000

发表回复

登录后才能评论
关注微信