理解 Transformers 中的交叉熵损失与 Masked Label 问题

理解 transformers 中的交叉熵损失与 masked label 问题

本文旨在深入解析 Hugging Face Transformers 库中,针对 Decoder-Only 模型(如 GPT-2)计算交叉熵损失时,如何正确使用 labels 参数进行 Masked Label 的设置。通过具体示例和代码,详细解释了 target_ids 的构造方式,以及如何避免常见的错误,并提供了自定义计算损失的方案。

在使用 Hugging Face Transformers 库训练或评估 Decoder-Only 模型(例如 GPT-2)时,交叉熵损失是一个核心概念。labels 参数在计算损失中扮演着关键角色,尤其是在需要对部分 token 进行 Masking 的场景下。本文将深入探讨 labels 参数的使用,以及如何避免常见的错误配置。

Decoder-Only 模型中的输入与目标

在 Hugging Face 中,Decoder-Only 模型通常需要 input_ids 和 labels 作为输入。attention_mask 虽然重要,但在此处不重点讨论。核心思想是,对于 Decoder-Only 模型,输入和目标需要具有相同的形状。

例如,假设输入是 “The answer is:”,我们希望模型学习到 “42” 作为答案。那么,完整的文本序列为 “The answer is: 42″,其对应的 token IDs 可能为 [464, 3280, 318, 25, 5433] (其中 “:” 对应 25,” 42″ 对应 5433)。

为了让模型学习预测 “42”,我们需要设置 labels 为 [-100, -100, -100, -100, 5433]。这里的 -100 是 torch.nn.CrossEntropyLoss 的 ignore_index,意味着这些位置的损失将被忽略。换句话说,模型不会学习 “The answer” 后面跟着 “is:” 这样的关系,而是专注于学习在给定 “The answer is:” 的前提下,应该预测 “42”。

注意: Decoder-Only 模型要求输入和输出具有相同的形状。这与 Encoder-Decoder 模型不同,后者可以有 “The answer is:” 作为输入,而 “42” 作为输出。

常见错误与正确做法

在问题中,作者尝试使用 target_ids[:, :-seq_len] = -100 来 Masking labels,但结果并未如预期。问题在于,当 seq_len 等于输入序列的长度时,这条语句实际上没有修改任何元素。

正确的做法是,根据实际需求,有选择性地将 target_ids 中的某些位置设置为 -100。例如,在迭代处理文本数据时,可能需要忽略之前已经见过的 token,而只计算当前新 token 的损失。

以下是一个示例,展示了如何在迭代过程中正确地 Masking labels:

max_length = 1024stride = 512# 假设 tokens 是一个包含完整文本 token IDs 的列表# 第一次迭代end_loc = max_lengthinput_ids = tokens[0:end_loc]target_ids = input_ids.clone()# 第一次迭代时,不需要 Masking,因此 target_ids 与 input_ids 相同# 第二次及后续迭代begin_loc = strideend_loc = begin_loc + max_lengthinput_ids = tokens[begin_loc:end_loc]target_ids = input_ids.clone()target_ids[:max_length - stride] = -100  # Masking 之前已经见过的 token

在这个例子中,每次迭代都会处理长度为 max_length 的文本片段,但只有最后 stride 个 token 的损失会被计算,之前的 token 通过 Masking 被忽略。

自定义计算损失

如果不想依赖模型内部的损失计算方式,也可以手动计算交叉熵损失。这种方法提供了更大的灵活性,可以更好地控制损失计算的细节。

以下是一个自定义计算损失的示例代码:

from transformers import GPT2LMHeadModel, GPT2TokenizerFastimport torchfrom torch.nn import CrossEntropyLossmodel_id = "gpt2-large"model = GPT2LMHeadModel.from_pretrained(model_id)tokenizer = GPT2TokenizerFast.from_pretrained(model_id)encodings = tokenizer("She felt his demeanor was sweet and endearing.", return_tensors="pt")target_ids = encodings.input_ids.clone()outputs = model(encodings.input_ids, labels=None) # 不传入 labelslogits = outputs.logitslabels = target_ids.to(logits.device)# Shift logits 和 labels,使它们对齐shift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()# 计算交叉熵损失loss_fct = CrossEntropyLoss(reduction='mean')loss = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))print(loss.item())

在这个例子中,我们首先不将 labels 传入模型,而是获取模型的 logits 输出。然后,手动将 logits 和 labels 进行对齐(shift),并使用 CrossEntropyLoss 计算损失。reduction=’mean’ 表示计算所有 token 的平均损失。

注意事项:

shift_logits 和 shift_labels 的目的是使预测的 logits 与对应的真实 label 对齐。contiguous() 方法用于确保张量在内存中是连续存储的,这对于某些操作是必需的。可以根据需要调整 CrossEntropyLoss 的 reduction 参数,例如设置为 ‘sum’ 来计算所有 token 的损失之和。

