
本文详细解析 numpy `einsum` 在处理多张量求和时的内部机制。通过逐步分解求和过程和提供等效的显式循环实现,帮助读者理解 `einsum` 如何根据索引字符串高效地执行元素乘法、重排和特定维度上的求和操作,从而掌握其在复杂张量运算中的应用细节。
NumPy 的 einsum 函数提供了一种极其灵活且高效的方式来执行张量运算,包括点积、转置、求和、矩阵乘法等。其核心在于通过一个简洁的字符串表达式来定义输入张量的索引关系以及输出张量的索引顺序。然而,当涉及到多个张量的复杂求和(收缩)操作时,理解其内部元素的组合和求和过程可能会变得有些抽象。本文将深入探讨 np.einsum(‘ijk,jil->kl’, a, b) 这一特定操作的细节,帮助读者透彻理解其背后的机制。
einsum 索引符号解析
首先,我们来解析 np.einsum(‘ijk,jil->kl’, a, b) 中的索引字符串:
ijk: 表示第一个输入张量 a 的维度索引。a 是一个三维张量,其维度顺序为 i、j、k。jil: 表示第二个输入张量 b 的维度索引。b 也是一个三维张量,其维度顺序为 j、i、l。->kl: 表示输出张量的维度索引。输出将是一个二维张量,其维度顺序为 k、l。
理解操作规则:
元素乘法: einsum 会对所有具有相同索引的维度进行“匹配”。例如,a 的第一个维度是 i,b 的第二个维度也是 i;a 的第二个维度是 j,b 的第一个维度也是 j。这意味着在执行元素乘法时,a[i, j, k] 将与 b[j, i, l] 进行匹配并相乘。求和(收缩): 任何出现在输入索引字符串中但未出现在输出索引字符串中的索引,都将被求和(收缩)。在本例中,i 和 j 出现在输入 ijk 和 jil 中,但未出现在输出 kl 中,因此 i 和 j 这两个维度将被求和。输出维度: 出现在输出索引字符串 kl 中的索引 k 和 l 将构成输出张量的维度。
简而言之,np.einsum(‘ijk,jil->kl’, a, b) 的数学表达式等价于:$$ text{output}_{kl} = sum_i sumj text{a}{ijk} cdot text{b}_{jil} $$
案例分析:逐步分解求和过程
为了更直观地理解 einsum 的求和细节,我们可以通过一个技巧来逐步分解它。这个技巧是先执行所有元素的乘法而不进行任何求和,然后手动执行求和步骤。
假设我们有以下两个 NumPy 张量:
import numpy as npa = np.arange(8.).reshape(4, 2, 1)b = np.arange(16.).reshape(2, 4, 2)print("张量 a 的形状:", a.shape) # (4, 2, 1)print("张量 b 的形状:", b.shape) # (2, 4, 2)
步骤一:生成所有未求和的乘积
我们可以通过在输出索引中包含所有输入索引来阻止 einsum 进行求和。对于 ijk,jil->kl,如果我们将输出定义为 ijkl,则 einsum 将返回所有 a[i,j,k] * b[j,i,l] 的乘积,但不会进行任何求和。
# 生成所有元素的乘积,不进行求和intermediate_products = np.einsum('ijk,jil->ijkl', a, b)print("n所有未求和的乘积 (形状: i, j, k, l):")print(intermediate_products)print("形状:", intermediate_products.shape) # (4, 2, 1, 2)
在这个 intermediate_products 张量中,每个元素 [i, j, k, l] 都对应着 a[i, j, k] * b[j, i, l] 的乘积。例如,intermediate_products[0, 0, 0, 0] 对应 a[0, 0, 0] * b[0, 0, 0]。
步骤二:逐步执行求和
现在,我们知道 i 和 j 是需要被求和的维度。在 intermediate_products 张量中,i 对应轴 0,j 对应轴 1。我们可以逐个对这些轴进行求和。
首先,对 j 轴(轴 1)进行求和:
# 对 j 轴 (轴 1) 进行求和sum_over_j = intermediate_products.sum(axis=1)print("n对 j 轴求和后的结果 (形状: i, k, l):")print(sum_over_j)print("形状:", sum_over_j.shape) # (4, 1, 2)
接下来,对 i 轴(轴 0)进行求和:
# 对 i 轴 (轴 0) 进行求和final_result = sum_over_j.sum(axis=0)print("n对 i 轴求和后的最终结果 (形状: k, l):")print(final_result)print("形状:", final_result.shape) # (1, 2)
为了验证,我们可以直接运行原始的 einsum 操作:
original_einsum_result = np.einsum('ijk,jil->kl', a, b)print("n原始 einsum 结果 (形状: k, l):")print(original_einsum_result)print("形状:", original_einsum_result.shape) # (1, 2)# 验证结果是否一致print("n逐步求和结果与原始 einsum 结果是否一致:", np.allclose(final_result, original_einsum_result))
通过这种逐步分解的方式,我们清晰地看到了 einsum 如何先进行元素乘法,然后对指定维度进行求和,最终得到结果。
案例分析:显式循环实现
另一种理解 einsum 细节的方式是将其转换为等效的显式循环。这有助于我们从最基本的元素层面观察操作。
def sum_array_explicit_loop(A, B): # 获取张量 A 的形状 (i_len, j_len, k_len) i_len_a, j_len_a, k_len_a = A.shape # 获取张量 B 的形状,这里我们只关心与输出相关的维度 (j_len, i_len, l_len) # 实际上,B 的形状是 (j_len_b, i_len_b, l_len_b) # 为了匹配 einsum 的索引,B 的实际形状是 (j_len_from_B, i_len_from_B, l_len_from_B) # 我们需要确保 A 和 B 的匹配维度长度一致 j_len_b, i_len_b, l_len_b = B.shape # 检查维度兼容性(einsum 会自动处理) if not (j_len_a == j_len_b and i_len_a == i_len_b): raise ValueError("张量维度不兼容") # 初始化结果张量,其形状为 (k_len, l_len) ret = np.zeros((k_len_a, l_len_b)) # 遍历所有可能的 i, j, k, l 组合 # i 和 j 是将被求和的维度 # k 和 l 是输出张量的维度 for i in range(i_len_a): # 遍历 A 的第一个维度 (i) for j in range(j_len_a): # 遍历 A 的第二个维度 (j) for k in range(k_len_a): # 遍历 A 的第三个维度 (k) for l in range(l_len_b): # 遍历 B 的第三个维度 (l) # 执行元素乘法并累加到 ret[k, l] # 注意 B 的索引是 j, i, l,与 einsum 字符串 'jil' 对应 ret[k, l] += A[i, j, k] * B[j, i, l] return ret# 使用显式循环计算结果explicit_loop_result = sum_array_explicit_loop(a, b)print("n显式循环计算结果:")print(explicit_loop_result)# 验证结果是否与原始 einsum 一致print("显式循环结果与原始 einsum 结果是否一致:", np.allclose(explicit_loop_result, original_einsum_result))
通过这个显式循环,我们可以清晰地看到:
外层循环 for i in range(i_len_a) 和 for j in range(j_len_a) 对应了 i 和 j 这两个被求和的维度。内层循环 for k in range(k_len_a) 和 for l in range(l_len_b) 对应了输出张量的维度。核心操作 ret[k, l] += A[i, j, k] * B[j, i, l] 直接反映了 einsum 字符串 ijk,jil->kl 的含义:A 以 i,j,k 索引,B 以 j,i,l 索引,它们的乘积被累加到以 k,l 索引的结果张量中。当 i 和 j 的循环完成时,所有对应的乘积都已被累加到 ret[k, l] 中,从而实现了对 i 和 j 的求和。
总结与注意事项
einsum 的强大与简洁: einsum 通过其索引字符串提供了一种声明式的方式来描述复杂的张量操作,极大地简化了代码并提高了可读性。性能优势: 尽管显式循环有助于理解,但在实际应用中,NumPy 的 einsum 函数通常会利用底层的 C/Fortran 优化,比纯 Python 循环快得多。索引是核心: 理解 einsum 的关键在于掌握其索引规则:重复索引: 在输入字符串中重复但不在输出字符串中的索引表示求和(收缩)维度。非重复索引: 在输入字符串中不重复或在输出字符串中出现的索引表示输出维度。顺序: 输出字符串中索引的顺序决定了输出张量的维度顺序。多功能性: einsum 不仅可以处理复杂的求和,还可以用于实现转置 (‘ij->ji’)、点积 (‘i,i->’)、矩阵乘法 (‘ij,jk->ik’)、元素乘法 (‘ij,ij->ij’) 等多种张量操作。
通过本文的详细解析,相信读者对 np.einsum 在处理多张量求和时的内部工作机制有了更深入的理解。掌握 einsum 将使您能够更高效、更灵活地处理各种张量计算任务。
以上就是深入理解 NumPy einsum:多张量求和与索引机制详解的详细内容,更多请关注创想鸟其它相关文章!
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1378118.html
微信扫一扫
支付宝扫一扫