
pytorch模型加载时,需要先定义模型结构,再加载保存的state_dict参数。这是因为pytorch通常只保存模型参数而非整个模型对象,以避免python对象序列化问题。本文将详细介绍如何分离模型的训练、保存与加载推理过程,并通过示例代码演示这一标准实践,帮助用户高效复用预训练模型。
在PyTorch中,将训练好的模型保存到磁盘并在后续加载进行推理是机器学习工作流中的常见需求。初学者常遇到的一个困惑是:加载模型时是否必须重新定义模型的完整结构?答案是肯定的,且这是PyTorch推荐的标准实践。本教程将深入探讨PyTorch的模型保存与加载机制,并提供清晰的示例代码,指导您如何正确地分离模型的训练、保存与推理过程。
理解PyTorch模型保存机制
PyTorch模型(nn.Module的实例)的保存通常有两种主要方式:
保存整个模型(不推荐):使用 torch.save(model, “model.pth”)。这种方法会保存整个模型对象,包括其结构和所有参数。然而,它依赖于Python的pickle模块进行序列化。当模型定义所在的类、包或文件结构发生变化时,或者在不同Python版本、PyTorch版本之间加载时,可能会遇到兼容性问题和序列化错误。因此,这种方法通常不被推荐用于生产环境或长期存储。保存模型的state_dict(推荐):使用 torch.save(model.state_dict(), “model.pth”)。state_dict是一个Python字典,它存储了模型中所有可学习参数(如权重和偏置)的映射。这种方式只保存参数,而模型的结构定义则需要独立存在。加载时,您需要先实例化一个具有相同结构的模型对象,然后将state_dict加载到这个新创建的对象中。这种方法更加健壮、灵活,且不易受环境变化的影响。
核心思想是: 模型结构(由nn.Module类定义)与模型参数(存储在state_dict中)是分离的。当您保存state_dict时,您只是保存了模型学到的“知识”,而模型的“骨架”——其架构定义——则需要在加载时重新提供。
模型训练与保存示例
为了演示这一过程,我们将使用一个简单的神经网络在FashionMNIST数据集上进行训练,并保存其state_dict。
Stable Diffusion 2.1 Demo
最新体验版 Stable Diffusion 2.1
101 查看详情
首先,我们需要设置环境、定义模型、数据加载器以及训练和测试函数。
# train_model.pyimport torchfrom torch import nnfrom torch.utils.data import DataLoaderfrom torchvision import datasetsfrom torchvision.transforms import ToTensor# 1. 准备数据training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(),)test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor(),)batch_size = 64train_dataloader = DataLoader(training_data, batch_size=batch_size)test_dataloader = DataLoader(test_data, batch_size=batch_size)# 2. 获取设备device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")print(f"Using {device} device")# 3. 定义模型class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10) ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logitsmodel = NeuralNetwork().to(device)print(model)# 4. 定义损失函数和优化器loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)# 5. 训练函数def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) pred = model(X) loss = loss_fn(pred, y) optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), (batch + 1) * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")# 6. 测试函数def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print(f"Test Error: n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} n")# 7. 训练模型并保存epochs = 5for t in range(epochs): print(f"Epoch {t+1}n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model, loss_fn)print("Done training!")# 保存模型的state_dicttorch.save(model.state_dict(), "model.pth")print("Saved PyTorch Model State to model.pth")
运行上述代码后,您将得到一个名为 model.pth 的文件,其中包含了训练好的模型参数。
模型加载与推理示例
现在,假设我们希望在一个完全独立的脚本中加载 model.pth 文件并进行推理。这个脚本不需要知道模型是如何训练的,但它必须知道模型的结构定义。
# inference_model.pyimport torchfrom torch import nnfrom torchvision import datasetsfrom torchvision.transforms import ToTensor# 1. 获取设备 (与训练时保持一致)device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
以上就是掌握PyTorch模型保存与加载:从训练到部署的完整指南的详细内容,更多请关注创想鸟其它相关文章!
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/918418.html
微信扫一扫
支付宝扫一扫