PyTorch二分类模型精度计算陷阱解析与跨框架对比实践

PyTorch二分类模型精度计算陷阱解析与跨框架对比实践

本文深入探讨了PyTorch二分类模型在精度计算时可能遇到的常见陷阱,特别是当与TensorFlow的评估结果进行对比时出现的显著差异。通过分析一个具体的案例,文章揭示了PyTorch中一个易被忽视的精度计算错误,并提供了正确的实现方式,旨在帮助开发者避免此类问题,确保模型评估的准确性和一致性。

1. 问题现象:PyTorch与TensorFlow的精度差异

深度学习模型开发过程中,开发者常会遇到在不同框架下实现相似模型时,评估指标出现显著差异的情况。一个典型的二分类问题中,我们观察到以下现象:使用pytorch实现的模型在测试集上仅获得约2.5%的精度,而结构和配置几乎相同的tensorflow模型却能达到约86%的精度。这种巨大的差异通常不是由模型性能本身引起,而是暗示了其中一个框架的评估逻辑可能存在根本性错误。

2. 模型结构与训练配置概览

为了更好地理解问题,我们首先审视两个框架中模型的结构和训练配置。

2.1 PyTorch模型与训练设置

PyTorch模型是一个简单的多层感知机(MLP),包含两个ReLU激活的隐藏层和一个Sigmoid激活的输出层,适用于二分类任务。

import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import DataLoader, TensorDatasetfrom sklearn.model_selection import train_test_splitimport pandas as pdimport numpy as np# 假设数据加载和预处理已完成# data = pd.read_csv('your_data.csv')# train, test = train_test_split(data, test_size=0.056, random_state=42)# train_X_np = train[["A","B","C", "D"]].to_numpy()# test_X_np = test[["A","B", "C", "D"]].to_numpy()# train_Y_np = train[["label"]].to_numpy()# test_Y_np = test[["label"]].to_numpy()# train_X = torch.tensor(train_X_np, dtype=torch.float32)# test_X = torch.tensor(test_X_np, dtype=torch.float32)# train_Y = torch.tensor(train_Y_np, dtype=torch.float32)# test_Y = torch.tensor(test_Y_np, dtype=torch.float32)# train_dataset = TensorDataset(train_X, train_Y)# batch_size = 64# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)class SimpleClassifier(nn.Module):    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):        super(SimpleClassifier, self).__init__()        self.fc1 = nn.Linear(input_size, hidden_size1)        self.relu1 = nn.ReLU()        self.fc2 = nn.Linear(hidden_size1, hidden_size2)        self.relu2 = nn.ReLU()        self.fc3 = nn.Linear(hidden_size2, output_size)        self.sigmoid = nn.Sigmoid()    def forward(self, x):        x = self.relu1(self.fc1(x))        x = self.relu2(self.fc2(x))        x = self.sigmoid(self.fc3(x))        return x# input_size = train_X.shape[1]# hidden_size1 = 64# hidden_size2 = 32# output_size = 1# model = SimpleClassifier(input_size, hidden_size1, hidden_size2, output_size)# criterion = nn.BCELoss()# optimizer = optim.Adam(model.parameters(), lr=0.001)# # 原始PyTorch训练循环中的评估部分(存在错误)# num_epochs = 50# for epoch in range(num_epochs):#     # ... (训练代码略)#     with torch.no_grad():#         model.eval()#         predictions = model(test_X).squeeze()#         predictions_binary = (predictions.round()).float()#         accuracy = torch.sum(predictions_binary == test_Y) / (len(test_Y) * 100) # 错误在此行#         if(epoch%25 == 0):#           print("Epoch " + str(epoch) + " passed. Test accuracy is {:.2f}%".format(accuracy))

PyTorch模型使用nn.BCELoss作为损失函数,optim.Adam作为优化器。问题主要出现在评估阶段的精度计算逻辑。

2.2 TensorFlow模型与训练设置

TensorFlow模型同样使用Keras的Sequential API构建了一个相似的MLP结构。

from tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Dense# import numpy as np # 假设 train_X, train_Y, test_X, test_Y 已经准备好为 numpy 数组# # 假设数据加载和预处理已完成# # model_tf = Sequential()# # model_tf.add(Dense(64, input_dim=len(train_X[0]), activation='relu'))# # model_tf.add(Dense(32, activation='relu'))# # model_tf.add(Dense(1, activation='sigmoid'))# # Compile the model# # model_tf.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])# # model_tf.fit(train_X, train_Y, epochs=50, batch_size=64, verbose=0)# # Evaluate the model# # loss_tf, accuracy_tf = model_tf.evaluate(test_X, test_Y, verbose=0)# # print(f"Loss: {loss_tf}, Accuracy: {accuracy_tf}")

TensorFlow模型在编译时直接指定了metrics=[‘accuracy’],这使得其在训练和评估时能够自动计算并报告正确的精度。

通过对比可以看出,两个框架的模型结构、损失函数和优化器选择都非常相似,主要的差异在于PyTorch的精度计算是手动实现,而TensorFlow则使用了内置的可靠指标。

千帆AppBuilder 千帆AppBuilder

百度推出的一站式的AI原生应用开发资源和工具平台,致力于实现人人都能开发自己的AI原生应用。

千帆AppBuilder 158 查看详情 千帆AppBuilder

3. PyTorch精度计算的症结所在

问题的核心在于PyTorch评估代码中的精度计算方式。

3.1 错误代码分析

原始PyTorch代码中的精度计算如下:

accuracy = torch.sum(predictions_binary == test_Y) / (len(test_Y) * 100)

让我们逐步分析这行代码:

predictions_binary == test_Y:这是一个布尔张量,表示每个预测是否与真实标签匹配。torch.sum(…):计算布尔张量中 True 的数量,即正确分类的样本数。len(test_Y):获取测试集中的总样本数。(len(test_Y) * 100):这是问题的关键所在。分母被错误地乘以了100。

正确的精度计算逻辑应该是:(正确分类样本数 / 总样本数) * 100%。例如,如果有86个正确预测和100个总样本,实际精度应为 (86 / 100) * 100% = 86%。然而,原始代码的计算是 (86 / (100 * 100)),即 86 / 10000 = 0.0086。如果再将其格式化为百分比,就会显示为 0.86%,或者在某些情况下,如果期望输出的是0-100的数值,则会是 0.86,与86%相去甚远。原始代码中 format(“{:.2f}%”.format(accuracy)) 会将 0.0086 格式化为 0.86%,而不是 86.00%。因此,PyTorch代码中2.5%的低精度实际上是由于计算公式中分母多乘了一个100,导致最终结果被额外缩小了100倍。

3.2 正确的精度计算方法

为了获得正确的百分比精度,我们需要修正计算公式:

# 假设 predictions_binary 是模型输出经过 Sigmoid 后,再四舍五入得到的二值预测 (0或1)# 假设 test_Y 是真实的二值标签 (0或1)# 计算正确预测的数量correct_predictions = (predictions_binary == test_Y).sum().item()# 获取总样本数total_samples = test_Y.size(0) # 或者 len(test_Y)# 计算精度(0-100

以上就是PyTorch二分类模型精度计算陷阱解析与跨框架对比实践的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/922321.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月29日 08:29:06
下一篇 2025年11月29日 08:29:25

