PyTorch中冻结中间层参数的深度解析与实践

PyTorch中冻结中间层参数的深度解析与实践

本教程深入探讨了在PyTorch中冻结神经网络特定中间层参数的两种常见方法:torch.no_grad()上下文管理器和设置参数的requires_grad = False属性。文章通过代码示例详细阐述了两种方法的原理、效果及适用场景,并明确指出requires_grad = False是实现精确中间层冻结的推荐方案,同时提供了验证层是否被冻结的技巧,旨在帮助开发者准确控制模型训练过程中的参数更新。

在深度学习模型训练过程中,我们经常会遇到需要冻结模型中某些层(即不更新这些层的参数)而只训练其他层的场景,例如在迁移学习中冻结预训练模型的特征提取层,或者在多任务学习中只更新特定任务相关的层。本文将详细探讨pytorch中实现这一目标的方法。

理解参数冻结的原理

在PyTorch中,参数更新是通过反向传播计算梯度并由优化器应用到参数上的。冻结一个层意味着阻止其参数参与梯度计算和随后的更新。这通常通过控制参数的requires_grad属性来实现。当requires_grad为False时,PyTorch的自动求导引擎将不会为该参数计算梯度,从而阻止其被优化器更新。

方法一:使用 torch.no_grad() 上下文管理器

torch.no_grad()是一个上下文管理器,它会禁用在其作用域内所有操作的梯度计算。这意味着,任何在with torch.no_grad():块中执行的操作,都不会构建计算图,也不会跟踪梯度。

让我们通过一个简单的三层线性网络为例来演示:

import torchimport torch.nn as nnimport torch.optim as optim# 定义一个简单的模型class SimpleModel(nn.Module):    def __init__(self):        super(SimpleModel, self).__init__()        self.lin0 = nn.Linear(1, 2)        self.lin1 = nn.Linear(2, 2)        self.lin2 = nn.Linear(2, 10)    def forward_with_no_grad(self, x):        x = self.lin0(x)        with torch.no_grad():            x = self.lin1(x) # 尝试冻结lin1        x = self.lin2(x)        return x# 实例化模型model_no_grad = SimpleModel()# 记录初始参数initial_lin0_weight = model_no_grad.lin0.weight.clone()initial_lin1_weight = model_no_grad.lin1.weight.clone()initial_lin2_weight = model_no_grad.lin2.weight.clone()# 模拟训练步骤input_data = torch.randn(1, 1)target = torch.randint(0, 10, (1,))criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model_no_grad.parameters(), lr=0.01)print("--- 使用 torch.no_grad() 冻结中间层 ---")print("初始 lin0 权重:n", initial_lin0_weight)print("初始 lin1 权重:n", initial_lin1_weight)print("初始 lin2 权重:n", initial_lin2_weight)# 前向传播与反向传播output = model_no_grad.forward_with_no_grad(input_data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()# 检查参数变化print("n训练后 lin0 权重:n", model_no_grad.lin0.weight)print("训练后 lin1 权重:n", model_no_grad.lin1.weight)print("训练后 lin2 权重:n", model_no_grad.lin2.weight)# 验证是否冻结print("nlin0 权重是否变化:", not torch.equal(initial_lin0_weight, model_no_grad.lin0.weight))print("lin1 权重是否变化:", not torch.equal(initial_lin1_weight, model_no_grad.lin1.weight))print("lin2 权重是否变化:", not torch.equal(initial_lin2_weight, model_no_grad.lin2.weight))

分析 torch.no_grad() 的效果:上述代码运行后会发现,lin0和lin1的参数都没有更新,而只有lin2的参数发生了变化。这是因为当lin1的操作在torch.no_grad()块中执行时,其输出张量x(来自lin1)的grad_fn属性将为None,这意味着从lin1往前的计算图被截断了。因此,尽管lin2的梯度可以正常计算并回传到lin1的输出,但由于lin1的操作没有梯度跟踪,导致无法计算lin1自身的梯度,也无法将梯度继续回传到lin0。最终结果是,lin0和lin1的参数都不会得到更新。

结论: torch.no_grad() 适用于冻结整个模型或模型的一部分,使其在推理阶段不消耗内存来存储梯度信息,或者在训练时完全禁用某些部分的梯度更新。但它不适合精确地冻结中间层而允许其上游层更新的场景。

方法二:设置 requires_grad = False

这是在PyTorch中实现精确层冻结的推荐方法。通过将特定层的参数的requires_grad属性设置为False,我们可以明确告诉PyTorch的自动求导引擎不需要为这些参数计算梯度。

