Vision Transformer多标签分类:损失函数与评估策略深度解析

vision transformer多标签分类:损失函数与评估策略深度解析

本文旨在详细阐述如何将Vision Transformer(ViT)从单标签多分类任务转换为多标签分类任务,并重点介绍损失函数的选择与评估策略的调整。我们将探讨为何`CrossEntropyLoss`不适用于多标签场景,并深入讲解`BCEWithLogitsLoss`的使用方法,包括标签格式要求。此外,文章还将介绍多标签分类任务中常用的评估指标,如精确率、召回率、F1分数和mAP,并提供代码示例,确保读者能够顺利实现ViT在多标签环境下的训练与评估。

从单标签到多标签:核心概念转变

深度学习的图像分类任务中,单标签多分类(Single-label Multi-class Classification)是指每张图片只属于一个类别,模型需要从多个互斥的类别中预测出唯一正确的那个。而多标签分类(Multi-label Classification)则允许每张图片同时属于一个或多个类别,模型需要为每个类别独立地判断其是否存在于图片中。

这种任务性质的转变,要求我们对模型的输出层、损失函数以及评估策略进行相应的调整。对于Vision Transformer(ViT)而言,其特征提取部分通常保持不变,但最终的分类头和训练流程需要进行适配。

损失函数的选择与实现

在单标签多分类任务中,我们通常使用torch.nn.CrossEntropyLoss作为损失函数。它内部包含了Softmax激活函数和负对数似然损失,期望模型的输出是每个类别的Logits,并且这些Logits经过Softmax后会转化为概率分布,所有类别的概率和为1。

然而,在多标签分类任务中,由于图片可能同时属于多个类别,各个类别之间不再是互斥关系。因此,CrossEntropyLoss不再适用,因为它强制了类别之间的互斥性。

推荐的损失函数:torch.nn.BCEWithLogitsLoss

对于多标签分类任务,最常用且推荐的损失函数是torch.nn.BCEWithLogitsLoss。这个损失函数结合了Sigmoid激活函数和二元交叉熵损失(Binary Cross Entropy Loss)。

其主要优点包括:

独立处理每个类别: BCEWithLogitsLoss会对模型输出的每个Logit独立地计算二元交叉熵,这与多标签任务中各类别独立存在的特性相符。数值稳定性: 它直接作用于模型的原始Logits输出,内部处理了Sigmoid激活,避免了先手动计算Sigmoid再计算交叉熵可能导致的数值溢出或下溢问题。

使用BCEWithLogitsLoss的注意事项:

模型输出: 模型的最终输出层应该是一个全连接层,输出维度等于类别的总数,且不应在其后接Softmax激活函数。例如,如果你的模型有7个类别,最终输出应为形状(batch_size, 7)的Logits张量。标签格式: 标签(target)必须是与模型输出Logits形状相同的浮点型(torch.float)张量。它通常是一个“多热编码”(multi-hot encoding)向量,其中1表示该类别存在,0表示该类别不存在。例如,[0, 1, 1, 0, 0, 1, 0]表示第二个、第三个和第六个类别存在。

代码示例:替换损失函数

假设我们有一个ViT模型,其输出为pred(Logits),标签为labels(多热编码)。

import torchimport torch.nn as nn# 假设模型输出的Logits,形状为 (batch_size, num_classes)# 这里以 batch_size = 2, num_classes = 7 为例logits = torch.randn(2, 7) # 模拟模型输出的原始Logits# 假设对应的多标签,形状也为 (batch_size, num_classes)# 注意:标签必须是浮点型 (torch.float)labels = torch.tensor([    [0, 1, 1, 0, 0, 1, 0], # 第一个样本的标签    [1, 0, 1, 1, 0, 0, 0]  # 第二个样本的标签]).float()# 实例化 BCEWithLogitsLossloss_function = nn.BCEWithLogitsLoss()# 计算损失loss = loss_function(logits, labels)print(f"Logits:n{logits}")print(f"Labels:n{labels}")print(f"Calculated Loss: {loss.item()}")# 原始训练循环中的应用# pred = model(images.to(device))# loss = loss_function(pred, labels.to(device))# loss.backward()# optimizer.step()