相关推荐

  • 用了一个星期的S25 Ultra,我有这些体验想和你分享一下

    三星galaxy s25 ultra:轻薄机身与ai赋能的完美融合 “均衡的手机千篇一律,有趣的手机万里挑一。”在手机市场同质化竞争日益激烈的今天,这句话或许道出了许多消费者的内心呼声。然而,三星Galaxy S系列却始终凭借其均衡的配置和体验,成为市场上的佼佼者。而全新发布的三星Galaxy S2…

    2025年12月6日 硬件教程
    000
  • 荣耀开始安排 6.3-6.5 英寸中小尺寸机型?两款新机曝光

    荣耀将推出中小尺寸屏幕新机型!据数码闲聊站爆料,荣耀计划发布两款中端机型,分别采用6.5英寸左右1.5k直屏和6.78英寸左右1.5k等深四曲屏,均配备7000毫安时以上大电池,并搭载骁龙7 gen 4处理器(sm7750),预计上半年发布。 爆料显示,荣耀正在积极布局中小尺寸手机市场,目前已启动6…

    2025年12月6日 硬件教程
    000
  • 怎样用免费工具美化PPT_免费美化PPT的实用方法分享

    利用KIMI智能助手可免费将PPT美化为科技感风格,但需核对文字准确性;2. 天工AI擅长优化内容结构,提升逻辑性,适合高质量内容需求;3. SlidesAI支持语音输入与自动排版,操作便捷,利于紧急场景;4. Prezo提供多种模板,自动生成图文并茂幻灯片,适合学生与初创团队。 如果您有一份内容完…

    2025年12月6日 软件教程
    000
  • Pages怎么协作编辑同一文档 Pages多人实时协作的流程

    首先启用Pages共享功能,点击右上角共享按钮并选择“添加协作者”,设置为可编辑并生成链接;接着复制链接通过邮件或社交软件发送给成员,确保其使用Apple ID登录iCloud后即可加入编辑;也可直接在共享菜单中输入邮箱地址定向邀请,设定编辑权限后发送;最后在共享面板中管理协作者权限,查看实时在线状…

    2025年12月6日 软件教程
    000
  • REDMI K90系列正式发布,售价2599元起!

    10月23日,redmi k90系列正式亮相,推出redmi k90与redmi k90 pro max两款新机。其中,redmi k90搭载骁龙8至尊版处理器、7100mah大电池及100w有线快充等多项旗舰配置,起售价为2599元,官方称其为k系列迄今为止最完整的标准版本。 图源:REDMI红米…

    2025年12月6日 行业动态
    000
  • Linux中如何安装Nginx服务_Linux安装Nginx服务的完整指南

    首先更新系统软件包,然后通过对应包管理器安装Nginx,启动并启用服务,开放防火墙端口,最后验证欢迎页显示以确认安装成功。 在Linux系统中安装Nginx服务是搭建Web服务器的第一步。Nginx以高性能、低资源消耗和良好的并发处理能力著称,广泛用于静态内容服务、反向代理和负载均衡。以下是在主流L…

    2025年12月6日 运维
    000
  • Linux journalctl与systemctl status结合分析

    先看 systemctl status 确认服务状态,再用 journalctl 查看详细日志。例如 nginx 启动失败时,systemctl status 显示 Active: failed,journalctl -u nginx 发现端口 80 被占用,结合两者可快速定位问题根源。 在 Lin…

    2025年12月6日 运维
    000
  • 华为新机发布计划曝光:Pura 90系列或明年4月登场

    近日,有数码博主透露了华为2025年至2026年的新品规划,其中pura 90系列预计在2026年4月发布,有望成为华为新一代影像旗舰。根据路线图,华为将在2025年底至2026年陆续推出mate 80系列、折叠屏新机mate x7系列以及nova 15系列,而pura 90系列则将成为2026年上…

    2025年12月6日 行业动态
    000
  • Linux如何优化系统性能_Linux系统性能优化的实用方法

    优化Linux性能需先监控资源使用,通过top、vmstat等命令分析负载,再调整内核参数如TCP优化与内存交换,结合关闭无用服务、选用合适文件系统与I/O调度器,持续按需调优以提升系统效率。 Linux系统性能优化的核心在于合理配置资源、监控系统状态并及时调整瓶颈环节。通过一系列实用手段,可以显著…

    2025年12月6日 运维
    000
  • 首款鸿蒙电脑惊艳亮相,华为重构电脑产业新格局

    华为鸿蒙电脑技术与生态沟通会隆重举行,首款鸿蒙电脑惊艳登场,这一标志性事件预示着华为在电脑领域迈出了具有深远影响的关键一步,为国产电脑产业带来了全新的革新与发展契机。 鸿蒙电脑的推出并非一朝一夕之功,而是华为经过五年精心策划的结果。在此期间,华为汇聚了超过10000名顶尖工程师,与20多家专业研究所…

    2025年12月6日 硬件教程
    000
  • 曝小米17 Air正在筹备 超薄机身+2亿像素+eSIM技术?

    近日,手机行业再度掀起超薄机型热潮,三星与苹果已相继推出s25 edge与iphone air等轻薄旗舰,引发市场高度关注。在此趋势下,多家国产厂商被曝正积极布局相关技术,加速抢占这一细分赛道。据业内人士消息,小米的超薄旗舰机型小米17 air已进入筹备阶段。 小米17 Pro 爆料显示,小米正在评…

    2025年12月6日 行业动态
    000
  • 「世纪传奇刀片新篇」飞利浦影音双11声宴开启

    百年声学基因碰撞前沿科技,一场有关声音美学与设计美学的影音狂欢已悄然引爆2025“双十一”! 当绝大多数影音数码品牌还在价格战中挣扎时,飞利浦影音已然开启了一场跨越百年的“声”活革命。作为拥有深厚技术底蕴的音频巨头,飞利浦影音及配件此次“双十一”精准聚焦“传承经典”与“设计美学”两大核心,为热爱生活…

    2025年12月6日 行业动态
    000
  • 荣耀手表5Pro 10月23日正式开启首销国补优惠价1359.2元起售

    荣耀手表5pro自9月25日开启全渠道预售以来,市场热度持续攀升,上市初期便迎来抢购热潮,一度出现全线售罄、供不应求的局面。10月23日,荣耀手表5pro正式迎来首销,提供蓝牙版与esim版两种选择。其中,蓝牙版本的攀登者(橙色)、开拓者(黑色)和远航者(灰色)首销期间享受国补优惠价,到手价为135…

    2025年12月6日 行业动态
    000
  • 软硬一体、AI牵引斑马智行推动国产心片释放算力效能

    堆砌了硬件的智能座舱,为何仍难逃“卡顿、无聊”的用户诟病?在刚刚落幕的2025年中国工程学会年会上,行业达成共识:芯片算力只是燃料,真正决定汽车智能化上限的,是基础软件与ai大模型。 多位专家在会上指出,软件定义汽车已迈入“云端一体大模型”新阶段。以AI为核心的软件能力正成为提升用户体验的关键驱动力…

    2025年12月6日 行业动态
    000
  • 环境搭建docker环境下如何快速部署mysql集群

    使用Docker Compose部署MySQL主从集群,通过配置文件设置server-id和binlog,编写docker-compose.yml定义主从服务并组网,启动后创建复制用户并配置主从连接,最后验证数据同步是否正常。 在Docker环境下快速部署MySQL集群,关键在于合理使用Docker…

    2025年12月6日 数据库
    000
  • 解决MongoDB连接错误:正确使用MongoClient进行数据库连接

    本教程旨在解决初次使用mongodb时常见的“mongodb.connect is not a function”错误。我们将详细介绍如何使用mongodb官方驱动中的`mongoclient`类建立稳定的数据库连接,并结合express.js框架,采用现代化的`async/await`语法实现高效…

    2025年12月6日 web前端
    000
  • 李楠谈iPhone Air:如果是乔布斯的话 估计早就做出来了

    10月25日消息,怒喵科技创始人李楠称,iphone air的续航表现与iphone 17相当,他感慨道:“如果是乔布斯在位,这台设备或许早就问世了。如果能提前几年推出,市场反响可能会更加热烈。” 他还评价说,iPhone Air是近十年来最出色的iPhone产品。无论是在材质选择、工艺精度、整体设…

    2025年12月6日 行业动态
    000
  • Xbox删忍龙美女角色 斯宾塞致敬板垣伴信被喷太虚伪

    近日,海外游戏推主@HaileyEira公开发表言论,批评Xbox负责人菲尔·斯宾塞不配向已故的《死或生》与《忍者龙剑传》系列之父板垣伴信致敬。她指出,Xbox并未真正尊重这位传奇制作人的创作遗产,反而在宣传相关作品时对内容进行了审查和删减。 所涉游戏为年初推出的《忍者龙剑传2:黑之章》,该作采用虚…

    2025年12月6日 游戏教程
    000
  • 如何在mysql中分析索引未命中问题

    答案是通过EXPLAIN分析执行计划,检查索引使用情况,优化WHERE条件写法,避免索引失效,结合慢查询日志定位问题SQL,并根据查询模式合理设计索引。 当 MySQL 查询性能下降,很可能是索引未命中导致的。要分析这类问题,核心是理解查询执行计划、检查索引设计是否合理,并结合实际数据访问模式进行优…

    2025年12月6日 数据库
    000
  • VSCode入门:基础配置与插件推荐

    刚用VSCode,别急着装一堆东西。先把基础设好,再按需求加插件,效率高还不卡。核心就三步:界面顺手、主题舒服、功能够用。 设置中文和常用界面 打开软件,左边活动栏有五个图标,点最下面那个“扩展”。搜索“Chinese”,装上官方出的“Chinese (Simplified) Language Pa…

    2025年12月6日 开发工具
    000

发表回复

登录后才能评论
关注微信