PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践

PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践

本文深入探讨了pytorch中`crossentropyloss`常见的`runtimeerror: expected scalar type long but found float`错误。该错误通常源于目标标签(target)的数据类型不符合`crossentropyloss`的预期。我们将详细解析错误原因,并提供如何在训练循环中正确使用`crossentropyloss`,包括标签类型转换、输入顺序以及避免重复应用softmax等关键最佳实践,以确保模型训练的稳定性和准确性。

深度学习的分类任务中,torch.nn.CrossEntropyLoss是一个非常常用的损失函数。它结合了LogSoftmax和负对数似然损失(NLLLoss),能够高效地处理多分类问题。然而,初学者在使用时常会遇到一个特定的运行时错误:RuntimeError: expected scalar type Long but found Float。这个错误明确指出,CrossEntropyLoss在处理其目标标签(target)时,期望的数据类型是torch.Long(即64位整数),但实际接收到的是torch.Float。

理解CrossEntropyLoss的工作原理

CrossEntropyLoss函数在PyTorch中通常接收两个主要参数:

input (或 logits):这是模型的原始输出,通常是未经Softmax激活函数处理的“对数几率”(logits)。它的形状通常是 (N, C),其中 N 是批量大小,C 是类别数量。对于图像任务,如果模型输出是像素级别的分类(如U-Net),则形状可能是 (N, C, H, W)。target (或 labels):这是真实的类别标签。它应该包含每个样本的类别索引,其数据类型必须是torch.long(或torch.int64)。它的形状通常是 (N),对于像素级别的分类,形状可能是 (N, H, W)。target中的值应介于 0 到 C-1 之间,代表对应的类别索引。

关键点: CrossEntropyLoss内部会自行执行Softmax操作,因此,向其传入经过Softmax处理的概率值是不正确的,这可能导致数值不稳定或不准确的损失计算。

RuntimeError: expected scalar type Long but found Float 错误解析与修正

这个错误的核心在于target张量的数据类型不匹配。在提供的代码片段中,错误发生在以下这行:

loss = criterion(output, labels.float())

尽管labels张量在创建时已经被明确指定为long类型:

labels = Variable(torch.FloatTensor(10).uniform_(0, 120).long())

但在计算损失时,又通过.float()方法将其强制转换回了float类型。这就是导致CrossEntropyLoss抛出错误的原因。

修正方法:只需移除对labels的.float()调用,确保target张量保持其long类型即可。

# 错误代码# loss = criterion(output, labels.float())# 正确代码loss = criterion(output, labels)

训练循环中的常见误用及修正

除了上述直接的类型转换错误,在提供的train_one_epoch函数中,也存在一些与CrossEntropyLoss使用相关的常见误区。

1. 标签数据类型转换错误

在train_one_epoch函数内部,标签被错误地转换成了float类型:

labels = labels.to(device).float() # 错误:将标签转换为float类型

这会直接导致CrossEntropyLoss接收到float类型的标签,再次触发同样的RuntimeError。

修正方法:确保标签在送入损失函数前是long类型。

labels = labels.to(device).long() # 正确:将标签转换为long类型

2. CrossEntropyLoss输入参数顺序和类型错误

在train_one_epoch函数中,计算损失的行是:

Spacely AI Spacely AI

为您的房间提供AI室内设计解决方案,寻找无限的创意

Spacely AI 67 查看详情 Spacely AI

loss = criterion(labels, torch.argmax(outputs, dim=1)) # 错误:参数顺序和类型不符

这里存在两个问题:

参数顺序错误: criterion(即CrossEntropyLoss)期望的第一个参数是模型的输出(logits),第二个参数是真实标签(target)。这里却反了过来。target参数类型错误: torch.argmax(outputs, dim=1) 已经是一个预测结果的类别索引,它不应该作为CrossEntropyLoss的target参数传入。target参数应是真实的、未经模型处理的类别标签。

修正方法:将模型的原始输出(logits)作为第一个参数,真实的long类型标签作为第二个参数。

3. 预先应用Softmax的错误

在计算outputs时,代码中显式地应用了F.softmax:

outputs = F.softmax(model(inputs.float()), dim=1) # 错误:CrossEntropyLoss内部已包含Softmax

由于CrossEntropyLoss内部已经包含了Softmax操作,再次应用F.softmax会导致:

冗余计算: 增加了不必要的计算开销。数值稳定性问题: 两次Softmax操作可能导致数值精度下降,尤其是在处理非常大或非常小的对数几率时。

修正方法:直接将模型的原始输出(logits)传递给CrossEntropyLoss。

优化后的训练函数示例

综合以上修正,以下是train_one_epoch函数的一个优化版本,遵循了CrossEntropyLoss的最佳实践:

