在少样本学习中,用SetFit进行文本分类

译者 | 陈峻

审校 | 重楼

在本文中,我将向您介绍“少样本(few-shot)学习”的相关概念,并重点讨论被广泛应用于文本分类的setfit方法。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

在少样本学习中,用SetFit进行文本分类

传统的机器学习(ML)

在监督(Supervised)机器学习中,大量数据集被用于模型训练,以便磨练模型能够做出精确预测的能力。在完成训练过程之后,我们便可以利用测试数据,来获得模型的预测结果。然而,这种传统的监督学习方法存在着一个显著缺点:它需要大量无差错的训练数据集。但是并非所有领域都能够提供此类无差错数据集。因此,“少样本学习”的概念应运而生。

在深入研究Sentence Transformer fine-tuning(SetFit)之前,我们有必要简要地回顾一下自然语言处理(Natural Language Processing,NLP)的一个重要方面,也就是:“少样本学习”。

少样本学习

少样本学习是指:使用有限的训练数据集,来训练模型。模型可以从这些被称为支持集的小集合中获取知识。此类学习旨在教会少样本模型,辨别出训练数据中的相同与相异之处。例如,我们并非要指示模型将所给图像分类为猫或狗,而是指示它掌握各种动物之间的共性和区别。可见,这种方法侧重于理解输入数据中的相似点和不同点。因此,它通常也被称为元学习(meta-learning)、或是从学习到学习(learning-to-learn)。

值得一提的是,少样本学习的支持集,也被称为k向(k-way)n样本(n-shot)学习。其中“k”代表支持集里的类别数。例如,在二分类(binary classification)中,k 等于 2。而“n”表示支持集中每个类别的可用样本数。例如,如果正分类有10个数据点,而负分类也有10个数据点,那么 n就等于10。总之,这个支持集可以被描述为双向10样本学习。

既然我们已经对少样本学习有了基本的了解,下面让我们通过使用SetFit进行快速学习,并在实际应用中对电商数据集进行文本分类。

SetFit架构

由Hugging Face和英特尔实验室的团队联合开发的SetFit,是一款用于少样本照片分类的开源工具。你可以在项目库链接–https://github.com/huggingface/setfit?ref=hackernoon.com中,找到关于SetFit的全面信息。

就输出而言,SetFit仅用到了客户评论(Customer Reviews,CR)情感分析数据集里、每个类别的八个标注示例。其结果就能够与由三千个示例组成的完整训练集上,经调优的RoBERTa Large的结果相同。值得强调的是,就体积而言,经微优的RoBERTa模型比SetFit模型大三倍。下图展示的是SetFit架构:

在少样本学习中,用SetFit进行文本分类

图片来源:https://www.php.cn/link/2456b9cd2668fa69e3c7ecd6f51866bf

用SetFit实现快速学习

SetFit的训练速度非常快,效率也极高。与GPT-3和T-FEW等大模型相比,其性能极具竞争力。请参见下图:

在少样本学习中,用SetFit进行文本分类SetFit与T-Few 3B模型的比较

如下图所示,SetFit在少样本学习方面的表现优于RoBERTa。

在少样本学习中,用SetFit进行文本分类

SetFit与RoBERT的比较,图片来源:https://www.php.cn/link/3ff4cea152080fd7d692a8286a587a67

数据集

下面,我们将用到由四个不同类别组成的独特电商数据集,它们分别是:书籍、服装与配件、电子产品、以及家居用品。该数据集的主要目的是将来自电商网站的产品描述归类到指定的标签下。

为了便于采用少样本的训练方法,我们将从四个类别中各选择八个样本,从而得到总共32个训练样本。而其余样本则将留作测试之用。简言之,我们在此使用的支持集是4向8样本学习。下图展示的是自定义电商数据集的示例:

在少样本学习中,用SetFit进行文本分类自定义电商数据集样本

我们采用名为“all-mpnet-base-v2”的Sentence Transformers预训练模型,将文本数据转换为各种向量嵌入。该模型可以为输入文本,生成维度为768的向量嵌入。

如下命令所示,我们将通过在conda环境(是一个开源的软件包管理系统和环境管理系统)中安装所需的软件包,来开始SetFit的实施。

!pip3 install SetFit !pip3 install sklearn !pip3 install transformers !pip3 install sentence-transformers

安装完软件包后,我们便可以通过如下代码加载数据集了。

