Torchmeta:PyTorch的元学习库

Torchmeta:PyTorch的元学习库

作者 | sharmistha chatterjee

来源 | Medium

编辑 | 代码医生团队

介绍

元学习研究和开放源代码库提供了一种通过标准化基准和各种可用数据集对不同算法进行详细比较的方法,从而可以完全控制此评估的复杂性。但是,大多数在线可用的代码都有以下限制:

数据管道通常特定于一个数据集,而对另一个数据集进行测试需要大量的返工。元学习中的基准测试由数据集组成,这给数据管道增加了一层复杂性。因此大多数元学习项目都实现了适合其方法的自己的特定数据加载组件。输入级别缺乏标准会导致围绕每种元学习算法的机制产生差异,从而使比较过程更具挑战性。

为了解决这个限制,Google AI引入了Torchmeta,这是一个基于PyTorch深度学习框架构建的库,可以对多个数据集的元学习算法进行无缝且一致的评估。为了解释Torchmeta,使用了一些初步的概念,例如DataLoader和BatchLoader,可以解释为:

DataLoader是一种通用实用程序,可用作应用程序数据获取层的一部分,以通过批处理和缓存在各种远程数据源(例如数据库或Web服务)上提供简化且一致的API。

批处理是DataLoader的主要功能。批处理加载函数接受键列表,并返回一个Promise,该Promise解析为值列表DataLoader合并在单个执行框架内发生的所有单个加载(一旦解决了包装承诺,即执行),然后是具有全部功能的批处理函数要求的钥匙。

Torchmeta具有以下功能。Torchmeta通过少量的分类和回归为大多数标准基准提供了DataLoader,并提供了新的元数据集抽象。数据加载器与PyTorch的标准数据组件完全兼容,例如Dataset和DataLoader。Torchmeta为所有可用的基准提供了相同的界面,从而使不同数据集之间的转换尽可能无缝。Torchmeta还对PyTorch进行了一些扩展,以简化与元学习算法兼容的模型的开发,其中一些需要更高阶的区分。可用的基准有助于为开发新的元学习算法提供参考。Torchmeta提供了一个框架,研究人员可以围绕该框架构建自己的元学习算法,而不是使数据管道适应其方法。Torchmeta通过将元数据集与算法本身解耦来促进代码重用,从而提供了这一抽象层。

数次学习的数据加载器

快速学习很少能具有使用先验知识快速推广具有有限监督经验的新任务的能力。快速学习分为三类:

数据使用先验知识来增强监督经验。该模型通过先验知识约束假设空间,算法使用先验知识来更改对假设空间中最佳假设参数的搜索。

Torchmeta在其库中具有以下内容。

该库提供了与元学习文献中经典的几次快照分类和回归问题相对应的数据集。该界面旨在支持分类和回归的数据集之间的模块化,以简化对全套基准测试的评估过程。

为了平衡几次学习中固有的数据缺乏,元学习算法从称为元训练集的数据集D-meta = {D1,…,Dn}中获取一些先验知识。在几次学习中,每个元素Di仅包含几个输入/输出对(x,y),其中y取决于问题的性质。由于这些数据集可以包含过去执行的不同任务的示例。Torchmeta提供了一种解决方案,可以使用最少的问题特定组件来自动创建每个数据集Di。

极少回归

少有的回归问题中的大多数是通过不同功能的输入和输出之间的简单回归问题,其中每个功能对应一个任务。这些功能被参数化以允许任务之间的可变性,同时在各个任务之间保持不变的“主题”。例如,这些函数可以是形式为fi(x)= ai sin(x + bi)的正弦波,其中a和b在某些范围内变化。

在Torchmeta中,元训练集继承自名为MetaDataset的对象,每个数据集Di(i = 1,…,n,用户定义n)对应于该函数的特定参数选择,所有在元训练集创建时采样一次的参数。一旦知道了函数的参数,我们就可以通过在给定范围内对输入进行采样并将其提供给函数来创建数据集。

少拍分类

对于少有的分类问题,数据集Di的创建通常遵循两个步骤:

前N个类别是从大量候选项中取样的(对应于“ N向分类”中的N)。在下一步中,每个班级选择k个示例(对应于“ k-shot学习”中的k个)。这是一个分为两步的过程,它是作为继承自MetaDataset的CombinationMetaDataset对象的一部分而提供的,它为用户提供了针对特定问题的大量类候选者的用户规范。为了促进元学习的可重复性,每个任务都与一个唯一的标识符(类标识符的N元组)相关联。选择任务后,对象将返回数据集Di以及来自相应类集中的所有示例。Torchmeta还包括一些有用的功能,以增加诸如旋转图像之类的变体来增加班级候选人的数量。