import torchimport torch.nn as nnimport torch.nn.functional as Fimport time# 假设 model, optimizer, dataloaders, device 已经定义def train_one_epoch(model, optimizer, data_loader, device):    model.train()    running_loss = 0.0    start_time = time.time()    total = 0    correct = 0    # 确保 data_loader 是实际的 DataLoader 对象    # 这里假设 dataloaders['train'] 是一个可迭代的 DataLoader    current_data_loader = data_loader # 如果传入的是字符串'train',需要根据实际情况获取    if isinstance(data_loader, str):        current_data_loader = dataloaders[data_loader] # 假设 dataloaders 是一个全局字典    for i, (inputs, labels) in enumerate(current_data_loader):        inputs = inputs.to(device)        # 核心修正:确保标签是long类型        labels = labels.to(device).long()         optimizer.zero_grad()        # 修正:直接使用模型的原始输出(logits),不应用Softmax        # 假设 model(inputs.float()) 返回的是 logits        logits = model(inputs.float())         # 打印形状以调试        # print("Inputs shape:", inputs.shape)        # print("Logits shape:", logits.shape)        # print("Labels shape:", labels.shape)        # 修正:CrossEntropyLoss的正确使用方式是 (logits, target_indices)        loss = criterion(logits, labels)         loss.backward()        optimizer.step()        # 计算准确率时,需要对logits应用argmax        _, predicted = torch.max(logits.data, 1)        total += labels.size(0)        correct += (predicted == labels).sum().item()        accuracy = 100 * correct / total        running_loss += loss.item()        if i % 10 == 0:    # print every 10 batches            batch_time = time.time()            speed = (i+1)/(batch_time-start_time)            print('[%5d] loss: %.3f, speed: %.2f, accuracy: %.2f %%' %                  (i, running_loss, speed, accuracy))            running_loss = 0.0            total = 0            correct = 0

验证模型函数 (val_model) 的注意事项

val_model函数在处理标签时使用了labels = labels.to(device).long(),这是正确的。同时,outputs = model(inputs.float()) 假设模型输出的是logits,然后用 torch.max(outputs.data, 1) 来获取预测类别,这也是标准做法。

唯一需要注意的是,model.val() 应该更正为 model.eval(),这会将模型设置为评估模式,禁用Dropout和BatchNorm等层,以确保评估结果的稳定性。

def val_model(model, data_loader, device): # 添加 device 参数    model.eval() # 修正:使用 model.eval()    start_time = time.time()    total = 0    correct = 0    current_data_loader = data_loader    if isinstance(data_loader, str):        current_data_loader = dataloaders[data_loader]    with torch.no_grad():        for i, (inputs, labels) in enumerate(current_data_loader):            inputs = inputs.to(device)            labels = labels.to(device).long() # 正确            outputs = model(inputs.float()) # 假设 model 输出 logits            _, predicted = torch.max(outputs.data, 1)            total += labels.size(0)            # 修正:(predicted == labels).sum() 返回一个标量,直接 .item() 即可            correct += (predicted == labels).sum().item()         accuracy = 100 * correct / total        print('Finished Testing')        print('Testing accuracy: %.1f %%' %(accuracy))

总结与最佳实践

处理PyTorch中的CrossEntropyLoss时,请牢记以下关键点:

目标标签的数据类型: CrossEntropyLoss的target参数必须是torch.long类型(即64位整数),且包含类别索引(从0到C-1)。模型输出: CrossEntropyLoss的input参数应是模型的原始输出(logits),即未经Softmax激活函数处理的对数几率。避免重复Softmax: 不要在将模型输出传递给CrossEntropyLoss之前手动应用F.softmax,因为CrossEntropyLoss内部已经包含了此操作。参数顺序: CrossEntropyLoss的调用格式是 loss = criterion(logits, target_labels)。评估模式: 在验证或测试模型时,务必使用model.eval()来设置模型为评估模式,并在torch.no_grad()上下文管理器中执行前向传播,以节省内存和计算。

遵循这些原则,可以有效避免RuntimeError: expected scalar type Long but found Float以及其他与CrossEntropyLoss使用相关的常见问题,确保模型训练的顺利进行。

以上就是PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/847285.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月27日 15:56:23
下一篇 2025年11月27日 15:57:05

