PyTorch二分类模型准确率计算陷阱与修正:对比TensorFlow实践

PyTorch二分类模型准确率计算陷阱与修正:对比TensorFlow实践

本文旨在解决PyTorch二分类模型训练过程中,准确率计算可能出现的常见错误,导致结果远低于预期。通过对比TensorFlow的实现,我们将深入分析PyTorch代码中准确率计算的陷阱,并提供正确的计算公式与实践方法,确保模型性能评估的准确性。

1. 问题背景与现象分析

深度学习二分类任务中,模型性能通常通过准确率(accuracy)来衡量。然而,开发者在使用不同深度学习框架(如pytorch和tensorflow)实现相同模型时,可能会遇到准确率计算结果显著不同的情况。一个常见的问题是,pytorch代码计算出的准确率远低于预期,而tensorflow则表现正常。这往往不是模型本身的差异,而是准确率计算逻辑上的细微错误。

例如,在以下PyTorch二分类模型评估代码中,可能会出现准确率仅为2.5%的异常情况:

# 原始PyTorch准确率计算片段# ...with torch.no_grad():    model.eval()    predictions = model(test_X).squeeze() # 模型输出经过Sigmoid,范围在0-1之间    predictions_binary = (predictions.round()).float() # 四舍五入到0或1    accuracy = torch.sum(predictions_binary == test_Y) / (len(test_Y) * 100) # 错误的计算方式    if(epoch%25 == 0):      print("Epoch " + str(epoch) + " passed. Test accuracy is {:.2f}%".format(accuracy))# ...

而使用等效的TensorFlow代码,通常能得到合理的准确率(例如86%):

# TensorFlow模型训练与评估片段# ...model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])model.fit(train_X, train_Y, epochs=50, batch_size=64)loss, accuracy = model.evaluate(test_X, test_Y)print(f"Loss: {loss}, Accuracy: {accuracy}")# ...

这种差异的核心原因在于PyTorch代码中准确率计算公式的误用。

2. PyTorch准确率计算错误剖析

上述PyTorch代码中的准确率计算错误主要体现在以下一行:

accuracy = torch.sum(predictions_binary == test_Y) / (len(test_Y) * 100)

具体分析如下:

除法顺序错误

为了得到百分比形式的准确率,正确的计算流程应该是:(正确预测数 / 总样本数) * 100。然而,原始代码中的 /(len(test_Y) * 100) 实际上是将正确预测数除以 (总样本数 * 100),这导致结果被额外除以了100,从而使得准确率数值变得非常小(例如,86%的准确率会变成0.86%)。

torch.sum() 返回张量

torch.sum(predictions_binary == test_Y) 返回的是一个包含正确预测数量的张量(tensor),而不是一个标量(scalar)。虽然PyTorch在某些情况下可以自动进行类型转换,但为了代码的健壮性和清晰性,通常建议使用 .item() 方法将其转换为Python数值类型,尤其是在进行标量运算时。

3. PyTorch中二分类准确率的正确计算方法

要修正PyTorch中的准确率计算,我们需要调整公式以确保正确的百分比转换,并处理好张量到标量的转换。

修正后的准确率计算代码:

# 修正后的PyTorch准确率计算片段# ...with torch.no_grad():    model.eval()    # 确保模型输出和标签形状一致,这里假设test_Y是(N, 1)或(N,)    # 如果model(test_X)输出是(N, 1),则不需要.squeeze()    # 如果model(test_X)输出是(N, 1)且test_Y是(N,),则需要.squeeze()其中一个    # 这里我们假设test_Y是(N, 1),模型输出也是(N, 1),因此不使用.squeeze()    predictions = model(test_X) # 保持(N, 1)形状    predictions_binary = (predictions.round()).float() # 四舍五入到0或1,保持(N, 1)形状    # 计算正确预测的数量    correct_predictions = torch.sum(predictions_binary == test_Y).item()    # 获取总样本数    total_samples = test_Y.size(0) # 等同于 len(test_Y)    # 计算准确率百分比    accuracy = (correct_predictions / total_samples) * 100    if(epoch%25 == 0):      print("Epoch " + str(epoch) + " passed. Test accuracy is {:.2f}%".format(accuracy))# ...

