在 torch.vmap 中高效处理内部张量创建

在 torch.vmap 中高效处理内部张量创建

理解 torch.vmap 与内部张量创建的挑战

torch.vmap 是 PyTorch 提供的一个强大工具,它允许我们将一个处理单个样本的函数(即非批处理函数)转换为一个能够高效处理一批样本的函数,而无需手动管理批处理维度。这在编写通用代码和加速计算方面非常有用。然而,当被 vmap 向量化的函数内部需要创建新的张量,并且这些张量的形状依赖于批处理输入的形状时,就会遇到一个常见的陷阱。

考虑以下场景:我们有一个函数 polycompanion,它接收一个多项式系数张量,并计算其伴随矩阵。伴随矩阵的维度取决于多项式的次数。

import torchpoly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)def polycompanion(polynomial):    # polynomial.shape[-1] 是多项式系数的个数,例如 [a, b, c, d] 代表 ax^3 + bx^2 + cx + d    # 次数 deg = 系数个数 - 1 - 1 = 系数个数 - 2 (如果最后一个系数是常数项)    deg = polynomial.shape[-1] - 2    # 尝试创建伴随矩阵    companion = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)    # 填充单位矩阵部分    companion[1:, :-1] = torch.eye(deg, dtype=torch.float32)    # 填充最后一列    # 注意这里 polynomial[:-1] 表示除了最后一个系数以外的所有系数    # polynomial[-1] 表示最后一个系数    companion[:, -1] = -1. * polynomial[:-1] / polynomial[-1]    return companion# 尝试使用 vmap 向量化polycompanion_vmap = torch.vmap(polycompanion)try:    print(polycompanion_vmap(poly_batched))except Exception as e:    print(f"Initial attempt failed: {e}")

上述代码在执行 polycompanion_vmap(poly_batched) 时会失败。原因是 polycompanion 函数内部通过 torch.zeros((deg+1, deg+1)) 创建了一个新的 companion 张量。尽管 deg 是从 polynomial(一个批处理输入)派生出来的,但 torch.zeros 本身创建的是一个普通的、非批处理的张量。当 vmap 试图对这个非批处理的 companion 张量执行批处理操作(例如,将其与从 polynomial 派生的批处理张量进行索引或赋值)时,就会出现维度不匹配或类型不兼容的问题,因为 vmap 期望所有参与运算的张量都带有批处理维度。

为什么 torch.zeros 不会自动批处理?

torch.vmap 的核心机制是跟踪批处理维度,并将操作提升到批处理层面。它能识别作为 vmap 输入的张量及其通过各种张量操作(如加法、乘法、切片等)派生出的张量,并为它们自动添加和管理批处理维度。然而,像 torch.zeros 这种从零开始创建新张量的操作,其默认行为是创建一个标准张量,不包含任何批处理维度信息。即使其形状参数 (deg+1, deg+1) 是基于批处理输入计算得出的,torch.zeros 也无法“感知”到外部的 vmap 上下文,从而无法自动生成一个 BatchedTensor。

torch.zeros_like 是一个例外,因为它基于一个已存在的张量来创建新张量。如果这个已存在的张量是 BatchedTensor,那么 torch.zeros_like 也能创建出一个 BatchedTensor。但在本例中,我们没有一个现成的 BatchedTensor 可以作为 zeros_like 的模板来创建 companion。

规避方案:预分配与外部传递

一种可行的(但不理想的)规避方法是,在调用 vmap 之前,手动创建一个带有批处理维度的 companion 张量,并将其作为函数的额外输入传递给 vmap。

def polycompanion_workaround(polynomial, companion_template):    # 注意:这里的 deg 现在从 companion_template 的形状推断,因为它已经有了批处理维度    deg = companion_template.shape[-1] - 1     # 在传入的 companion_template 上进行就地修改    companion_template[1:, :-1] = torch.eye(deg, dtype=torch.float32)    companion_template[:, -1] = -1. * polynomial[:-1] / polynomial[-1]    return companion_templatepolycompanion_vmap_workaround = torch.vmap(polycompanion_workaround)# 预先创建批处理的 companion 模板# poly_batched.shape[0] 是批次大小# poly_batched.shape[-1]-1 是伴随矩阵的行/列维度companion_init_shape = (poly_batched.shape[0], poly_batched.shape[-1] - 1, poly_batched.shape[-1] - 1)pre_batched_companion = torch.zeros(companion_init_shape, dtype=torch.float32)print("--- Workaround Output ---")print(polycompanion_vmap_workaround(poly_batched, pre_batched_companion))

这种方法虽然能够正确输出结果,但存在明显缺点:

函数签名改变:polycompanion 函数现在需要一个额外的 companion_template 参数,这破坏了其原始的、独立处理单个样本的语义。外部依赖:在调用 vmap 之前,必须手动计算并创建具有正确批处理维度的 pre_batched_companion 张量,增加了代码的复杂性和耦合性。