from datasets import load_datasetdataset = load_dataset('csv', data_files={"train": 'E_Commerce_Dataset_Train.csv',"test": 'E_Commerce_Dataset_Test.csv'})

我们来参照下图,看看训练样本和测试样本数。

豆包爱学 豆包爱学

豆包旗下AI学习应用

豆包爱学 674 查看详情 豆包爱学

在少样本学习中,用SetFit进行文本分类训练和测试数据

我们使用sklearn软件包中的LabelEncoder,将文本标签转换为编码标签。

from sklearn.preprocessing import LabelEncoder le = LabelEncoder()

通过LabelEncoder,我们将对训练和测试数据集进行编码,并将编码后的标签添加到数据集的“标签”列中。请参见如下代码:

Encoded_Product = le.fit_transform(dataset["train"]['Label']) dataset["train"] = dataset["train"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["train"].features)Encoded_Product = le.fit_transform(dataset["test"]['Label']) dataset["test"] = dataset["test"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["test"].features)

下面,我们将初始化SetFit模型和句子转换器(sentence-transformers)模型。

from setfit import SetFitModel, SetFitTrainer from sentence_transformers.losses import CosineSimilarityLossmodel_id = "sentence-transformers/all-mpnet-base-v2" model = SetFitModel.from_pretrained(model_id)trainer = SetFitTrainer(  model=model, train_dataset=dataset["train"], eval_dataset=dataset["test"], loss_class=CosineSimilarityLoss, metric="accuracy", batch_size=64, num_iteratinotallow=20, num_epochs=2, column_mapping={"Text": "text", "Label": "label"})

初始化完成两个模型后,我们现在便可以调用训练程序了。

trainer.train()

在完成了2个训练轮数(epoch)后,我们将在eval_dataset上,对训练好的模型进行评估。

trainer.evaluate()

经测试,我们的训练模型的最高准确率为87.5%。虽然87.5%的准确率并不算高,但是毕竟我们的模型只用了32个样本进行训练。也就是说,考虑到数据集规模的有限性,在测试数据集上取得87.5%的准确率,实际上是相当可观的。

此外,SetFit还能够将训练好的模型,保存到本地存储器中,以便后续从磁盘加载,用于将来的预测。

trainer.model._save_pretrained(save_directory="SetFit_ECommerce_Output/")model=SetFitModel.from_pretrained("SetFit_ECommerce_Output/", local_files_notallow=True)

如下代码展示了根据新的数据进行的预测结果:

input = ["Campus Sutra Men's Sports Jersey T-Shirt Cool-Gear: Our Proprietary Moisture Management technology. Helps to absorb and evaporate sweat quickly. Keeps you Cool & Dry. Ultra-Fresh: Fabrics treated with Ultra-Fresh Antimicrobial Technology. Ultra-Fresh is a trademark of (TRA) Inc, Ontario, Canada. Keeps you odour free."]output = model(input)

可见,其预测输出为1,而标签的LabelEncoded值为“服装与配件”。由于传统的AI模型需要大量的训练资源(包括时间和数据),才能有稳定水准的输出。而我们的模型与之相比,既准确又高效。

至此,相信您已经基本掌握了“少样本学习”的概念,以及如何使用SetFit来进行文本分类等应用。当然,为了获得更深刻的理解,我强烈建议您选择一个实际场景,创建一个数据集,编写对应的代码,并将该过程延展到零样本学习、以及单样本学习上。

译者介绍

陈峻(Julian Chen)是51CTO社区的编辑,他在IT项目实施方面有十多年的经验,擅长管理内外部资源和风险,并专注于传播网络和信息安全的知识和经验

原文标题:Mastering Few-Shot Learning with SetFit for Text Classification,作者:Shyam Ganesh S)

以上就是在少样本学习中,用SetFit进行文本分类的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/457199.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月8日 01:56:04
下一篇 2025年11月8日 01:56:48