多标签分类的评估策略

在单标签分类中,准确率(Accuracy)是最常用的评估指标。然而,在多标签分类中,仅仅计算准确率是不足够的,甚至可能产生误导。例如,如果一个模型总是预测所有类别都不存在,而实际只有少数类别存在,那么它的准确率可能很高(因为它正确预测了大量不存在的类别),但它对存在类别的识别能力却很差。

因此,我们需要采用更全面的指标来评估多标签分类模型的性能。

1. 从Logits到预测结果

百度GBI 百度GBI

百度GBI-你的大模型商业分析助手

百度GBI 104 查看详情 百度GBI

在计算评估指标之前,我们需要将模型的Logits输出转换为具体的类别预测。这通常通过对Logits应用Sigmoid函数,然后设定一个阈值(例如0.5)来完成。

# 假设 logits 是模型输出的Logits# 例如:logits = torch.randn(batch_size, num_classes)# 1. 应用Sigmoid函数将Logits转换为概率probabilities = torch.sigmoid(logits)# 2. 设定阈值,将概率转换为二元预测 (0或1)threshold = 0.5predictions = (probabilities > threshold).float()print(f"Probabilities:n{probabilities}")print(f"Predictions (threshold={threshold}):n{predictions}")

2. 常用评估指标

以下是多标签分类中常用的评估指标:

精确率(Precision)、召回率(Recall)、F1分数(F1-score):

精确率: 预测为正例的样本中,有多少是真正的正例。召回率: 实际为正例的样本中,有多少被模型预测为正例。F1分数: 精确率和召回率的调和平均值,综合衡量模型的性能。这些指标可以针对每个类别独立计算(Per-class),也可以通过微平均(Micro-average)或宏平均(Macro-average)来汇总所有类别的结果。Micro-average: 汇总所有类别的TP、FP、FN后再计算总体的Precision、Recall、F1。它更侧重于样本级别的性能,受样本数量较多的类别影响较大。Macro-average: 先计算每个类别的Precision、Recall、F1,然后取这些值的平均。它给予每个类别相同的权重,不受类别样本数量不平衡的影响。

平均精确率(Average Precision, AP)与平均精确率均值(mean Average Precision, mAP):

AP: 衡量单个类别在不同召回率下的精确率表现,通常通过计算PR曲线下面积获得。AP值越高,说明模型在该类别上的性能越好。mAP: 对所有类别的AP值取平均,是衡量多标签分类模型整体性能的一个非常重要的指标,尤其在目标检测等领域广泛使用。

Jaccard Index (IoU) / Jaccard Similarity Score:

衡量预测集合与真实标签集合的相似度,计算公式为交集大小除以并集大小。对于多标签分类,可以计算每个样本的预测标签集合与真实标签集合的Jaccard相似度,然后取平均。

Hamming Loss:

衡量预测结果与真实标签不一致的标签比例。Hamming Loss越低越好。

3. 使用torchmetrics或scikit-learn进行评估

在PyTorch生态中,torchmetrics库提供了丰富的多标签评估指标。scikit-learn也是一个非常强大的工具,可以在CPU上方便地进行评估。

torchmetrics示例 (推荐用于PyTorch训练循环中):