下图展示了元学习器的作用,在元测试中,另一个不相交的任务集Tt〜p(T)(p(T)->任务T的分布)用于测试元学习者。每个Tt都作用于N个数据集,其中数据集= {D train Tt,D test Tt}。学习者从训练集D train Tt和测试集D test Tt上学习。Tt的平均损耗被视为元学习测试误差。

Torchmeta:PyTorch的元学习库

训练和测试数据集拆分

在元学习中,每个数据集Di分为两部分:训练集(或支持集),用于使模型适应当前的任务;测试集(或查询集),用于评估和元优化。当任务保持不变时,这两个部分不会重叠,在训练和测试集中都没有任何示例。Torchmeta在数据集上引入了一个称为Splitter的包装器,该包装器负责创建训练和测试数据集,以及可选地对数据进行混排。

为了实例化基于Mini Imagenet的5向1发分类问题的元训练集,使用:

数据集= torchmeta.datasets.MiniImagenet(“数据”,num_classes_per_task = 5,meta_train = True,下载= True)

数据集= torchmeta.transforms.ClassSplitter(数据集,num_train_per_class = 1,num_test_per_class = 15,shuffle = True)

除了元训练集之外,大多数基准测试还提供了元测试集,用于对元学习算法的总体评估(以及可能的元验证集)。创建MetaDataset对象时,可以使用meta_test = True(或meta_val = True)而不是meta_train = True来选择这些不同的元数据集。

元数据加载器

可以迭代一些镜头分类和回归问题中的元训练集对象,以生成PyTorch数据集对象,该对象包含在任何标准数据管道(与DataLoader组合)中。

元学习算法在批次任务上运行效果更好。与在PyTorch中将示例与DataLoader一起批处理的方式类似,Torchmeta公开了一个MetaDataLoader,该对象可以在迭代时产生大量任务。这样的元数据加载器能够输出一个大张量,其中包含批处理中来自不同任务的所有示例,如下所示:

数据集= torchmeta.datasets.helpers.miniimagenet(“数据”,镜头= 1,方式= 5,meta_train = True,下载= True)

数据加载器= torchmeta.utils.data.BatchMetaDataLoader(数据集,batch_size = 16)

元学习模块

下图显示了使用学习者的损失和错误信号进行元学习的顺序步骤。

Torchmeta:PyTorch的元学习库

元学习者的学习步骤:来源:

LuckyCola工具库 LuckyCola工具库

LuckyCola工具库是您工作学习的智能助手,提供一系列AI驱动的工具,旨在为您的生活带来便利与高效。

LuckyCola工具库 19 查看详情 LuckyCola工具库

https : //arxiv.org/pdf/1904.05046.pdf

在元学习中,PyTorch中的模型是由称为模块的基本组件创建的,该基本组件等效于神经网络中包含该层的计算图及其参数的一层。这些模块将其参数视为其计算图的组成部分,足以训练带有反向传播的模型。

但是,一些元学习算法需要通过参数更新(例如梯度更新)进行反向传播,以进行元优化(或“外环”),因此涉及高阶微分。

因此,适应PyTorch中的现有模块至关重要,以便它们可以处理任意计算图来替代这些参数。因此,Torchmeta扩展了现有模块,并保留了提供新参数作为附加输入的选项。这些新对象称为MetaModule,它们的默认行为(即,未指定任何其他参数)等同于它们的PyTorch对应对象。否则,如果指定了额外的参数(例如,梯度下降的一步的结果),则MetaModule会将它们视为计算图的一部分,并且反向传播将按预期进行。

Torchmeta:PyTorch的元学习库

上图描述了带有或不带有附加参数的线性模块(称为MetaLinear)的扩展如何工作,以及对梯度的影响。左图显示了元模块作为参数W和b的容器的实例,以及带有占位符的重量和偏差参数的计算图。中间的图显示了MetaLinear元模块的默认行为,其中的占位符用W&b替换,这等效于PyTorch的Linear模块。右图显示了如何使用完整的计算图填充这些占位符,就像一个梯度下降步骤。在后一种情况下,外循环更新中必需的外循环相对于W的坡度可以正确地一直流到参数W。

下面的代码演示了如何从Torchmeta的现有数据集中生成训练,验证和测试元数据集。

代码语言:javascript代码运行次数:0运行复制