推荐解决方案:利用 clone 和 concatenate

为了在 vmap 上下文中优雅地创建和填充张量,我们可以避免在非批处理的 torch.zeros 张量上进行就地修改。相反,我们将伴随矩阵视为由两部分组成:一个包含单位矩阵的左侧部分,以及一个由多项式系数计算得出的右侧(最后一列)部分。然后,我们分别构建这两部分,并使用 torch.concatenate 将它们合并。

关键在于:

静态部分:对于伴随矩阵中相对固定的部分(如单位矩阵),我们可以先在一个非批处理的 torch.zeros 张量上构建。动态部分:对于依赖于批处理输入的部分(如最后一列),我们直接从批处理输入 polynomial 计算。合并:使用 torch.concatenate 将这两部分合并。concatenate 是一种张量操作,vmap 能够很好地处理其批处理行为。

以下是改进后的 polycompanion 函数:

def polycompanion_optimized(polynomial):    deg = polynomial.shape[-1] - 2    # 1. 创建一个基础的非批处理张量来填充单位矩阵部分    # 这是一个临时的、非批处理的张量    base_matrix = torch.zeros((deg + 1, deg + 1), dtype=torch.float32)    base_matrix[1:, :-1] = torch.eye(deg, dtype=torch.float32)    # 2. 提取 base_matrix 的左侧部分,并进行克隆    # clone() 创建了一个新的张量,虽然它仍然是非批处理的,    # 但在 vmap 上下文中,当它与批处理张量拼接时,vmap 会正确处理    left_part = base_matrix[:, :-1].clone()    # 3. 计算伴随矩阵的最后一列    # 这一部分完全从批处理输入 polynomial 派生,因此 vmap 会将其视为批处理张量    # polynomial[:-1] 是 (deg+1,) 形状    # polynomial[-1] 是标量    # 结果是一个 (deg+1,) 形状的张量    last_column_values = -1. * polynomial[:-1] / polynomial[-1]    # 4. 扩展最后一列的维度,使其可以与 left_part 进行拼接    # last_column_values 是 (deg+1,),我们需要将其变为 (deg+1, 1)    last_column_reshaped = last_column_values[:, None]     # 5. 使用 concatenate 组合左右两部分    # vmap 会识别 left_part 和 last_column_reshaped,并为它们在批次维度上执行拼接    final_companion = torch.concatenate([left_part, last_column_reshaped], dim=1)    return final_companionpolycompanion_vmap_optimized = torch.vmap(polycompanion_optimized)print("n--- Optimized Solution Output ---")print(polycompanion_vmap_optimized(poly_batched))

输出:

tensor([[[ 0.0000,  0.0000, -0.2500],         [ 1.0000,  0.0000, -0.5000],         [ 0.0000,  1.0000, -0.7500]],        [[ 0.0000,  0.0000, -0.2500],         [ 1.0000,  0.0000, -0.5000],         [ 0.0000,  1.0000, -0.7500]]])

这个解决方案成功地生成了批处理的伴随矩阵,同时保持了 polycompanion_optimized 函数的简洁性,使其能够独立处理单个样本,并且不需要外部预分配张量。

注意事项与最佳实践

函数式编程思维:在使用 torch.vmap 时,尽量采用函数式编程的思维,即函数主要通过返回新张量来完成操作,而不是通过就地修改输入张量。这有助于 vmap 更好地跟踪张量的依赖关系和批处理维度。避免在 vmap 内部进行就地修改:除非你确切知道自己在做什么,并且只对批处理输入进行就地修改,否则应避免在 vmap 内部对非批处理张量进行就地修改。clone() 的作用:在上述解决方案中,clone() 是关键。它创建了一个 base_matrix 切片的新副本。虽然 base_matrix 本身是非批处理的,但通过 clone() 得到的 left_part 可以被 concatenate 操作正确地与批处理的 last_column_reshaped 结合。维度匹配:当使用 torch.concatenate 或 torch.stack 时,确保所有参与拼接的张量在非拼接维度上形状一致。[:, None] 技巧常用于为张量添加一个维度,使其符合拼接要求。性能考量:虽然 concatenate 方案解决了功能问题,但频繁创建和拼接中间张量可能会带来一定的性能开销。对于极致性能敏感的场景,可能需要权衡 vmap 的便利性与手动批处理的优化潜力。然而,对于大多数情况,vmap 带来的代码简化和潜在加速(尤其是在支持的后端)是值得的。

总结

在 torch.vmap 中处理函数内部的张量创建是一个常见的挑战。通过理解 vmap 对批处理张量的期望,并采用 clone() 结合 torch.concatenate 的策略,我们能够优雅地构建出所需的批处理张量,而无需妥协函数的简洁性或引入复杂的外部依赖。这种方法体现了在 PyTorch 中进行高效张量操作的灵活性和强大功能,是掌握 torch.vmap 的一个重要技巧。