相关推荐

  • 怎样用免费工具美化PPT_免费美化PPT的实用方法分享

    利用KIMI智能助手可免费将PPT美化为科技感风格,但需核对文字准确性;2. 天工AI擅长优化内容结构,提升逻辑性,适合高质量内容需求;3. SlidesAI支持语音输入与自动排版,操作便捷,利于紧急场景;4. Prezo提供多种模板,自动生成图文并茂幻灯片,适合学生与初创团队。 如果您有一份内容完…

    2025年12月6日 软件教程
    000
  • Pages怎么协作编辑同一文档 Pages多人实时协作的流程

    首先启用Pages共享功能,点击右上角共享按钮并选择“添加协作者”,设置为可编辑并生成链接;接着复制链接通过邮件或社交软件发送给成员,确保其使用Apple ID登录iCloud后即可加入编辑;也可直接在共享菜单中输入邮箱地址定向邀请,设定编辑权限后发送;最后在共享面板中管理协作者权限,查看实时在线状…

    2025年12月6日 软件教程
    100
  • REDMI K90系列正式发布,售价2599元起!

    10月23日,redmi k90系列正式亮相,推出redmi k90与redmi k90 pro max两款新机。其中,redmi k90搭载骁龙8至尊版处理器、7100mah大电池及100w有线快充等多项旗舰配置,起售价为2599元,官方称其为k系列迄今为止最完整的标准版本。 图源:REDMI红米…

    2025年12月6日 行业动态
    200
  • Linux中如何安装Nginx服务_Linux安装Nginx服务的完整指南

    首先更新系统软件包,然后通过对应包管理器安装Nginx,启动并启用服务,开放防火墙端口,最后验证欢迎页显示以确认安装成功。 在Linux系统中安装Nginx服务是搭建Web服务器的第一步。Nginx以高性能、低资源消耗和良好的并发处理能力著称,广泛用于静态内容服务、反向代理和负载均衡。以下是在主流L…

    2025年12月6日 运维
    000
  • Linux journalctl与systemctl status结合分析

    先看 systemctl status 确认服务状态,再用 journalctl 查看详细日志。例如 nginx 启动失败时,systemctl status 显示 Active: failed,journalctl -u nginx 发现端口 80 被占用,结合两者可快速定位问题根源。 在 Lin…

    2025年12月6日 运维
    100
  • 华为新机发布计划曝光:Pura 90系列或明年4月登场

    近日,有数码博主透露了华为2025年至2026年的新品规划,其中pura 90系列预计在2026年4月发布,有望成为华为新一代影像旗舰。根据路线图,华为将在2025年底至2026年陆续推出mate 80系列、折叠屏新机mate x7系列以及nova 15系列,而pura 90系列则将成为2026年上…

    2025年12月6日 行业动态
    100
  • TikTok视频无法下载怎么办 TikTok视频下载异常修复方法

    先检查链接格式、网络设置及工具版本。复制以https://www.tiktok.com/@或vm.tiktok.com开头的链接,删除?后参数,尝试短链接;确保网络畅通,可切换地区节点或关闭防火墙;更新工具至最新版,优先选用yt-dlp等持续维护的工具。 遇到TikTok视频下载不了的情况,别急着换…

    2025年12月6日 软件教程
    100
  • Linux如何优化系统性能_Linux系统性能优化的实用方法

    优化Linux性能需先监控资源使用,通过top、vmstat等命令分析负载,再调整内核参数如TCP优化与内存交换,结合关闭无用服务、选用合适文件系统与I/O调度器,持续按需调优以提升系统效率。 Linux系统性能优化的核心在于合理配置资源、监控系统状态并及时调整瓶颈环节。通过一系列实用手段,可以显著…

    2025年12月6日 运维
    000
  • 曝小米17 Air正在筹备 超薄机身+2亿像素+eSIM技术?

    近日,手机行业再度掀起超薄机型热潮,三星与苹果已相继推出s25 edge与iphone air等轻薄旗舰,引发市场高度关注。在此趋势下,多家国产厂商被曝正积极布局相关技术,加速抢占这一细分赛道。据业内人士消息,小米的超薄旗舰机型小米17 air已进入筹备阶段。 小米17 Pro 爆料显示,小米正在评…

    2025年12月6日 行业动态
    000
  • 「世纪传奇刀片新篇」飞利浦影音双11声宴开启

    百年声学基因碰撞前沿科技,一场有关声音美学与设计美学的影音狂欢已悄然引爆2025“双十一”! 当绝大多数影音数码品牌还在价格战中挣扎时,飞利浦影音已然开启了一场跨越百年的“声”活革命。作为拥有深厚技术底蕴的音频巨头,飞利浦影音及配件此次“双十一”精准聚焦“传承经典”与“设计美学”两大核心,为热爱生活…

    2025年12月6日 行业动态
    000
  • 荣耀手表5Pro 10月23日正式开启首销国补优惠价1359.2元起售

    荣耀手表5pro自9月25日开启全渠道预售以来,市场热度持续攀升,上市初期便迎来抢购热潮,一度出现全线售罄、供不应求的局面。10月23日,荣耀手表5pro正式迎来首销,提供蓝牙版与esim版两种选择。其中,蓝牙版本的攀登者(橙色)、开拓者(黑色)和远航者(灰色)首销期间享受国补优惠价,到手价为135…

    2025年12月6日 行业动态
    000
  • Vue.js应用中配置环境变量:灵活管理后端通信地址

    在%ignore_a_1%应用中,灵活配置后端api地址等参数是开发与部署的关键。本文将详细介绍两种主要的环境变量配置方法:推荐使用的`.env`文件,以及通过`cross-env`库在命令行中设置环境变量。通过这些方法,开发者可以轻松实现开发、测试、生产等不同环境下配置的动态切换,提高应用的可维护…

    2025年12月6日 web前端
    000
  • JavaScript动态生成日历式水平日期布局的优化实践

    本教程将指导如何使用javascript高效、正确地动态生成html表格中的日历式水平日期布局。重点解决直接操作`innerhtml`时遇到的标签闭合问题,通过数组构建html字符串来避免浏览器解析错误,并利用事件委托机制优化动态生成元素的事件处理,确保生成结构清晰、功能完善的日期展示。 在前端开发…

    2025年12月6日 web前端
    000
  • 环境搭建docker环境下如何快速部署mysql集群

    使用Docker Compose部署MySQL主从集群,通过配置文件设置server-id和binlog,编写docker-compose.yml定义主从服务并组网,启动后创建复制用户并配置主从连接,最后验证数据同步是否正常。 在Docker环境下快速部署MySQL集群,关键在于合理使用Docker…

    2025年12月6日 数据库
    000
  • Xbox删忍龙美女角色 斯宾塞致敬板垣伴信被喷太虚伪

    近日,海外游戏推主@HaileyEira公开发表言论,批评Xbox负责人菲尔·斯宾塞不配向已故的《死或生》与《忍者龙剑传》系列之父板垣伴信致敬。她指出,Xbox并未真正尊重这位传奇制作人的创作遗产,反而在宣传相关作品时对内容进行了审查和删减。 所涉游戏为年初推出的《忍者龙剑传2:黑之章》,该作采用虚…

    2025年12月6日 游戏教程
    000
  • 如何在mysql中分析索引未命中问题

    答案是通过EXPLAIN分析执行计划,检查索引使用情况,优化WHERE条件写法,避免索引失效,结合慢查询日志定位问题SQL,并根据查询模式合理设计索引。 当 MySQL 查询性能下降,很可能是索引未命中导致的。要分析这类问题,核心是理解查询执行计划、检查索引设计是否合理,并结合实际数据访问模式进行优…

    2025年12月6日 数据库
    000
  • VSCode入门:基础配置与插件推荐

    刚用VSCode,别急着装一堆东西。先把基础设好,再按需求加插件,效率高还不卡。核心就三步:界面顺手、主题舒服、功能够用。 设置中文和常用界面 打开软件,左边活动栏有五个图标,点最下面那个“扩展”。搜索“Chinese”,装上官方出的“Chinese (Simplified) Language Pa…

    2025年12月6日 开发工具
    000
  • VSCode性能分析与瓶颈诊断技术

    首先通过资源监控定位异常进程,再利用开发者工具分析性能瓶颈,结合禁用扩展、优化语言服务器配置及项目设置,可有效解决VSCode卡顿问题。 VSCode作为主流的代码编辑器,虽然轻量高效,但在处理大型项目或配置复杂扩展时可能出现卡顿、响应延迟等问题。要解决这些性能问题,需要系统性地进行性能分析与瓶颈诊…

    2025年12月6日 开发工具
    000
  • php查询代码怎么写_php数据库查询语句编写技巧与实例

    在PHP中进行数据库查询,最常用的方式是使用MySQLi或PDO扩展连接MySQL数据库。下面介绍基本的查询代码写法、编写技巧以及实用示例,帮助你高效安全地操作数据库。 1. 使用MySQLi进行查询(面向对象方式) 这是较为推荐的方式,适合大多数中小型项目。 // 创建连接$host = ‘loc…

    2025年12月6日 后端开发
    000
  • php数据库如何实现数据缓存 php数据库减少查询压力的方案

    答案:PHP结合Redis等内存缓存系统可显著提升Web应用性能。通过将用户信息、热门数据等写入内存缓存并设置TTL,先查缓存未命中再查数据库,减少数据库压力;配合OPcache提升脚本执行效率,文件缓存适用于小型项目,数据库缓冲池优化和读写分离进一步提升性能,推荐Redis为主并防范缓存穿透与雪崩…

    2025年12月6日 后端开发
    000

发表回复

登录后才能评论
关注微信