通过理解 Decoder-Only 模型的输入和目标,以及正确使用 labels 参数进行 Masking,可以更有效地训练和评估这些模型。同时,自定义计算损失的方法提供了更大的灵活性,可以满足不同的需求。

以上就是理解 Transformers 中的交叉熵损失与 Masked Label 问题的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1374684.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年12月14日 14:21:47
下一篇 2025年12月14日 14:22:10

相关推荐

  • Go语言高并发场景下newdefer引发的内存激增问题解析与优化

    在高并发Go应用中,尤其涉及大量`defer`语句和潜在`panic`恢复的场景,可能会遭遇`newdefer`导致的内存激增。本文将深入剖析`newdefer`内存泄漏的成因,结合`pprof`工具的诊断方法,并提供通过Go版本升级和优化错误处理机制来解决此类问题的专业指导。 Go应用中newde…

    2025年12月16日
    000
  • Go语言中如何使用反射遍历结构体字段

    本文介绍了如何使用Go语言的`reflect`包来遍历结构体中的字段。通过反射,我们可以动态地获取结构体的字段数量和字段值,并将其转换为`interface{}`类型进行处理。这在需要通用处理结构体字段的场景下非常有用,例如序列化、反序列化或数据验证等。 在Go语言中,有时我们需要遍历结构体的字段,…

    2025年12月16日
    000
  • Go语言中字符串与字节切片的比较及用户输入处理实践

    本文深入探讨go语言中`string`类型与`[]byte`切片的本质区别及其在用户输入处理中的影响。通过分析`bufio.readbytes`的行为,揭示了比较用户输入与字符串字面量时常见的问题根源,并提供了包含换行符处理、跨平台兼容性以及更推荐的`bufio.scanner`解决方案。旨在帮助开…

    2025年12月16日
    000
  • Golang如何使用reflect遍历map

    首先通过reflect.ValueOf获取map的反射值,再使用MapKeys遍历键并用MapIndex获取对应值,最后通过Interface方法还原为接口类型进行输出,实现对任意类型map的遍历。 在Go语言中,可以使用reflect包来遍历任意类型的map,尤其是在处理未知类型或需要泛型能力的场…

    2025年12月16日
    000
  • Golang如何定义全局变量与局部变量

    全局变量在函数外定义,作用域为整个包,如GlobalCounter;局部变量在函数内定义,仅在函数或代码块内有效,如calculate中的sum和count。 在Go语言中,全局变量和局部变量的定义主要通过变量声明的位置来区分。理解它们的作用域和生命周期对编写清晰、安全的代码非常重要。 全局变量的定…

    2025年12月16日
    000
  • Go语言中用户输入字符串与字节切片的比较及换行符处理指南

    本文深入探讨go语言中处理用户输入时,字符串(string)与字节切片([]byte)比较的常见问题。重点解释了两种数据类型的本质区别,并揭示了`bufio.newreader`读取操作中换行符(`n`或`rn`)被包含在内的陷阱。通过示例代码,提供了正确比较用户输入字符串的解决方案,并强调了跨平台…

    2025年12月16日
    000
  • Golang如何修改指针指向的值

    在Go中通过解引用指针并赋值即可修改其指向的值,如ptr=30;2. 函数中传入指针可修改外部变量,需确保指针非nil且已初始化。 在Go语言中,修改指针指向的值非常直接。你只需要使用星号 * 来解引用指针,然后赋新值即可。下面详细说明如何操作。 理解指针的基本概念 指针是一个变量,它存储另一个变量…

    2025年12月16日
    000
  • Golang如何使用gRPC实现多服务通信_Golang gRPC多服务通信实践详解

    使用Golang构建微服务时,gRPC基于HTTP/2和Protocol Buffers实现高效通信;2. 多服务间需定义清晰的proto接口并分文件管理;3. 通过protoc生成Go代码,可将多个服务注册到同一gRPC Server;4. 服务间通过gRPC客户端调用,如Order服务调用Use…

    2025年12月16日
    000
  • Go语言:高效将外部命令标准输出重定向到文件

    本文详细介绍了在go语言中如何将`exec.cmd`执行外部命令的标准输出直接重定向到一个文件。通过将目标文件句柄赋值给`cmd.stdout`字段,可以实现高效且简洁的输出捕获,避免了手动处理管道和并发的复杂性,是处理此类场景的推荐方法。 在Go语言中,执行外部命令是常见的操作,例如调用shell…

    2025年12月16日
    000
  • MySQL INSERT 语句可读性优化:利用 SET 语法提升代码清晰度

    本文探讨了在mysql中优化`insert`语句可读性的方法。针对传统`insert … values`语法在处理大量列时难以匹配值与列名的问题,推荐使用`insert … set`语法。这种方式能显著提升sql语句的清晰度,使开发者更容易理解和维护代码,尤其适用于go等语言…

    2025年12月16日
    000
  • Go语言中用户输入字符串与字节切片的比较及常见陷阱解析

    本文深入探讨go语言中`string`类型与`[]byte`切片的本质区别,并着重解析在处理用户输入时,`bufio.reader.readbytes`方法因包含换行符而导致的比较失败问题。通过详细解释类型特性和提供修正后的代码示例,文章旨在帮助开发者正确比较用户输入,并处理跨平台换行符及编码兼容性…

    2025年12月16日
    000
  • Golang如何写入CSV文件_Golang CSV文件写入实践详解

    Go语言通过encoding/csv包写入CSV文件,需用csv.NewWriter创建写入器并调用Write或WriteAll写入数据,每行以[]string格式传入,示例中先写入表头再批量写入记录,关键步骤包括创建文件、写入数据、延迟调用writer.Flush()确保缓冲区数据落盘。逐行写入适…

    2025年12月16日
    000
  • 如何在Golang中实现网络数据加密_Golang网络数据加密方法汇总

    答案:Golang中实现网络数据加密主要通过TLS、对称加密(如AES)和非对称加密(如RSA)结合的方式。1. 使用crypto/tls包配置证书可启用HTTPS加密,保护HTTP、gRPC等通信;2. 在TCP/UDP层可采用AES-GCM对数据加密,需共享密钥并使用随机IV防止重放攻击;3. …

    2025年12月16日
    000
  • Golang JSON 序列化:通过结构体标签控制字段输出与安全实践

    本教程将详细介绍在 Go 语言中如何高效且安全地将结构体数组序列化为 JSON。核心内容是利用 Go 的 `encoding/json` 包提供的结构体标签(`json:”-“`)来精确控制哪些字段应被包含或排除在最终的 JSON 输出中,尤其适用于处理敏感数据,确保数据传输…

    2025年12月16日
    000
  • Golang如何实现文件加锁与并发访问控制_Golang文件加锁并发控制实践详解

    Go语言中通过文件加锁机制解决多进程并发访问问题,使用syscall.Flock实现独占锁或共享锁,推荐采用github.com/go-flock/flock等第三方库简化跨平台操作,结合最小化锁持有时间、统一锁协议等最佳实践,确保文件读写安全与一致性。 在Go语言开发中,当多个进程或协程需要同时访…

    2025年12月16日
    000
  • Golang如何开发基础的事件管理系统

    答案:Go语言中通过观察者模式实现事件管理系统,核心为事件总线。定义Event结构体与事件类型常量,构建包含handlers映射和读写锁的EventBus,提供Subscribe注册处理器、Publish异步触发回调。示例中用户创建事件触发邮件通知,主函数演示注册与发布流程,系统支持解耦、并发与扩展…

    2025年12月16日
    000
  • 如何在Golang中处理Web服务器日志_Golang Web服务器日志处理方法汇总

    使用标准库log可实现基础日志输出,结合文件写入和中间件记录请求信息;2. 采用zap、logrus或slog进行结构化日志,提升可读性与分析效率;3. 通过中间件统一记录请求响应详情,包括状态码、耗时等;4. 利用rotatelogs或logrotate实现日志轮转,避免磁盘占满;5. 合理配置多…

    2025年12月16日
    000
  • 如何在Golang中使用gRPC进行安全认证

    Golang中gRPC安全认证通过TLS加密和认证机制实现,需配置双向证书认证并启用客户端与服务端证书校验,结合Per-RPC Credentials传递Token,使用拦截器在服务端验证authorization头,确保通信安全。 在Golang中使用gRPC进行安全认证,核心方式是通过TLS加密…

    2025年12月16日
    000
  • Go语言文件传输安全:深度解析FTP、SFTP、SCP与FTPS

    本文深入探讨了Go语言中文件传输的安全性问题,特别关注了传统FTP(如`goftp`库)的固有风险。我们将详细分析FTP明文传输的弱点,并介绍更安全的替代方案,包括基于SSH的SFTP和SCP,以及基于SSL/TLS的FTPS。文章还将提供在Go语言中实现这些安全协议的指导和示例,旨在帮助开发者构建…

    2025年12月16日
    000
  • Go语言:高效读取文本文件并按行处理的全面指南

    本教程详细介绍了在go语言中读取文本文件并将其内容按行存储到字符串切片中的两种主要方法。我们将探讨使用`ioutil.readfile`结合`strings.split`的简洁方式,以及利用`bufio.scanner`进行高效逐行处理的策略,并提供相应的代码示例和最佳实践,帮助开发者根据文件大小和…

    2025年12月16日
    000

发表回复

登录后才能评论
关注微信