
在使用`torch.vmap`进行函数向量化时,直接在被向量化的函数内部使用`torch.zeros`创建新的张量并期望其自动获得批处理维度是一个常见挑战。本文将深入探讨这一问题,并提供一种优雅的解决方案:通过结合`clone()`和`torch.concatenate`,可以有效地在`vmap`环境中创建和填充具有正确批处理维度的张量,从而避免手动传递预先创建的批处理张量,实现代码的简洁与高效。
torch.vmap与批处理张量创建的挑战
torch.vmap是PyTorch中一个强大的工具,它允许用户对批量输入高效地应用一个单样本函数,而无需手动编写循环或调整张量维度。然而,当被向量化的函数需要在内部创建新的张量时,一个常见的陷阱是这些新创建的张量并不会自动继承批处理维度。
考虑一个计算多项式伴随矩阵的函数polycompanion。这个函数需要根据输入多项式polynomial的维度创建一个新的零矩阵companion,然后填充其部分内容。
import torchpoly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)def polycompanion(polynomial): # 计算伴随矩阵的维度 deg = polynomial.shape[-1] - 2 # 创建一个 (deg+1, deg+1) 的零矩阵 companion = torch.zeros((deg + 1, deg + 1), dtype=torch.float32) # 填充单位矩阵部分 companion[1:, :-1] = torch.eye(deg, dtype=torch.float32) # 填充最后一列,这部分依赖于输入多项式 companion[:, -1] = -1. * polynomial[:-1] / polynomial[-1] return companion# 尝试使用 vmap 向量化该函数polycompanion_vmap = torch.vmap(polycompanion)# 预期会遇到问题,因为 companion 不是 BatchedTensor# print(polycompanion_vmap(poly_batched))# 上述代码会因 vmap 无法处理非 BatchedTensor 的原地操作而失败
在上述代码中,torch.vmap在执行polycompanion时,polynomial是一个BatchedTensor。然而,companion = torch.zeros((deg + 1, deg + 1))创建的companion张量并不是BatchedTensor。当尝试对companion进行原地修改,特别是当修改操作涉及polynomial(一个BatchedTensor)时,vmap无法正确地跟踪和应用批处理语义,导致运行时错误。
常见的“丑陋”解决方案及其局限性
为了规避这个问题,一种常见的(但不推荐的)做法是预先在vmap外部创建批处理的零张量,并将其作为参数传递给被向量化的函数。
import torchpoly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)def polycompanion_workaround(polynomial, companion_template): # 注意:这里的 deg 需要根据 companion_template 的形状来推断,或者与 polynomial 保持一致 # 为了简化,我们假设 companion_template 已经有正确的形状 deg = companion_template.shape[-1] - 1 # 假设 companion_template 已经是 (deg+1, deg+1) # 在 template 上进行原地操作 companion_template[1:, :-1] = torch.eye(deg, dtype=torch.float32) companion_template[:, -1] = -1. * polynomial[:-1] / polynomial[-1] return companion_templatepolycompanion_vmap_workaround = torch.vmap(polycompanion_workaround)# 预先创建批处理的零张量batch_size = poly_batched.shape[0]companion_dim = poly_batched.shape[-1] - 1 # (deg+1)initial_companion = torch.zeros(batch_size, companion_dim, companion_dim, dtype=torch.float32)# 传递预创建的批处理张量output_workaround = polycompanion_vmap_workaround(poly_batched, initial_companion)print("Workaround Output:")print(output_workaround)
输出:
Workaround Output:tensor([[[ 0.0000, 0.0000, -0.2500], [ 1.0000, 0.0000, -0.5000], [ 0.0000, 1.0000, -0.7500]], [[ 0.0000, 0.0000, -0.2500], [ 1.0000, 0.0000, -0.5000], [ 0.0000, 1.0000, -0.7500]]])
这种方法虽然能工作,但它破坏了函数的封装性,使得函数签名的设计变得复杂,且在函数内部无法动态决定新张量的批处理大小,不够灵活。
优雅的解决方案:clone()与torch.concatenate
解决此问题的关键在于,对于需要批处理的张量,我们必须确保其批处理维度在vmap的上下文中是明确的。如果一个张量的一部分内容依赖于批处理输入,而另一部分是固定的,我们可以将它们分别处理,然后合并。
核心思路是:
创建非批处理的固定部分(例如单位矩阵部分)。创建批处理的动态部分(例如最后一列,它依赖于polynomial)。使用clone()确保非批处理部分可以被独立地操作和复制。使用torch.concatenate将这两部分沿着正确的维度合并,同时利用None来添加缺失的维度以进行匹配。
import torchpoly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=torch.float32)def polycompanion_refined(polynomial): deg = polynomial.shape[-1] - 2 # 1. 创建一个非批处理的零矩阵作为基础 companion_base = torch.zeros((deg + 1, deg + 1), dtype=torch.float32) # 2. 填充单位矩阵部分(这部分是固定的,不依赖于批处理) # 注意:这里我们只填充除了最后一列之外的部分 companion_base[1:, :-1] = torch.eye(deg, dtype=torch.float32) # 3. 计算最后一列,这部分是依赖于 polynomial (BatchedTensor) 的,因此会是 BatchedTensor last_column_batched = -1. * polynomial[:-1] / polynomial[-1] # 4. 准备合并: # - companion_base[:, :-1] 是非批处理的,需要 clone 以便后续操作。 # clone() 确保 vmap 可以对每个批次独立处理这个副本。 # - last_column_batched 是一个一维的 BatchedTensor,形状为 (batch_size, deg+1)。 # 为了与 companion_base[:, :-1] (形状为 (deg+1, deg)) 合并, # 需要将其扩展为 (batch_size, deg+1, 1) 的形状,通过 [:, None] 实现。 _companion = torch.concatenate([ companion_base[:, :-1].clone(), # 克隆非批处理的左侧部分 last_column_batched[:, None] # 批处理的右侧列,添加一个维度使其可合并 ], dim=1) # 沿着列维度合并 return _companionpolycompanion_vmap_refined = torch.vmap(polycompanion_refined)output_refined = polycompanion_vmap_refined(poly_batched)print("nRefined Solution Output:")print(output_refined)
输出:
Refined Solution Output:tensor([[[ 0.0000, 0.0000, -0.2500], [ 1.0000, 0.0000, -0.5000], [ 0.0000, 1.0000, -0.7500]], [[ 0.0000, 0.0000, -0.2500], [ 1.0000, 0.0000, -0.5000], [ 0.0000, 1.0000, -0.7500]]])
注意事项与总结
torch.zeros_like的适用性:如果新张量的形状可以直接从一个批处理输入张量派生,并且所有元素都初始化为零,那么torch.zeros_like(batched_input)可以很好地工作,因为它会创建一个BatchedTensor。然而,在伴随矩阵的例子中,我们需要一个特定形状的零矩阵,其大小与输入张量的最后一个维度相关,但并非完全相同,且后续需要部分填充。因此,zeros_like在此场景下并不直接适用。clone()的重要性:在vmap环境中,当一个张量(如companion_base[:, :-1])不是BatchedTensor但需要与BatchedTensor(如last_column_batched)合并时,对其调用clone()可以有效地为每个批次创建一个独立的副本。这使得vmap能够独立地处理每个批次的合并操作,而不会因为原始张量不是批处理的而产生冲突。维度匹配:torch.concatenate要求所有输入张量在非合并维度上具有相同的形状。在我们的例子中,last_column_batched是一个形状为(batch_size, deg+1)的一维批处理张量。为了与形状为(deg+1, deg)的companion_base[:, :-1].clone()合并,我们需要将last_column_batched的形状调整为(batch_size, deg+1, 1),这通过[:, None]索引实现,它在最后一个维度上添加了一个新的维度。
通过这种clone()和torch.concatenate的组合,我们能够在torch.vmap的上下文中,在函数内部灵活且优雅地创建和填充新的批处理张量,从而保持代码的简洁性和功能性,避免了不必要的外部参数传递。这种模式对于在vmap函数中构建复杂张量结构非常有用。
以上就是在torch.vmap中高效创建与操作批处理张量的详细内容,更多请关注创想鸟其它相关文章!
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1377852.html
微信扫一扫
支付宝扫一扫