关键修正点:

torch.sum(…).item():将布尔张量的求和结果(正确预测数)转换为Python标量。/ total_samples:计算正确预测的比例。* 100:将比例转换为百分比。

4. 完整的PyTorch二分类模型训练与评估示例

以下是一个集成了正确准确率计算的完整PyTorch二分类模型训练与评估示例:

import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import DataLoader, TensorDatasetfrom sklearn.model_selection import train_test_splitimport pandas as pdimport numpy as np# 1. 数据准备 (模拟数据)# 假设你的数据加载和预处理如下:# data = pd.read_csv('your_data.csv')# data['label'] = (data['some_feature'] > threshold).astype(int) # 示例标签生成# ...# 这里使用模拟数据以确保代码可运行np.random.seed(42)num_samples = 1000data = pd.DataFrame({    'A': np.random.rand(num_samples),    'B': np.random.rand(num_samples),    'C': np.random.rand(num_samples),    'D': np.random.rand(num_samples),    'label': np.random.randint(0, 2, num_samples)})train, test = train_test_split(data, test_size=0.2, random_state=42) # 调整test_sizetrain_X = train[["A","B","C", "D"]].to_numpy()test_X = test[["A","B", "C", "D"]].to_numpy()train_Y = train[["label"]].to_numpy()test_Y = test[["label"]].to_numpy()train_X = torch.tensor(train_X, dtype=torch.float32)test_X = torch.tensor(test_X, dtype=torch.float32)train_Y = torch.tensor(train_Y, dtype=torch.float32) # 保持(N, 1)形状test_Y = torch.tensor(test_Y, dtype=torch.float32)   # 保持(N, 1)形状batch_size = 64train_dataset = TensorDataset(train_X, train_Y)# test_dataset = TensorDataset(test_X, test_Y) # 评估时通常直接使用test_X, test_Ytrain_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)# test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 如果需要批量评估,也可以使用# 2. 模型定义class SimpleClassifier(nn.Module):    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):        super(SimpleClassifier, self).__init__()        self.fc1 = nn.Linear(input_size, hidden_size1)        self.relu1 = nn.ReLU()        self.fc2 = nn.Linear(hidden_size1, hidden_size2)        self.relu2 = nn.ReLU()        self.fc3 = nn.Linear(hidden_size2, output_size)        self.sigmoid = nn.Sigmoid()    def forward(self, x):        x = self.relu1(self.fc1(x))        x = self.relu2(self.fc2(x))        x = self.sigmoid(self.fc3(x)) # 输出范围0-1        return xinput_size = train_X.shape[1]hidden_size1 = 64hidden_size2 = 32output_size = 1 # 二分类输出model = SimpleClassifier(input_size,

以上就是PyTorch二分类模型准确率计算陷阱与修正:对比TensorFlow实践的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1375925.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年12月14日 15:27:43
下一篇 2025年12月14日 15:27:54

