理解 Transformers 中的交叉熵损失及 Masked Label 问题

理解 transformers 中的交叉熵损失及 masked label 问题

本文旨在深入解析 Hugging Face Transformers 库中,使用 GPT-2 等 Decoder-Only 模型计算交叉熵损失时,如何正确使用 masked label,并解释了常见的困惑。通过具体示例和代码,详细阐述了 target_ids 的构建方法,以及如何结合 ignore_index 来控制损失计算的范围,从而避免不必要的计算偏差,并提供了手动计算损失的替代方案。

在使用 Hugging Face Transformers 库进行自然语言处理任务时,尤其是使用 GPT-2 等 Decoder-Only 模型时,理解交叉熵损失的计算方式和 masked label 的作用至关重要。本文将深入探讨 target_ids 的正确构建方法,以及如何利用 ignore_index 来精确控制损失计算的范围,从而避免常见的错误和困惑。

Decoder-Only 模型、输入和目标

在 Hugging Face Transformers 库中,Decoder-Only 模型(如 GPT-2)主要依赖 input_ids、label_ids 和 attention_mask 进行训练。其中,input_ids 代表输入序列的 token IDs,label_ids 代表目标序列的 token IDs,而 attention_mask 用于指示哪些 token 应该被模型关注。

假设我们有一个输入 “The answer is:”,我们希望模型学习回答 “42”。将这个句子转化为 token IDs,假设 “The answer is: 42” 对应的 IDs 是 [464, 3280, 318, 25, 5433](其中 “:” 是 25,” 42″ 是 5433)。

为了让模型学习预测 “42”,我们需要设置 label_ids 为 [-100, -100, -100, -100, 5433]。这样,模型就不会学习到 “The answer” 后面应该跟着 “is:”,因为这些位置的损失被忽略了。

注意: Decoder-Only 模型要求输入和输出具有相同的形状。这与 Encoder-Decoder 模型不同,后者可以接受 “The answer is:” 作为输入,而 “42” 作为输出。

-100 是 torch.nn.CrossEntropyLoss 的默认 ignore_index。使用 “忽略” 比 “mask” 更准确,因为 “mask” 暗示模型看不到这些输入,或者原始输入被替换为特殊的 “” token。

理解问题的根源

原始问题中,代码 target_ids[:, :-seq_len] = -100 试图将 target_ids 中除了最后 seq_len 个元素之外的所有元素设置为 -100。然而,由于 target_ids 的长度为 seq_len,所以实际上没有任何元素被修改,导致损失计算结果不变。

迭代数据集时的正确方法

在使用滑动窗口迭代数据集时,masked label 的应用需要在不同的迭代步骤中进行调整。以下是一个示例:

第一次迭代:

max_length = 1024stride = 512end_loc = 1024input_ids = tokens[0 : 1024]target_ids = input_ids.clone()target_ids[:-1024] = -100  # 实际上没有修改任何元素assert torch.equal(target_ids, input_ids)trg_len = 1024prev_end_loc = 1024

在第一次迭代中,由于 target_ids[:-1024] 实际上等于 target_ids[:0],因此 target_ids 没有被修改,损失是基于所有 1024 个 token 计算的。

第二次及后续迭代:

begin_loc = 512end_loc = 1536trg_len = 1536 - 1024  # 512input_ids = tokens[512 : 1536]  # 注意:tokens 512-1024 已经被模型看到过target_ids = tokens[512 : 1536].clone()target_ids[:-512] = -100  # 将已经见过的 token 对应的 label 设置为 -100

从第二次迭代开始,target_ids 的前 512 个元素(对应于模型已经见过的 token)被设置为 -100,损失仅基于后 512 个 token 计算。

手动计算损失

如果需要更精细地控制损失计算过程,可以直接从模型获取 logits,然后手动计算交叉熵损失。

from torch.nn import CrossEntropyLossoutputs = model(encodings.input_ids, labels=None)logits = outputs.logitslabels = target_ids.to(logits.device)# 调整 logits 和 labels 的形状,使其匹配shift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()# 计算损失loss_fct = CrossEntropyLoss(reduction='mean')loss = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))print(loss.item())

这段代码首先从模型获取 logits,然后将 logits 和 labels 的形状进行调整,使其能够匹配。最后,使用 CrossEntropyLoss 计算损失。

总结:

理解 Decoder-Only 模型中 target_ids 的构建方式,以及如何利用 ignore_index 来控制损失计算的范围,是使用 Hugging Face Transformers 库进行自然语言处理任务的关键。通过正确设置 target_ids,可以避免不必要的计算偏差,并提高模型的训练效果。

以上就是理解 Transformers 中的交叉熵损失及 Masked Label 问题的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1374744.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年12月14日 14:25:12
下一篇 2025年12月14日 14:25:28

相关推荐

  • HTML、CSS 和 JavaScript 中的简单侧边栏菜单

    构建一个简单的侧边栏菜单是一个很好的主意,它可以为您的网站添加有价值的功能和令人惊叹的外观。 侧边栏菜单对于客户找到不同项目的方式很有用,而不会让他们觉得自己有太多选择,从而创造了简单性和秩序。 今天,我将分享一个简单的 HTML、CSS 和 JavaScript 源代码来创建一个简单的侧边栏菜单。…

    2025年12月24日
    200
  • 前端代码辅助工具:如何选择最可靠的AI工具?

    前端代码辅助工具:可靠性探讨 对于前端工程师来说,在HTML、CSS和JavaScript开发中借助AI工具是司空见惯的事情。然而,并非所有工具都能提供同等的可靠性。 个性化需求 关于哪个AI工具最可靠,这个问题没有一刀切的答案。每个人的使用习惯和项目需求各不相同。以下是一些影响选择的重要因素: 立…

    2025年12月24日
    000
  • 带有 HTML、CSS 和 JavaScript 工具提示的响应式侧边导航栏

    响应式侧边导航栏不仅有助于改善网站的导航,还可以解决整齐放置链接的问题,从而增强用户体验。通过使用工具提示,可以让用户了解每个链接的功能,包括设计紧凑的情况。 在本教程中,我将解释使用 html、css、javascript 创建带有工具提示的响应式侧栏导航的完整代码。 对于那些一直想要一个干净、简…

    2025年12月24日
    000
  • 布局 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在这里查看视觉效果: 固定导航 – 布局 – codesandbox两列 – 布局 – codesandbox三列 – 布局 – codesandbox圣杯 &#8…

    2025年12月24日
    000
  • 隐藏元素 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在此处查看隐藏元素的视觉效果 – codesandbox 隐藏元素 hiding elements hiding elements hiding elements hiding elements hiding element…

    2025年12月24日
    400
  • 居中 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在此处查看垂直中心 – codesandbox 和水平中心的视觉效果。 通过 css 居中 垂直居中 centering centering centering centering centering centering立即…

    2025年12月24日 好文分享
    300
  • 如何在 Laravel 框架中轻松集成微信支付和支付宝支付?

    如何用 laravel 框架集成微信支付和支付宝支付 问题:如何在 laravel 框架中集成微信支付和支付宝支付? 回答: 建议使用 easywechat 的 laravel 版,easywechat 是一个由腾讯工程师开发的高质量微信开放平台 sdk,已被广泛地应用于许多 laravel 项目中…

    2025年12月24日
    000
  • 如何在移动端实现子 div 在父 div 内任意滑动查看?

    如何在移动端中实现让子 div 在父 div 内任意滑动查看 在移动端开发中,有时我们需要让子 div 在父 div 内任意滑动查看。然而,使用滚动条无法实现负值移动,因此需要采用其他方法。 解决方案: 使用绝对布局(absolute)或相对布局(relative):将子 div 设置为绝对或相对定…

    2025年12月24日
    000
  • 移动端嵌套 DIV 中子 DIV 如何水平滑动?

    移动端嵌套 DIV 中子 DIV 滑动 在移动端开发中,遇到这样的问题:当子 DIV 的高度小于父 DIV 时,无法在父 DIV 中水平滚动子 DIV。 无限画布 要实现子 DIV 在父 DIV 中任意滑动,需要创建一个无限画布。使用滚动无法达到负值,因此需要使用其他方法。 相对定位 一种方法是将子…

    2025年12月24日
    000
  • 移动端项目中,如何消除rem字体大小计算带来的CSS扭曲?

    移动端项目中消除rem字体大小计算带来的css扭曲 在移动端项目中,使用rem计算根节点字体大小可以实现自适应布局。但是,此方法可能会导致页面打开时出现css扭曲,这是因为页面内容在根节点字体大小赋值后重新渲染造成的。 解决方案: 要避免这种情况,将计算根节点字体大小的js脚本移动到页面的最前面,即…

    2025年12月24日
    000
  • Nuxt 移动端项目中 rem 计算导致 CSS 变形,如何解决?

    Nuxt 移动端项目中解决 rem 计算导致 CSS 变形 在 Nuxt 移动端项目中使用 rem 计算根节点字体大小时,可能会遇到一个问题:页面内容在字体大小发生变化时会重绘,导致 CSS 变形。 解决方案: 可将计算根节点字体大小的 JS 代码块置于页面最前端的 标签内,确保在其他资源加载之前执…

    2025年12月24日
    200
  • Nuxt 移动端项目使用 rem 计算字体大小导致页面变形,如何解决?

    rem 计算导致移动端页面变形的解决方法 在 nuxt 移动端项目中使用 rem 计算根节点字体大小时,页面会发生内容重绘,导致页面打开时出现样式变形。如何避免这种现象? 解决方案: 移动根节点字体大小计算代码到页面顶部,即 head 中。 原理: flexível.js 也遇到了类似问题,它的解决…

    2025年12月24日
    000
  • 形状 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在此处查看 codesandbox 的视觉效果。 通过css绘制各种形状 如何在 css 中绘制正方形、梯形、三角形、异形三角形、扇形、圆形、半圆、固定宽高比、0.5px 线? shapes 0.5px line .square { w…

    2025年12月24日
    000
  • 有哪些美观的开源数字大屏驾驶舱框架?

    开源数字大屏驾驶舱框架推荐 问题:有哪些美观的开源数字大屏驾驶舱框架? 答案: 资源包 [弗若恩智能大屏驾驶舱开发资源包](https://www.fanruan.com/resource/152) 软件 [弗若恩报表 – 数字大屏可视化组件](https://www.fanruan.c…

    2025年12月24日
    000
  • 网站底部如何实现飘彩带效果?

    网站底部飘彩带效果的 js 库实现 许多网站都会在特殊节日或活动中添加一些趣味性的视觉效果,例如点击按钮后散发的五彩缤纷的彩带。对于一个特定的网站来说,其飘彩带效果的实现方式可能有以下几个方面: 以 https://dub.sh/ 网站为例,它底部按钮点击后的彩带效果是由 javascript 库实…

    2025年12月24日
    000
  • 网站彩带效果背后是哪个JS库?

    网站彩带效果背后是哪个js库? 当你访问某些网站时,点击按钮后,屏幕上会飘出五颜六色的彩带,营造出庆祝的氛围。这些效果是通过使用javascript库实现的。 问题: 哪个javascript库能够实现网站上点击按钮散发彩带的效果? 答案: 根据给定网站的源代码分析: 可以发现,该网站使用了以下js…

    好文分享 2025年12月24日
    100
  • 产品预览卡项目

    这个项目最初是来自 Frontend Mentor 的挑战,旨在使用 HTML 和 CSS 创建响应式产品预览卡。最初的任务是设计一张具有视觉吸引力和功能性的产品卡,能够无缝适应各种屏幕尺寸。这涉及使用 CSS 媒体查询来确保布局在不同设备上保持一致且用户友好。产品卡包含产品图像、标签、标题、描述和…

    2025年12月24日
    100
  • 如何利用 echarts-gl 绘制带发光的 3D 图表?

    如何绘制带发光的 3d 图表,类似于 echarts 中的示例? 为了实现类似的 3d 图表效果,需要引入 echarts-gl 库:https://github.com/ecomfe/echarts-gl。 echarts-gl 专用于在 webgl 环境中渲染 3d 图形。它提供了各种 3d 图…

    2025年12月24日
    000
  • 如何在 Element UI 的 el-rate 组件中实现 5 颗星 5 分制与百分制之间的转换?

    如何在el-rate中将5颗星5分制的分值显示为5颗星百分制? 要实现该效果,只需使用 el-rate 组件的 allow-half 属性。在设置 allow-half 属性后,获得的结果乘以 20 即可得到0-100之间的百分制分数。如下所示: score = score * 20; 动态显示鼠标…

    2025年12月24日
    100
  • CSS 最佳实践:后端程序员重温 CSS 时常见的三个疑问?

    CSS 最佳实践:提升代码质量 作为后端程序员,在重温 CSS/HTML 时,你可能会遇到一些关于最佳实践的问题。以下将解答三个常见问题,帮助你编写更规范、清晰的 CSS 代码。 1. margin 设置策略 当相邻元素都设置了 margin 时,通常情况下应为上一个元素设置 margin-bott…

    2025年12月24日
    000

发表回复

登录后才能评论
关注微信