以上就是在 torch.vmap 中高效处理内部张量创建的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1377848.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
Pandas DataFrame中基于条件创建新列的字符串处理技巧
上一篇 2025年12月14日 18:08:08
Tkinter Entry 控件在获取焦点时自动清除默认文本的教程
下一篇 2025年12月14日 18:08:20

相关推荐

  • composer require-dev和require有什么不同_Composer Require与Require-Dev区别解析

    require用于声明项目运行必需的依赖,如框架、数据库组件和第三方SDK,这些包会随项目部署到生产环境;2. require-dev用于声明仅在开发和测试阶段需要的工具,如PHPUnit、PHPStan、Faker等,不会默认部署到生产环境;3. 安装时composer install根据环境决定…

    2026年5月10日
    1000
  • Golang JSON序列化:控制敏感字段暴露的最佳实践

    本教程探讨golang中如何高效控制结构体字段在json序列化时的可见性。当需要将包含敏感信息的结构体数组转换为json响应时,通过利用`encoding/json`包提供的结构体标签,特别是`json:”-“`,可以轻松实现对特定字段的忽略,从而避免敏感数据泄露,确保api…

    2026年5月10日
    000
  • 利用海象运算符简化条件赋值:Python教程与最佳实践

    本文旨在探讨Python中海象运算符(:=)在条件赋值场景下的应用。通过对比传统if/else语句与海象运算符,以及条件表达式,分析海象运算符在简化代码、提高可读性方面的优势与局限性。并通过具体示例,展示如何在列表推导式等场景下合理使用海象运算符,同时强调其潜在的复杂性及替代方案,帮助开发者更好地掌…

    2026年5月10日
    000
  • Debian syslog性能优化技巧有哪些

    提升Debian系统syslog (通常基于rsyslog)性能,关键在于精简配置和高效处理日志。以下策略能有效优化日志管理,提升系统整体性能: 精简配置,高效加载: 在rsyslog配置文件中,仅加载必要的输入、输出和解析模块。 使用全局指令设置日志级别和格式,避免不必要的处理。 自定义模板: 创…

    2026年5月10日
    000
  • 比特币新手教程 比特币交易平台有哪些

    比特币是一种去中心化的数字货币,基于区块链技术实现点对点交易,具有匿名性、有限发行和不可篡改等特点;新手可通过交易所购买,P2P交易获得比特币,常用平台包括Binance、OKX和Huobi;交易流程包括注册账户、实名认证、绑定支付方式、充值法币并下单购买,可选择市价单或限价单;比特币存储方式有交易…

    2026年5月10日
    000
  • c++中的SFINAE技术是什么_c++模板编程中的SFINAE原理与应用

    SFINAE 是“替换失败不是错误”的原则,指模板实例化时若参数替换导致错误,只要存在其他合法候选,编译器不报错而是继续重载决议。它用于条件启用模板、类型检测等场景,如通过 decltype 或 enable_if 控制函数重载,实现类型特征判断。尽管 C++20 引入 Concepts 简化了部分…

    2026年5月10日
    000
  • Go语言mgo查询构建:深入理解bson.M与日期范围查询的正确实践

    本文旨在解决go语言mgo库中构建复杂查询时,特别是涉及嵌套`bson.m`和日期范围筛选的常见错误。我们将深入剖析`bson.m`的类型特性,解释为何直接索引`interface{}`会导致“invalid operation”错误,并提供一种推荐的、结构清晰的代码重构方案,以确保查询条件能够正确…

    2026年5月10日
    100
  • 理解编程指令:当结果正确,但实现方式不符要求时

    本文探讨了在编程实践中,即使程序输出了正确的结果,但若其实现方式未能严格遵循既定指令,仍可能被视为“不正确”的问题。我们将通过具体示例,对比直接求和与累加求和两种实现策略,强调理解和遵守编程规范的重要性,以确保代码的健壮性、可维护性及符合项目要求。 在软件开发过程中,我们经常会遇到这样的情况:编写的…

    2026年5月10日
    000
  • Golang goroutine与channel调试技巧

    使用go run -race检测数据竞争,结合runtime.NumGoroutine监控协程数量,通过pprof分析阻塞调用栈,利用select超时避免永久阻塞,有效排查goroutine泄漏、死锁和数据竞争问题。 Go语言的goroutine和channel是并发编程的核心,但它们也带来了调试上…

    2026年5月10日
    000
  • 使用 Jupyter Notebook 进行探索性数据分析

    Jupyter Notebook通过单元格实现代码与Markdown结合,支持数据导入(pandas)、清洗(fillna)、探索(matplotlib/seaborn可视化)、统计分析(describe/corr)和特征工程,便于记录与分享分析过程。 Jupyter Notebook 是进行探索性…

    2026年5月10日
    000
  • 《魔兽世界》将于6月11日开启国服回归技术测试

    《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试

    《%ign%ignore_a_1%re_a_1%》官方宣布,将于6月11日开启国服回归技术测试,时间为7天,并称可以在6月内正式开服,玩家们可以访问官网下载战网客户端并预下载“巫妖王之怒”客户端,技术测试详情见下图。 WordAi WordAI是一个AI驱动的内容重写平台 53 查看详情 以上就是《…

    2026年5月10日 用户投稿
    200
  • 如何在HTML中插入表单元素_HTML表单控件与输入类型使用指南

    HTML表单通过标签构建,包含action和method属性定义数据提交目标与方式,常用input类型如text、password、email等适配不同输入需求,配合label、required、placeholder提升可用性,结合textarea、select、button等控件实现完整交互,是…

    2026年5月10日
    000
  • 网站标题关键词更新后,搜索引擎为何仍显示旧标题?

    网站标题更新后,搜索引擎为何显示旧标题? 网站SEO优化中,站长常修改网站标题关键词,期望搜索结果显示自定义标题。然而,即使更新标签、meta keywords、meta description和结构化数据中的name属性后,搜索结果仍显示旧标题,这令人费解。本文将对此进行解释。 问题:站长修改了网…

    2026年5月10日
    100
  • 创建指定大小并填充特定数据的Golang文件教程

    本文将介绍如何使用Golang创建一个指定大小的文件,并用特定数据填充它。我们将使用 `os` 包提供的函数来创建和截断文件,从而实现快速生成大文件的目的。示例代码展示了如何创建一个10MB的文件,并将其填充为全零数据。掌握这些方法,可以方便地在例如日志系统或磁盘队列等场景中,预先创建测试文件或初始…

    2026年5月10日
    000
  • Python命令怎样使用profile分析脚本性能 Python命令性能分析的基础教程

    使用Python的cProfile模块分析脚本性能最直接的方式是通过命令行执行python -m cProfile your_script.py,它会输出每个函数的调用次数、总耗时、累积耗时等关键指标,帮助定位性能瓶颈;为进一步分析,可将结果保存为文件python -m cProfile -o ou…

    2026年5月10日
    000
  • 使用 WebCodecs VideoDecoder 实现精确逐帧回退

    本文档旨在解决在使用 WebCodecs VideoDecoder 进行视频解码时,实现精确逐帧回退的问题。通过比较帧的时间戳与目标帧的时间戳,可以避免渲染中间帧,从而提高用户体验。本文将提供详细的解决方案和示例代码,帮助开发者实现精确的视频帧控制。 在使用 WebCodecs VideoDecod…

    2026年5月10日
    000
  • 如何插入查询结果数据_SQL插入Select查询结果方法

    如何插入查询结果数据_SQL插入Select查询结果方法如何插入查询结果数据_SQL插入Select查询结果方法如何插入查询结果数据_SQL插入Select查询结果方法如何插入查询结果数据_SQL插入Select查询结果方法

    使用INSERT INTO…SELECT语句可高效插入数据,通过NOT EXISTS、LEFT JOIN、MERGE语句或唯一约束避免重复;表结构不一致时可通过别名、类型转换、默认值或计算字段处理;结合存储过程可提升可维护性,支持参数化与动态SQL。 将查询结果数据插入到另一个表中,可以…

    2026年5月10日 用户投稿
    000
  • PHP动态生成表单输入与POST数据获取实践指南

    本教程详细阐述了如何在php中根据动态数据源(如数据库值)生成多个表单输入框,并演示了如何通过post方法准确无误地获取这些动态生成的输入值。文章强调了正确的输入框命名策略,避免了常见的命名误区,并提供了完整的代码示例,确保开发者能够高效处理动态表单数据。 动态生成表单输入 在Web开发中,我们经常…

    2026年5月10日
    000
  • Discord.py 交互按钮超时与持久化解决方案

    本教程旨在解决Discord.py中交互按钮在一段时间后出现“This Interaction Failed”错误的问题。我们将深入探讨视图(View)的超时机制,并提供通过正确设置timeout参数以及利用bot.add_view()方法实现按钮持久化的具体方案,确保您的机器人交互功能稳定可靠,即…

    2026年5月10日
    000
  • Debian Copilot的社区活跃度如何

    debian copilot是codeberg社区维护的ai助手,旨在为debian用户提供服务。尽管搜索结果中没有直接提供关于debian copilot社区支持活跃度的具体数据,但我们可以通过debian社区的整体活跃度和特点来推断其活跃性。 Debian社区的一般情况: Debian拥有详尽的…

    2026年5月10日
    000

发表回复

登录后才能评论
关注微信