import torchfrom torchmetrics.classification import MultilabelF1Score, MultilabelAveragePrecision# 假设真实标签和预测概率# num_classes = 7num_labels = 7num_samples = 10target_labels = torch.randint(0, 2, (num_samples, num_labels)).float() # 真实标签 (0或1)predicted_probs = torch.rand(num_samples, num_labels) # 模型输出的概率 (经过Sigmoid)# 或者直接使用Logits,让metrics内部处理Sigmoidpredicted_logits = torch.randn(num_samples, num_labels)# 实例化F1分数,可以指定 average 方式 (e.g., 'micro', 'macro', 'weighted', 'none')# MultilabelF1Score 期望输入是 (preds, target)# preds: 概率 (float) 或 原始logits (float)# target: 真实标签 (int 或 float, 0/1)f1_score_micro = MultilabelF1Score(num_labels=num_labels, average='micro', validate_args=False)f1_score_macro = MultilabelF1Score(num_labels=num_labels, average='macro', validate_args=False)# 计算F1分数# 注意:MultilabelF1Score 可以直接接收概率或logits,但通常建议给概率f1_micro_val = f1_score_micro(predicted_probs, target_labels.long()) # target_labels需要是long类型对于F1Scoref1_macro_val = f1_score_macro(predicted_probs, target_labels.long())print(f"Micro F1 Score: {f1_micro_val.item()}")print(f"Macro F1 Score: {f1_macro_val.item()}")# 实例化mAP# MultilabelAveragePrecision 期望输入是 (preds, target)# preds: 概率 (float)# target: 真实标签 (int 或 float, 0/1)map_metric = MultilabelAveragePrecision(num_labels=num_labels, validate_args=False)# 计算mAPmap_val = map_metric(predicted_probs, target_labels.long()) # target_labels需要是long类型对于mAPprint(f"mAP: {map_val.item()}")# 如果输入是logits,可以这样处理 (MultilabelF1Score 和 MultilabelAveragePrecision 默认不带sigmoid,需要手动处理或确保其内部处理了)# 对于MultilabelF1Score和MultilabelAveragePrecision,当输入是概率时,通常需要手动将target转换为long# 如果输入是logits,则需要确保metrics内部会执行sigmoid# 更好的做法是,统一将模型输出转换为概率再传入metricsprobs_from_logits = torch.sigmoid(predicted_logits)f1_micro_val_logits = f1_score_micro(probs_from_logits, target_labels.long())map_val_logits = map_metric(probs_from_logits, target_labels.long())print(f"Micro F1 Score (from logits): {f1_micro_val_logits.item()}")print(f"mAP (from logits): {map_val_logits.item()}")

总结与注意事项

将ViT从单标签多分类转换为多标签分类,关键在于以下几点:

模型输出层: 确保模型的最终全连接层输出与类别数量相等的Logits,并且不带Softmax激活。损失函数: 使用torch.nn.BCEWithLogitsLoss作为损失函数,它能独立处理每个类别的预测。标签格式: 真实标签应为多热编码的浮点型张量,形状与模型输出的Logits相同。评估指标: 采用适合多标签任务的评估指标,如Micro/Macro F1分数、mAP、Jaccard Index等,并结合torchmetrics或scikit-learn等库进行高效计算。阈值选择: 在将概率转换为二元预测时,阈值的选择对最终的精确率和召回率有显著影响,可能需要通过验证集进行调优。类别不平衡: 在多标签任务中,类别不平衡问题可能更复杂(例如,某些标签总是同时出现,某些标签非常稀有)。可以考虑使用加权BCE损失、Focal Loss或采样策略来缓解。

通过以上调整,您的Vision Transformer模型将能够有效地处理多标签图像分类任务。

以上就是Vision Transformer多标签分类:损失函数与评估策略深度解析的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/589903.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月10日 15:24:58
下一篇 2025年11月10日 15:27:52