import torchimport torch.nn as nnimport torch.optim as optim# 定义一个简单的模型class SimpleModel(nn.Module):    def __init__(self):        super(SimpleModel, self).__init__()        self.lin0 = nn.Linear(1, 2)        self.lin1 = nn.Linear(2, 2)        self.lin2 = nn.Linear(2, 10)    def forward(self, x):        x = self.lin0(x)        x = self.lin1(x)        x = self.lin2(x)        return x# 实例化模型model_requires_grad = SimpleModel()# 冻结lin1层的参数model_requires_grad.lin1.weight.requires_grad = Falsemodel_requires_grad.lin1.bias.requires_grad = False# 记录初始参数initial_lin0_weight_rg = model_requires_grad.lin0.weight.clone()initial_lin1_weight_rg = model_requires_grad.lin1.weight.clone()initial_lin2_weight_rg = model_requires_grad.lin2.weight.clone()# 注意:优化器只应传入 requires_grad 为 True 的参数optimizer_rg = optim.SGD(filter(lambda p: p.requires_grad, model_requires_grad.parameters()), lr=0.01)# 模拟训练步骤input_data = torch.randn(1, 1)target = torch.randint(0, 10, (1,))criterion = nn.CrossEntropyLoss()print("n--- 使用 requires_grad = False 冻结中间层 ---")print("初始 lin0 权重:n", initial_lin0_weight_rg)print("初始 lin1 权重:n", initial_lin1_weight_rg)print("初始 lin2 权重:n", initial_lin2_weight_rg)# 前向传播与反向传播output = model_requires_grad(input_data)loss = criterion(output, target)optimizer_rg.zero_grad()loss.backward()optimizer_rg.step()# 检查参数变化print("n训练后 lin0 权重:n", model_requires_grad.lin0.weight)print("训练后 lin1 权重:n", model_requires_grad.lin1.weight)print("训练后 lin2 权重:n", model_requires_grad.lin2.weight)# 验证是否冻结print("nlin0 权重是否变化:", not torch.equal(initial_lin0_weight_rg, model_requires_grad.lin0.weight))print("lin1 权重是否变化:", not torch.equal(initial_lin1_weight_rg, model_requires_grad.lin1.weight))print("lin2 权重是否变化:", not torch.equal(initial_lin2_weight_rg, model_requires_grad.lin2.weight))

分析 requires_grad = False 的效果:运行上述代码后,你会发现lin0和lin2的参数都得到了更新,而只有lin1的参数保持不变。这是因为:

lin1.weight.requires_grad = False和lin1.bias.requires_grad = False明确地告诉PyTorch不要为这些参数计算梯度。在反向传播时,尽管梯度会流经lin1,但由于lin1的参数被标记为不需要梯度,PyTorch会跳过其梯度计算,并继续将梯度回传到lin0。优化器在初始化时,通过filter(lambda p: p.requires_grad, model_requires_grad.parameters())确保它只接收那些requires_grad=True的参数进行更新。

结论: requires_grad = False 是实现精确冻结模型中特定层(包括中间层)的正确且推荐的方法。它允许梯度流经被冻结的层,但不会更新该层自身的参数,同时能将梯度正确地传递给更上游的层。

验证层是否被冻结

在实际操作中,可以通过以下几种方式来验证层是否成功被冻结:

检查 param.requires_grad 属性:在设置后,可以打印出model.lin1.weight.requires_grad来确认其是否为False。

检查 param.grad 属性:在执行loss.backward()之后,检查被冻结层的参数(例如model.lin1.weight.grad)是否为None。如果为None,则表示没有为该参数计算梯度。

检查参数值是否变化:在训练循环开始前记录参数的初始值,经过一个或多个训练步骤后,再次检查这些参数的值。如果参数值未发生变化,则说明该层已被冻结。这正是本文示例代码中采用的方法。

总结与最佳实践

精确冻结中间层: 始终使用设置参数的requires_grad = False属性来冻结模型中的特定层。优化器初始化: 当冻结部分层时,务必在初始化优化器时,只将那些requires_grad = True的参数传递给优化器。例如:optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)。torch.no_grad() 的适用场景: torch.no_grad() 主要用于推理阶段,或者在训练过程中完全禁用某一部分的梯度计算,它会截断计算图,不适合需要梯度回传到上游层的场景。模型状态: 冻结层与model.train()和model.eval()没有直接冲突。model.eval()主要影响nn.BatchNorm和nn.Dropout等层在训练和评估模式下的行为,而requires_grad控制的是参数是否更新。

通过理解和正确应用requires_grad = False,开发者可以灵活地控制PyTorch模型中各层的训练状态,从而实现更复杂的训练策略,例如微调预训练模型或进行部分模型的更新。