from torchmeta.datasets import Omniglot, MiniImagenet, CIFARFS, FC100, TieredImagenet, TCGAfrom torchmeta.transforms import Categorical, ClassSplitter, Rotationfrom torchvision.transforms import Compose, Resize, ToTensorfrom torchmeta.utils.data import BatchMetaDataLoader dataset = Omniglot("data",                   # Number of ways                   num_classes_per_task=5,                   # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)                   transform=Compose([Resize(28), ToTensor()]),                   # Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))                   target_transform=Categorical(num_classes=5),                   # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)                   class_augmentations=[Rotation([90, 180, 270])],                   meta_train=True,                   download=True)                   dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)for batch in dataloader:    train_inputs, train_targets = batch["train"]    print('Train inputs shape: {0}'.format(train_inputs.shape))    # (16, 25, 1, 28, 28)    print('Train targets shape: {0}'.format(train_targets.shape))  # (16, 25)     test_inputs, test_targets = batch["test"]    print('Test inputs shape: {0}'.format(test_inputs.shape))      # (16, 75, 1, 28, 28)    print('Test targets shape: {0}'.format(test_targets.shape))    # (16, 75)

下图显示了下载后从Omnichlot和MiniImagenet从Torchmeta的数据集中生成的元学习数据集。

此处Omniglot数据集包含50个字母。将其分为30个字母的背景集和20个字母的评估集。在将背景大小调整为28x28张量后,应该使用背景集学习有关字符的一般知识(例如,特征学习,元学习)。此外,将标签传送到整数Glagolitic / character01”,“ Sanskrit / character14”,……)到(0,1,..,n)。

MiniImageNet包含60,000个84x84 RGB图像,每个类别600个图像。使用Torchmeta,可以生成HDF5格式的元学习数据集。

Torchmeta具有以HDF5格式下载数据集的功能,该功能允许:

要将包含HDF5文件的文件夹(包括子文件夹)用作数据源,在数据集中维护一个简单的HDF5组层次结构,启用延迟数据加载(即应DataLoader的请求),以便允许使用不适合内存的数据集,配备了数据缓存以加快数据加载过程,并且允许对源或目标数据集进行自定义转换。

Torchmeta:PyTorch的元学习库

用于定义Torchmeta数据集(例如Omniglot)的元学习参数的TieredImagenetClassDataset包含来自34个类别的图像。元训练/验证/测试拆分超过20/6/8个类别。每个类别包含10到30个类别。按类别划分(而不是按类别划分)可确保所有训练课程与测试课程完全不同(不同于Mini-Imagenet)。它带有以下一组参数,这些参数定义了训练,验证和测试数据集的划分以及应用于它们的转换和增强技术

num_classes_per_task(int):每个任务的类数,对应于“ N向”分类中的“ N”。

meta_train:bool(`False`):使用数据集的元火车拆分。如果设置为True,则必须将参数meta_val和meta_test设置为False。这三个参数中的一个必须正确设置为“ True”。

meta_val:bool(`False`):使用数据集的元验证拆分。如果设置为True,则参数meta_train和metatest必须设置为False。这三个参数中只有一个必须设置为“ True”。

meta_test:bool(`False`):使用数据集的元测试拆分。如果设置为True,则参数meta_train和meta_val必须设置为False。这三个参数中只有一个必须设置为“ True”。

meta_split:{'train','val','test'}中的字符串,可选要使用的拆分名称,如果所有三个都设置为False,则覆盖参数meta_train,metaval和metatest。

transform:可调用的,可选的:获取“ PIL”图像并返回转换后版本的函数/转换。

target_transform:可调用,可选:接受目标并返回转换版本的函数/转换。

dataset_transform:可调用,可选:函数/转换,它接受数据集(即任务),并返回其转换后的版本。-> torchmeta.transforms.ClassSplitter()。

class_augmentations:可调用的,可选的列表:使用新类扩展数据集的函数列表。这些类是现有类的转换。

download:bool(默认值:False)如果为True,则下载pickle文件并处理根目录(位于tieredimagenet文件夹下)中的数据集。如果数据集已经可用,则不会再次下载/处理数据集。

结论

在此博客中,了解了Google AI最新发布的库Torchmeta,它提供了哪些功能以及可以解决什么样的元学习问题。可以浏览其他PyTorch元学习库,例如元Agonistic机器学习,以学习可以快速适应新任务的网络初始化。

https://github.com/dragen1860/MAML-Pytorch

如下图所示,在Torchmeta中很少有镜头学习可用于图像分类。

Torchmeta:PyTorch的元学习库

参考

https://github.com/markdtw/meta-learning-lstm-pytorch

https://arxiv.org/abs/1909.06576