相关推荐

  • HTML、CSS 和 JavaScript 中的简单侧边栏菜单

    构建一个简单的侧边栏菜单是一个很好的主意,它可以为您的网站添加有价值的功能和令人惊叹的外观。 侧边栏菜单对于客户找到不同项目的方式很有用,而不会让他们觉得自己有太多选择,从而创造了简单性和秩序。 今天,我将分享一个简单的 HTML、CSS 和 JavaScript 源代码来创建一个简单的侧边栏菜单。…

    2025年12月24日
    200
  • 前端代码辅助工具:如何选择最可靠的AI工具?

    前端代码辅助工具:可靠性探讨 对于前端工程师来说,在HTML、CSS和JavaScript开发中借助AI工具是司空见惯的事情。然而,并非所有工具都能提供同等的可靠性。 个性化需求 关于哪个AI工具最可靠,这个问题没有一刀切的答案。每个人的使用习惯和项目需求各不相同。以下是一些影响选择的重要因素: 立…

    2025年12月24日
    300
  • 带有 HTML、CSS 和 JavaScript 工具提示的响应式侧边导航栏

    响应式侧边导航栏不仅有助于改善网站的导航,还可以解决整齐放置链接的问题,从而增强用户体验。通过使用工具提示,可以让用户了解每个链接的功能,包括设计紧凑的情况。 在本教程中,我将解释使用 html、css、javascript 创建带有工具提示的响应式侧栏导航的完整代码。 对于那些一直想要一个干净、简…

    2025年12月24日
    000
  • 布局 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在这里查看视觉效果: 固定导航 – 布局 – codesandbox两列 – 布局 – codesandbox三列 – 布局 – codesandbox圣杯 &#8…

    2025年12月24日
    000
  • 隐藏元素 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在此处查看隐藏元素的视觉效果 – codesandbox 隐藏元素 hiding elements hiding elements hiding elements hiding elements hiding element…

    2025年12月24日
    400
  • 居中 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在此处查看垂直中心 – codesandbox 和水平中心的视觉效果。 通过 css 居中 垂直居中 centering centering centering centering centering centering立即…

    2025年12月24日 好文分享
    300
  • 如何在 Laravel 框架中轻松集成微信支付和支付宝支付?

    如何用 laravel 框架集成微信支付和支付宝支付 问题:如何在 laravel 框架中集成微信支付和支付宝支付? 回答: 建议使用 easywechat 的 laravel 版,easywechat 是一个由腾讯工程师开发的高质量微信开放平台 sdk,已被广泛地应用于许多 laravel 项目中…

    2025年12月24日
    000
  • 如何在移动端实现子 div 在父 div 内任意滑动查看?

    如何在移动端中实现让子 div 在父 div 内任意滑动查看 在移动端开发中,有时我们需要让子 div 在父 div 内任意滑动查看。然而,使用滚动条无法实现负值移动,因此需要采用其他方法。 解决方案: 使用绝对布局(absolute)或相对布局(relative):将子 div 设置为绝对或相对定…

    2025年12月24日
    000
  • 移动端嵌套 DIV 中子 DIV 如何水平滑动?

    移动端嵌套 DIV 中子 DIV 滑动 在移动端开发中,遇到这样的问题:当子 DIV 的高度小于父 DIV 时,无法在父 DIV 中水平滚动子 DIV。 无限画布 要实现子 DIV 在父 DIV 中任意滑动,需要创建一个无限画布。使用滚动无法达到负值,因此需要使用其他方法。 相对定位 一种方法是将子…

    2025年12月24日
    000
  • 移动端项目中,如何消除rem字体大小计算带来的CSS扭曲?

    移动端项目中消除rem字体大小计算带来的css扭曲 在移动端项目中,使用rem计算根节点字体大小可以实现自适应布局。但是,此方法可能会导致页面打开时出现css扭曲,这是因为页面内容在根节点字体大小赋值后重新渲染造成的。 解决方案: 要避免这种情况,将计算根节点字体大小的js脚本移动到页面的最前面,即…

    2025年12月24日
    000
  • Nuxt 移动端项目中 rem 计算导致 CSS 变形,如何解决?

    Nuxt 移动端项目中解决 rem 计算导致 CSS 变形 在 Nuxt 移动端项目中使用 rem 计算根节点字体大小时,可能会遇到一个问题:页面内容在字体大小发生变化时会重绘,导致 CSS 变形。 解决方案: 可将计算根节点字体大小的 JS 代码块置于页面最前端的 标签内,确保在其他资源加载之前执…

    2025年12月24日
    200
  • Nuxt 移动端项目使用 rem 计算字体大小导致页面变形,如何解决?

    rem 计算导致移动端页面变形的解决方法 在 nuxt 移动端项目中使用 rem 计算根节点字体大小时,页面会发生内容重绘,导致页面打开时出现样式变形。如何避免这种现象? 解决方案: 移动根节点字体大小计算代码到页面顶部,即 head 中。 原理: flexível.js 也遇到了类似问题,它的解决…

    2025年12月24日
    000
  • 形状 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在此处查看 codesandbox 的视觉效果。 通过css绘制各种形状 如何在 css 中绘制正方形、梯形、三角形、异形三角形、扇形、圆形、半圆、固定宽高比、0.5px 线? shapes 0.5px line .square { w…

    2025年12月24日
    000
  • 有哪些美观的开源数字大屏驾驶舱框架?

    开源数字大屏驾驶舱框架推荐 问题:有哪些美观的开源数字大屏驾驶舱框架? 答案: 资源包 [弗若恩智能大屏驾驶舱开发资源包](https://www.fanruan.com/resource/152) 软件 [弗若恩报表 – 数字大屏可视化组件](https://www.fanruan.c…

    2025年12月24日
    000
  • 网站底部如何实现飘彩带效果?

    网站底部飘彩带效果的 js 库实现 许多网站都会在特殊节日或活动中添加一些趣味性的视觉效果,例如点击按钮后散发的五彩缤纷的彩带。对于一个特定的网站来说,其飘彩带效果的实现方式可能有以下几个方面: 以 https://dub.sh/ 网站为例,它底部按钮点击后的彩带效果是由 javascript 库实…

    2025年12月24日
    000
  • 网站彩带效果背后是哪个JS库?

    网站彩带效果背后是哪个js库? 当你访问某些网站时,点击按钮后,屏幕上会飘出五颜六色的彩带,营造出庆祝的氛围。这些效果是通过使用javascript库实现的。 问题: 哪个javascript库能够实现网站上点击按钮散发彩带的效果? 答案: 根据给定网站的源代码分析: 可以发现,该网站使用了以下js…

    好文分享 2025年12月24日
    100
  • 产品预览卡项目

    这个项目最初是来自 Frontend Mentor 的挑战,旨在使用 HTML 和 CSS 创建响应式产品预览卡。最初的任务是设计一张具有视觉吸引力和功能性的产品卡,能够无缝适应各种屏幕尺寸。这涉及使用 CSS 媒体查询来确保布局在不同设备上保持一致且用户友好。产品卡包含产品图像、标签、标题、描述和…

    2025年12月24日
    100
  • 如何利用 echarts-gl 绘制带发光的 3D 图表?

    如何绘制带发光的 3d 图表,类似于 echarts 中的示例? 为了实现类似的 3d 图表效果,需要引入 echarts-gl 库:https://github.com/ecomfe/echarts-gl。 echarts-gl 专用于在 webgl 环境中渲染 3d 图形。它提供了各种 3d 图…

    2025年12月24日
    000
  • 如何在 Element UI 的 el-rate 组件中实现 5 颗星 5 分制与百分制之间的转换?

    如何在el-rate中将5颗星5分制的分值显示为5颗星百分制? 要实现该效果,只需使用 el-rate 组件的 allow-half 属性。在设置 allow-half 属性后,获得的结果乘以 20 即可得到0-100之间的百分制分数。如下所示: score = score * 20; 动态显示鼠标…

    2025年12月24日
    100
  • CSS 最佳实践:后端程序员重温 CSS 时常见的三个疑问?

    CSS 最佳实践:提升代码质量 作为后端程序员,在重温 CSS/HTML 时,你可能会遇到一些关于最佳实践的问题。以下将解答三个常见问题,帮助你编写更规范、清晰的 CSS 代码。 1. margin 设置策略 当相邻元素都设置了 margin 时,通常情况下应为上一个元素设置 margin-bott…

    2025年12月24日
    000

发表回复

登录后才能评论
关注微信