相关推荐

  • 如何部署一个机器学习模型到生产环境?

    部署机器学习模型需先序列化存储模型,再通过API服务暴露预测接口,接着容器化应用并部署至云平台或服务器,同时建立监控、日志和CI/CD体系,确保模型可扩展、可观测且可持续更新。 部署机器学习模型到生产环境,简单来说,就是让你的模型真正开始“干活”,为实际用户提供预测或决策支持。这并非只是把模型文件复…

    2025年12月14日
    000
  • 如何使用Python进行机器学习(Scikit-learn基础)?

    答案:Scikit-learn提供系统化机器学习流程,涵盖数据预处理、模型选择与评估。具体包括使用StandardScaler等工具进行特征缩放,SimpleImputer处理缺失值,OneHotEncoder编码类别特征,SelectKBest实现特征选择;根据问题类型选择分类、回归或聚类模型,结…

    2025年12月14日
    000
  • Python中如何使用sklearn进行机器学习?

    使用sklearn进行机器学习的步骤包括:1. 数据预处理,如标准化和处理缺失值;2. 模型选择和训练,使用决策树、随机森林等算法;3. 模型评估和调参,利用交叉验证和网格搜索;4. 处理类别不平衡问题。sklearn提供了从数据预处理到模型评估的全套工具,帮助用户高效地进行机器学习任务。 在Pyt…

    2025年12月14日
    000
  • 如何在Python中利用机器学习算法进行数据挖掘和预测

    如何在Python中利用机器学习算法进行数据挖掘和预测 引言随着大数据时代的到来,数据挖掘和预测成为了数据科学研究的重要组成部分。而Python作为一种简洁优雅的编程语言,拥有强大的数据处理和机器学习库,成为了数据挖掘和预测的首选工具。本文将介绍如何在Python中利用机器学习算法进行数据挖掘和预测…

    2025年12月13日
    000
  • 机器学习中的Python问题及解决策略

    机器学习是当前最热门的技术领域之一,而Python作为一种简洁、灵活、易于学习的编程语言,成为了机器学习领域最受欢迎的工具之一。然而,在机器学习中使用Python过程中,总会遇到一些问题和挑战。本文将介绍一些常见的机器学习中使用Python的问题,并提供一些解决策略和具体的代码示例。 Python版…

    2025年12月13日
    000
  • Python是机器学习的最佳选择吗?

    “哪种编程语言最好?”这是编程世界中最流行和最有争议的问题。这个问题的答案不是线性的或简单的,因为从技术上讲,每种编程语言都有自己的优点和缺点。不存在“最好”的编程语言,因为根据问题的不同,每种语言都比其他语言具有轻微的优势。当我们谈论机器学习时,毫无疑问Python是一种高度首选的语言,但有一些因…

    2025年12月13日
    000
  • PHP机器学习:PHP-ML基础

    php-ml是适用于php环境的机器学习库。1.它提供分类、回归、聚类等算法;2.通过composer安装使用;3.适合中小型项目,性能不及python但无需额外扩展;4.常用算法包括朴素贝叶斯、svm、knn等,选择需根据问题类型和数据特征决定;5.支持数据预处理与特征工程如标准化、缺失值处理、文…

    2025年12月10日 好文分享
    000
  • PHP 函数设计模式在机器学习中的应用

    函数设计模式在机器学习中通过工厂模式创建模型对象,建造者模式构建训练数据集,以及策略模式切换算法,实现可重用、可扩展和易维护的机器学习管道。 PHP 函数设计模式在机器学习中的应用 函数设计模式是一种设计原则,用于提高代码的可重用性和可维护性。在机器学习中,函数设计模式可以帮助我们创建灵活、可扩展的…

    2025年12月9日
    100
  • PHP函数在机器学习中的关键作用

    php在机器学习中扮演着关键角色,提供以下函数:线性回归:stats_regression_linear()聚类:kmeans()分类:svm_train() 和 svm_predict() PHP函数在机器学习中的关键作用 引言 PHP是一种通用脚本语言,在构建网站和应用程序时得到广泛使用。近年来…

    2025年12月9日
    000
  • PHP 函数如何扩展到机器学习?

    使用 phpml 库扩展 php 函数以利用机器学习技术:安装和加载 phpml 库。使用 k-近邻算法进行图像识别等实战应用。phpml 提供其他机器学习算法,如回归、分类和聚类。通过学习使用 phpml,开发者可以在 php 项目中轻松应用机器学习技术。 PHP 函数扩展到机器学习 随着机器学习…

    2025年12月9日
    000
  • HiDream-I1— 智象未来开源的文生图模型

    hidream-i1:一款强大的开源图像生成模型 HiDream-I1是由HiDream.ai团队开发的17亿参数开源图像生成模型,采用MIT许可证,在图像质量和对提示词的理解方面表现卓越。它支持多种风格,包括写实、卡通和艺术风格,广泛应用于艺术创作、商业设计、科研教育以及娱乐媒体等领域。 HiDr…

    2025年12月5日
    000
  • Tome怎样用叙述流提示自动排版_Tome用叙述流提示自动排版【叙述提示】

    通过输入具有逻辑顺序的叙述性文本,Tome可自动排版生成演示文稿。一、使用自然语言描述结构,如“从问题到解决方案”,系统识别关键节点并分页布局;二、插入含角色与场景的叙事段落,如“产品经理小李发现留存率下降并提出新方案”,触发故事板等可视化模板;三、利用“起初”“随后”“最终”等时间序列词引导生成线…

    2025年12月2日 科技
    000
  • Tome怎样用叙述流自动排版_Tome用叙述流自动排版【自动排版】

    启用叙述流模式后,T%ignore_a_1%me可根据语义自动调整排版:1. 创建页面并开启叙述流;2. 使用#、##、-等符号输入标题、子标题和列表;3. 插入图片或视频触发智能图文布局;4. 利用AI优化功能调整段落结构并同步更新样式。 ☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费…

    2025年12月2日 科技
    000
  • Copy.ai怎样用受众定位提示精准_Copy.ai用受众定位提示精准【受众提示】

    要提升Copy.ai内容的针对性,需精准设置受众提示。一、输入具体人口信息如年龄、地域、职业,避免模糊描述;二、补充心理与行为特征,如环保偏好、购买场景及痛点;三、设定使用情境与沟通语气,增强代入感;四、对比不同受众设定下的输出,优选更匹配营销目标的版本进行优化。 ☞☞☞AI 智能聊天, 问答助手,…

    2025年12月2日 科技
    000
  • Tome怎样用强调词提示突出重点_Tome用强调词提示突出重点【强调提示】

    通过加粗、高亮、符号和彩色字体四种方式可在Tome中突出关键信息:一、选中文本点击B按钮加粗;二、用黄色等醒目的背景色高亮重点内容;三、在文字旁添加!或*等符号或图标强化提示;四、使用红色等高饱和度颜色设置字体,增强视觉锚点,提升信息传达效率。 ☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, …

    2025年12月2日 科技
    000
  • StableDiffusion怎样用LoRA定制画风_StableDiffusion用LoRA定制画风【画风定制】

    通过加载LoRA模型可精准控制Stable Diffusion的生成画风,需将.safetensors文件放入models/loras/目录并重启WebUI;2. 在提示词中使用调用,结合正向提示词描述风格、反向提示词排除干扰,并调整权重值(0.5~1.0)优化效果;3. 可引入Textual In…

    2025年12月2日 科技
    000
  • 选择最适合数据的嵌入模型:OpenAI 和开源多语言嵌入的对比测试

    openai最近宣布推出他们的最新一代嵌入模型embedding v3,他们声称这是性能最出色的嵌入模型,具备更高的多语言性能。这一批模型被划分为两种类型:规模较小的text-embeddings-3-small和更为强大、体积较大的text-embeddings-3-large。 ☞☞☞AI 智能…

    2025年12月2日 科技
    000
  • 人工智能如何将数据中心转变为可持续性的动力

    ☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜ 数据中心历来是许多技术进步的支柱,现在面临的不仅仅是基础设施提供商的问题。人工智能的快速发展凸显了数据中心迫切需要更加敏捷、创新和协作,为这个新时代提供动力。 人工智能和机器学习的蓬勃发展,加上…

    2025年12月2日 科技
    000
  • MiMo-Embodied— 小米推出的跨领域具身大模型

    mimo-embodied 是小米推出的全球首个开源跨领域具身大模型,首次将自动驾驶与具身智能两大方向深度融合,具备出色的环境感知、任务规划和空间理解能力。该模型基于视觉语言模型(vlm)架构,采用四阶段训练方法——包括具身智能监督微调、自动驾驶监督微调、链式推理微调以及强化学习微调,显著增强了在不…

    2025年12月2日 科技
    000
  • llama3怎么启用多模态融合_llama3多模态融合启用指南及跨媒体处理详解

    要实现Llama3的多模态融合,需集成视觉编码器并调整模型架构。首先选用支持图像理解的Llama3变体如Bunny-Llama-3-8B-V,并从Hugging Face下载模型文件;接着安装transformers和torchvision库,使用CLIPVisionModel和CLIPImageP…

    2025年12月2日 科技
    000

发表回复

登录后才能评论
关注微信