相关推荐

  • 优化 QLoRA 训练:解决大 Batch Size 导致训练时间过长的问题

    本文将深入探讨在使用 QLoRA(Quantization-aware Low-Rank Adaptation)微调 openlm-research/open_llama_7b_v2 模型时,增大 per_device_train_batch_size 导致训练时间显著增加的问题。我们将分析可能的原…

    好文分享 2025年12月14日
    000
  • python静态方法的用法

    静态方法是通过@staticmethod装饰器定义的、不依赖实例或类状态的工具函数,适合用于逻辑相关但无需访问属性的场景,如数据验证、数学计算等。 静态方法在 Python 中是一种特殊的方法类型,它不属于实例也不属于类,而是作为一个独立的函数被定义在类的内部。它的主要作用是将逻辑上相关的函数组织到…

    2025年12月14日
    000
  • 使用 cppyy 调用 C++ 库时 destroyModel 函数参数传递错误

    在使用 cppyy 调用 C++ 库时,遇到 TypeError: could not convert argument 1 错误,通常是因为 C++ 函数的参数类型与 Python 传递的参数类型不匹配。特别是当 C++ 函数的参数类型是引用时,cppyy 的默认行为可能无法正确处理。 问题描述 …

    2025年12月14日
    000
  • Python对象序列化:将类与实例属性递归转换为嵌套字典

    本文探讨了如何将Python类及其嵌套实例的类属性和实例属性递归地转换为一个结构化的字典。针对Python内置__dict__无法捕获类属性和嵌套对象深层属性的问题,我们提出并实现了一个Serializable基类,通过自定义的to_dict()方法,有效解决了对象及其复杂属性结构的序列化难题,最终…

    2025年12月14日
    000
  • python中Sobel算子是什么

    Sobel算子通过3×3卷积核计算图像梯度实现边缘检测,使用Gx和Gy分量结合幅值与方向判断边缘,具有抗噪性强、定位准确的优点,常用作图像处理预处理步骤。 Sobel算子是图像处理和计算机视觉中常用的一种边缘检测算子,主要用于检测图像中的梯度变化,从而识别出图像的边缘。它通过计算图像在水平和垂直方向…

    2025年12月14日
    000
  • python负值如何使用?

    负值在Python中用于数值计算和反向索引。-5+3得-2,-1表示最后一个元素,如text[-1]输出o,lst[-3]取20;切片nums[-3:]得[3,4,5],[::-1]可反转列表;注意索引越界会报错。 Python中的负值使用非常直接,主要用于数值计算、索引操作和控制流程等场景。负值就…

    2025年12月14日
    000
  • python3.5如何安装

    答案:Python 3.5 可在 Windows、macOS 和 Linux 上安装。Windows 用户从官网下载安装包并勾选添加到 PATH;macOS 建议使用官方安装包或 Homebrew 安装;Linux(Ubuntu)可通过 deadsnakes PPA 安装。安装后通过 python3…

    2025年12月14日
    000
  • Python装饰器的应用场景

    装饰器通过封装横切逻辑提升代码复用性,如@login_required实现权限校验,@log_calls记录函数调用,@timing统计执行耗时,@lru_cache缓存结果,实现认证、日志、性能优化等功能。 Python装饰器是一种强大的语言特性,它允许你在不修改原函数代码的前提下,为函数添加额外…

    2025年12月14日
    000
  • python单元测试中的函数整理

    Python单元测试核心函数来自unittest模块,包括断言方法如assertEqual、assertTrue;setUp和tearDown用于测试前后环境准备与清理;@skip等装饰器支持条件跳过;unittest.mock提供Mock、patch实现依赖模拟;通过unittest.main()…

    2025年12月14日
    000
  • 基于OpenCV的视频帧拼接防抖技术教程

    基于OpenCV的视频帧拼接防抖技术教程 本文旨在解决使用OpenCV进行多摄像头视频帧拼接时出现的抖动问题。通过继承Stitcher类并重写initialize_stitcher()和stitch()方法,实现仅在第一帧进行相机标定,后续帧沿用标定结果,从而避免因每帧独立标定导致的画面扭曲和抖动。…

    2025年12月14日
    000
  • python实例方法的使用注意

    实例方法必须定义在类中并接收self参数,通过实例调用以操作对象状态,避免误用为静态函数。 在Python中,实例方法是最常见的方法类型,它依赖于类的实例来调用和操作数据。正确使用实例方法不仅能提升代码可读性,还能避免常见错误。以下是使用实例方法时需要注意的关键点。 必须定义在类中并接收self参数…

    2025年12月14日
    000
  • python赋值运算符是什么

    Python赋值运算符用于将值赋予变量,基础赋值运算符为=,如a=10、b=a+5;复合赋值运算符结合算术或位运算与赋值,如+=、-=、=、/=、%=、*=、//=,以及位运算赋值&=、|=、^=、>>=等,使代码更简洁。 Python赋值运算符用于将值赋予变量。最基础的赋值运算…

    2025年12月14日
    000
  • 利用部分字符串在列表中查找完整值

    本文介绍如何在Python列表中,通过提供部分字符串来查找包含该字符串的完整元素。通过遍历列表中的元素,并使用字符串的in操作符进行匹配,可以高效地找到目标值。本文提供了一个可复用的函数示例,并讨论了其适用场景和潜在的优化方向。 在处理从HTML页面解析或其他数据源获取的列表数据时,经常会遇到需要根…

    2025年12月14日
    000
  • 将类和实例属性转换为嵌套字典的 Python 教程

    本文介绍如何将 Python 类及其实例的属性,包括嵌套的类和实例属性,转换为一个字典。通过自定义 Serializable 类和 to_dict() 方法,可以方便地将类和实例的属性以嵌套字典的形式进行展示。同时,本文也讨论了该方法的一些局限性,例如处理循环引用和非序列化对象的情况。 实现 Ser…

    2025年12月14日
    000
  • 标题:Python Turtle 教程:理解条件判断中的逻辑错误

    本教程旨在帮助读者理解 Python 中条件判断语句的逻辑运算,并通过 Turtle 模块的示例,深入剖析 or 运算符在条件判断中可能出现的陷阱。我们将分析一个 Turtle 随机移动并改变方向的场景,重点讲解如何正确地使用 or 运算符来判断 Turtle 是否超出边界,并提供修改后的代码示例,…

    2025年12月14日
    000
  • Python AWS Lambda 函数请求超时及连接重置问题排查与解决

    第一段引用上面的摘要:本文旨在解决 AWS Lambda 函数中使用 Python requests.get() 方法时遇到的超时和连接重置问题。通过分析网络配置,特别是 Lambda 函数的 VPC 设置,解释了为何会出现这些问题,并提供了两种解决方案:配置 NAT 网关以允许 Lambda 函数…

    2025年12月14日
    000
  • 解决dput上传Debian包时SSL证书验证失败问题:自签名证书的临时方案

    本教程针对使用dput向GitLab上传Debian包时,因自签名SSL证书导致的“SSL: CERTIFICATE_VERIFY_FAILED”错误,提供了一个直接修改dput脚本以临时禁用SSL验证的解决方案。此方法适用于受控环境,但需注意其安全风险。 问题描述:dput上传与SSL证书验证失败…

    2025年12月14日
    000
  • 在Pyomo中动态扩展约束

    本文档旨在帮助Pyomo初学者了解如何在Pyomo中实现类似Pulp中动态扩展约束的功能。由于Pyomo的表达式不可变性,直接修改约束表达式较为复杂。本文将介绍如何利用命名表达式(Expression)以及元组表示法来灵活地构建和修改约束,并提供示例代码和注意事项,帮助读者掌握在Pyomo中实现动态…

    2025年12月14日
    000
  • python防止栈溢出的解决

    递归深度过大导致栈溢出时,可通过增加递归限制或改用迭代解决。1. 使用sys.setrecursionlimit()可提高递归深度,但受限于系统资源;2. 将递归算法转为迭代形式,如阶乘计算,避免调用堆栈增长,提升效率与安全性。 Python中防止栈溢出主要出现在递归调用过深的情况下。由于Pytho…

    2025年12月14日
    000
  • 解决preview-generator安装失败问题:Windows平台安装指南

    摘要 本文针对在Windows系统中使用pip安装preview-generator包时遇到的常见错误,提供详细的排查和解决方案。preview-generator依赖多个非Python库,在Windows上的安装配置较为复杂。本文将引导你安装必要的依赖项,并提供替代方案,帮助你成功生成文件预览。 …

    2025年12月14日
    000

发表回复

登录后才能评论
关注微信