以上就是PyTorch中冻结中间层参数的深度解析与实践的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1368386.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年12月14日 08:44:21
下一篇 2025年12月14日 08:44:30

相关推荐

  • Uniapp 中如何不拉伸不裁剪地展示图片?

    灵活展示图片:如何不拉伸不裁剪 在界面设计中,常常需要以原尺寸展示用户上传的图片。本文将介绍一种在 uniapp 框架中实现该功能的简单方法。 对于不同尺寸的图片,可以采用以下处理方式: 极端宽高比:撑满屏幕宽度或高度,再等比缩放居中。非极端宽高比:居中显示,若能撑满则撑满。 然而,如果需要不拉伸不…

    2025年12月24日
    400
  • 如何让小说网站控制台显示乱码,同时网页内容正常显示?

    如何在不影响用户界面的情况下实现控制台乱码? 当在小说网站上下载小说时,大家可能会遇到一个问题:网站上的文本在网页内正常显示,但是在控制台中却是乱码。如何实现此类操作,从而在不影响用户界面(UI)的情况下保持控制台乱码呢? 答案在于使用自定义字体。网站可以通过在服务器端配置自定义字体,并通过在客户端…

    2025年12月24日
    800
  • 如何在地图上轻松创建气泡信息框?

    地图上气泡信息框的巧妙生成 地图上气泡信息框是一种常用的交互功能,它简便易用,能够为用户提供额外信息。本文将探讨如何借助地图库的功能轻松创建这一功能。 利用地图库的原生功能 大多数地图库,如高德地图,都提供了现成的信息窗体和右键菜单功能。这些功能可以通过以下途径实现: 高德地图 JS API 参考文…

    2025年12月24日
    400
  • 如何使用 scroll-behavior 属性实现元素scrollLeft变化时的平滑动画?

    如何实现元素scrollleft变化时的平滑动画效果? 在许多网页应用中,滚动容器的水平滚动条(scrollleft)需要频繁使用。为了让滚动动作更加自然,你希望给scrollleft的变化添加动画效果。 解决方案:scroll-behavior 属性 要实现scrollleft变化时的平滑动画效果…

    2025年12月24日
    000
  • 如何为滚动元素添加平滑过渡,使滚动条滑动时更自然流畅?

    给滚动元素平滑过渡 如何在滚动条属性(scrollleft)发生改变时为元素添加平滑的过渡效果? 解决方案:scroll-behavior 属性 为滚动容器设置 scroll-behavior 属性可以实现平滑滚动。 html 代码: click the button to slide right!…

    2025年12月24日
    500
  • 如何选择元素个数不固定的指定类名子元素?

    灵活选择元素个数不固定的指定类名子元素 在网页布局中,有时需要选择特定类名的子元素,但这些元素的数量并不固定。例如,下面这段 html 代码中,activebar 和 item 元素的数量均不固定: *n *n 如果需要选择第一个 item元素,可以使用 css 选择器 :nth-child()。该…

    2025年12月24日
    200
  • 使用 SVG 如何实现自定义宽度、间距和半径的虚线边框?

    使用 svg 实现自定义虚线边框 如何实现一个具有自定义宽度、间距和半径的虚线边框是一个常见的前端开发问题。传统的解决方案通常涉及使用 border-image 引入切片图片,但是这种方法存在引入外部资源、性能低下的缺点。 为了避免上述问题,可以使用 svg(可缩放矢量图形)来创建纯代码实现。一种方…

    2025年12月24日
    100
  • 如何让“元素跟随文本高度,而不是撑高父容器?

    如何让 元素跟随文本高度,而不是撑高父容器 在页面布局中,经常遇到父容器高度被子元素撑开的问题。在图例所示的案例中,父容器被较高的图片撑开,而文本的高度没有被考虑。本问答将提供纯css解决方案,让图片跟随文本高度,确保父容器的高度不会被图片影响。 解决方法 为了解决这个问题,需要将图片从文档流中脱离…

    2025年12月24日
    000
  • 为什么 CSS mask 属性未请求指定图片?

    解决 css mask 属性未请求图片的问题 在使用 css mask 属性时,指定了图片地址,但网络面板显示未请求获取该图片,这可能是由于浏览器兼容性问题造成的。 问题 如下代码所示: 立即学习“前端免费学习笔记(深入)”; icon [data-icon=”cloud”] { –icon-cl…

    2025年12月24日
    200
  • 如何利用 CSS 选中激活标签并影响相邻元素的样式?

    如何利用 css 选中激活标签并影响相邻元素? 为了实现激活标签影响相邻元素的样式需求,可以通过 :has 选择器来实现。以下是如何具体操作: 对于激活标签相邻后的元素,可以在 css 中使用以下代码进行设置: li:has(+li.active) { border-radius: 0 0 10px…

    2025年12月24日
    100
  • 如何模拟Windows 10 设置界面中的鼠标悬浮放大效果?

    win10设置界面的鼠标移动显示周边的样式(探照灯效果)的实现方式 在windows设置界面的鼠标悬浮效果中,光标周围会显示一个放大区域。在前端开发中,可以通过多种方式实现类似的效果。 使用css 使用css的transform和box-shadow属性。通过将transform: scale(1.…

    2025年12月24日
    200
  • 为什么我的 Safari 自定义样式表在百度页面上失效了?

    为什么在 Safari 中自定义样式表未能正常工作? 在 Safari 的偏好设置中设置自定义样式表后,您对其进行测试却发现效果不同。在您自己的网页中,样式有效,而在百度页面中却失效。 造成这种情况的原因是,第一个访问的项目使用了文件协议,可以访问本地目录中的图片文件。而第二个访问的百度使用了 ht…

    2025年12月24日
    000
  • 如何用前端实现 Windows 10 设置界面的鼠标移动探照灯效果?

    如何在前端实现 Windows 10 设置界面中的鼠标移动探照灯效果 想要在前端开发中实现 Windows 10 设置界面中类似的鼠标移动探照灯效果,可以通过以下途径: CSS 解决方案 DEMO 1: Windows 10 网格悬停效果:https://codepen.io/tr4553r7/pe…

    2025年12月24日
    000
  • 使用CSS mask属性指定图片URL时,为什么浏览器无法加载图片?

    css mask属性未能加载图片的解决方法 使用css mask属性指定图片url时,如示例中所示: mask: url(“https://api.iconify.design/mdi:apple-icloud.svg”) center / contain no-repeat; 但是,在网络面板中却…

    2025年12月24日
    000
  • 如何用CSS Paint API为网页元素添加时尚的斑马线边框?

    为元素添加时尚的斑马线边框 在网页设计中,有时我们需要添加时尚的边框来提升元素的视觉效果。其中,斑马线边框是一种既醒目又别致的设计元素。 实现斜向斑马线边框 要实现斜向斑马线间隔圆环,我们可以使用css paint api。该api提供了强大的功能,可以让我们在元素上绘制复杂的图形。 立即学习“前端…

    2025年12月24日
    000
  • 图片如何不撑高父容器?

    如何让图片不撑高父容器? 当父容器包含不同高度的子元素时,父容器的高度通常会被最高元素撑开。如果你希望父容器的高度由文本内容撑开,避免图片对其产生影响,可以通过以下 css 解决方法: 绝对定位元素: .child-image { position: absolute; top: 0; left: …

    2025年12月24日
    000
  • CSS 帮助

    我正在尝试将文本附加到棕色框的左侧。我不能。我不知道代码有什么问题。请帮助我。 css .hero { position: relative; bottom: 80px; display: flex; justify-content: left; align-items: start; color:…

    2025年12月24日 好文分享
    200
  • 前端代码辅助工具:如何选择最可靠的AI工具?

    前端代码辅助工具:可靠性探讨 对于前端工程师来说,在HTML、CSS和JavaScript开发中借助AI工具是司空见惯的事情。然而,并非所有工具都能提供同等的可靠性。 个性化需求 关于哪个AI工具最可靠,这个问题没有一刀切的答案。每个人的使用习惯和项目需求各不相同。以下是一些影响选择的重要因素: 立…

    2025年12月24日
    300
  • 如何用 CSS Paint API 实现倾斜的斑马线间隔圆环?

    实现斑马线边框样式:探究 css paint api 本文将探究如何使用 css paint api 实现倾斜的斑马线间隔圆环。 问题: 给定一个有多个圆圈组成的斑马线图案,如何使用 css 实现倾斜的斑马线间隔圆环? 答案: 立即学习“前端免费学习笔记(深入)”; 使用 css paint api…

    2025年12月24日
    000
  • 如何使用CSS Paint API实现倾斜斑马线间隔圆环边框?

    css实现斑马线边框样式 想定制一个带有倾斜斑马线间隔圆环的边框?现在使用css paint api,定制任何样式都轻而易举。 css paint api 这是一个新的css特性,允许开发人员创建自定义形状和图案,其中包括斑马线样式。 立即学习“前端免费学习笔记(深入)”; 实现倾斜斑马线间隔圆环 …

    2025年12月24日
    100

发表回复

登录后才能评论
关注微信