https://docs.graphene-python.org/en/latest/execution/dataloader/

以上就是Torchmeta:PyTorch的元学习库的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/367585.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
算力架构・生态协同・服务革新:华为助力行业数智化“三维破局”
上一篇 2025年11月6日 06:06:45
CentOS HDFS配置中资源分配策略
下一篇 2025年11月6日 06:06:46

相关推荐

  • 修复Django电商项目中AJAX过滤产品列表图片不显示问题

    在Django电商项目中,当使用AJAX动态加载过滤后的产品列表时,常遇到图片无法正常显示的问题。这通常是由于前端模板中图片加载方式(如data-setbg属性结合JavaScript库)与AJAX动态内容更新机制不兼容所致。解决方案是直接在AJAX返回的HTML中使用标准的标签来渲染图片,确保浏览…

    2026年5月10日
    000
  • 开源免费PHP工具 PHP开发效率提升利器

    推荐开源免费PHP开发工具以提升效率:VS Code、Sublime Text轻量高效,PhpStorm专业强大;调试用Xdebug、Kint、Ray;依赖管理选Composer;代码质量工具包括PHPStan、Psalm、PHP_CodeSniffer;数据库管理可用%ignore_a_1%MyA…

    2026年5月10日
    000
  • Matplotlib 地图中多类型图例的创建与优化

    Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化

    本教程旨在解决matplotlib地图可视化中,如何在一个图例中同时展示颜色块(如区域分类)和自定义标记(如特定兴趣点)的问题。文章详细介绍了当传统`patch`对象无法正确显示标记时,如何利用`matplotlib.lines.line2d`创建标记图例句柄,并将其与颜色块图例句柄合并,从而生成一…

    2026年5月10日 用户投稿
    100
  • Golang JSON序列化:控制敏感字段暴露的最佳实践

    本教程探讨golang中如何高效控制结构体字段在json序列化时的可见性。当需要将包含敏感信息的结构体数组转换为json响应时,通过利用`encoding/json`包提供的结构体标签,特别是`json:”-“`,可以轻松实现对特定字段的忽略,从而避免敏感数据泄露,确保api…

    2026年5月10日
    000
  • 利用海象运算符简化条件赋值:Python教程与最佳实践

    本文旨在探讨Python中海象运算符(:=)在条件赋值场景下的应用。通过对比传统if/else语句与海象运算符,以及条件表达式,分析海象运算符在简化代码、提高可读性方面的优势与局限性。并通过具体示例,展示如何在列表推导式等场景下合理使用海象运算符,同时强调其潜在的复杂性及替代方案,帮助开发者更好地掌…

    2026年5月10日
    000
  • 比特币新手教程 比特币交易平台有哪些

    比特币是一种去中心化的数字货币,基于区块链技术实现点对点交易,具有匿名性、有限发行和不可篡改等特点;新手可通过交易所购买,P2P交易获得比特币,常用平台包括Binance、OKX和Huobi;交易流程包括注册账户、实名认证、绑定支付方式、充值法币并下单购买,可选择市价单或限价单;比特币存储方式有交易…

    2026年5月10日
    000
  • c++中的SFINAE技术是什么_c++模板编程中的SFINAE原理与应用

    SFINAE 是“替换失败不是错误”的原则,指模板实例化时若参数替换导致错误,只要存在其他合法候选,编译器不报错而是继续重载决议。它用于条件启用模板、类型检测等场景,如通过 decltype 或 enable_if 控制函数重载,实现类型特征判断。尽管 C++20 引入 Concepts 简化了部分…

    2026年5月10日
    000
  • Golang gRPC流式请求异常处理

    在Golang的gRPC流式通信中,必须通过context.Context处理异常。应监听上下文取消或超时,及时释放资源,设置合理超时,避免连接长时间挂起,并在goroutine中通过context控制生命周期。 在使用 Golang 和 gRPC 实现流式通信时,异常处理是确保服务健壮性的关键部分…

    2026年5月10日
    000
  • Go语言mgo查询构建:深入理解bson.M与日期范围查询的正确实践

    本文旨在解决go语言mgo库中构建复杂查询时,特别是涉及嵌套`bson.m`和日期范围筛选的常见错误。我们将深入剖析`bson.m`的类型特性,解释为何直接索引`interface{}`会导致“invalid operation”错误,并提供一种推荐的、结构清晰的代码重构方案,以确保查询条件能够正确…

    2026年5月10日
    100
  • vscode上怎么运行html_vscode上运行html步骤【指南】

    首先保存文件为.html格式,再通过浏览器或Live Server插件打开预览;推荐安装Live Server实现本地服务器运行与实时刷新,提升开发体验。 在 VS Code 上运行 HTML 文件并不需要复杂的配置,只需几个简单步骤即可预览页面效果。VS Code 本身是一个代码编辑器,不直接运行…

    2026年5月10日
    100
  • RichHandler与Rich Progress集成:解决显示冲突的教程

    在使用rich库的`richhandler`进行日志输出并同时使用`progress`组件时,可能会遇到显示错乱或溢出问题。这通常是由于为`richhandler`和`progress`分别创建了独立的`console`实例导致的。解决方案是确保日志处理器和进度条组件共享同一个`console`实例…

    2026年5月10日
    000
  • 修复点击时按钮抖动:CSS垂直对齐实践

    本文探讨了在Web开发中,交互式按钮(如播放/暂停按钮)在点击时发生意外垂直位移的问题。通过分析CSS样式变化对元素布局的影响,我们发现这是由于按钮不同状态下的边框样式和内边距改变,以及默认的垂直对齐行为共同作用所致。核心解决方案是利用CSS的vertical-align属性,将其设置为middle…

    2026年5月10日
    000
  • Golang goroutine与channel调试技巧

    使用go run -race检测数据竞争,结合runtime.NumGoroutine监控协程数量,通过pprof分析阻塞调用栈,利用select超时避免永久阻塞,有效排查goroutine泄漏、死锁和数据竞争问题。 Go语言的goroutine和channel是并发编程的核心,但它们也带来了调试上…

    2026年5月10日
    000
  • 使用 Jupyter Notebook 进行探索性数据分析

    Jupyter Notebook通过单元格实现代码与Markdown结合,支持数据导入(pandas)、清洗(fillna)、探索(matplotlib/seaborn可视化)、统计分析(describe/corr)和特征工程,便于记录与分享分析过程。 Jupyter Notebook 是进行探索性…

    2026年5月10日
    000
  • 《魔兽世界》将于6月11日开启国服回归技术测试

    《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试

    《%ign%ignore_a_1%re_a_1%》官方宣布,将于6月11日开启国服回归技术测试,时间为7天,并称可以在6月内正式开服,玩家们可以访问官网下载战网客户端并预下载“巫妖王之怒”客户端,技术测试详情见下图。 WordAi WordAI是一个AI驱动的内容重写平台 53 查看详情 以上就是《…

    2026年5月10日 用户投稿
    200
  • 如何在HTML中插入表单元素_HTML表单控件与输入类型使用指南

    HTML表单通过标签构建,包含action和method属性定义数据提交目标与方式,常用input类型如text、password、email等适配不同输入需求,配合label、required、placeholder提升可用性,结合textarea、select、button等控件实现完整交互,是…

    2026年5月10日
    000
  • 前端缓存策略与JavaScript存储管理

    根据数据特性选择合适的存储方式并制定清晰的读写与清理逻辑,能显著提升前端性能;合理运用Cookie、localStorage、sessionStorage、IndexedDB及Cache API,结合缓存策略与定期清理机制,可在保证用户体验的同时避免安全与性能隐患。 前端缓存和JavaScript存…

    2026年5月10日
    100
  • HTML5网页如何实现手势操作 HTML5网页移动端交互的处理技巧

    首先利用原生touch事件实现滑动判断,再通过preventDefault解决滚动冲突,接着引入Hammer.js处理复杂手势,最后通过优化点击区域、避免事件冲突和增加视觉反馈提升体验。 在移动端浏览器中,HTML5网页可以通过触摸事件实现手势操作,提升用户体验。虽然原生JavaScript提供了基…

    2026年5月10日
    000
  • 创建指定大小并填充特定数据的Golang文件教程

    本文将介绍如何使用Golang创建一个指定大小的文件,并用特定数据填充它。我们将使用 `os` 包提供的函数来创建和截断文件,从而实现快速生成大文件的目的。示例代码展示了如何创建一个10MB的文件,并将其填充为全零数据。掌握这些方法,可以方便地在例如日志系统或磁盘队列等场景中,预先创建测试文件或初始…

    2026年5月10日
    000
  • Python命令怎样使用profile分析脚本性能 Python命令性能分析的基础教程

    使用Python的cProfile模块分析脚本性能最直接的方式是通过命令行执行python -m cProfile your_script.py,它会输出每个函数的调用次数、总耗时、累积耗时等关键指标,帮助定位性能瓶颈;为进一步分析,可将结果保存为文件python -m cProfile -o ou…

    2026年5月10日
    000

发表回复

登录后才能评论
关注微信