
本文旨在解释为什么在 Numba 编译的函数中添加 break 语句有时会导致性能显著下降,并提供一种通过分块处理数据来避免此问题的方法。文章将深入探讨 LLVM 编译器在代码向量化方面的限制,并提供实际代码示例和性能测试结果,帮助读者理解并解决类似问题。
在 Numba 中,性能优化很大程度上依赖于 LLVM 编译器将 Python 代码转换为高效的机器码。然而,某些代码模式可能会阻止 LLVM 进行有效的向量化,从而导致性能下降。一个典型的例子是在循环中使用 break 语句。
考虑以下两个 Numba 函数,它们的功能相似,但一个包含 break 语句:
import numbaimport numpy as npfrom timeit import timeit@numba.njitdef count_in_range(arr, min_value, max_value): count = 0 for a in arr: if min_value < a < max_value: count += 1 return count@numba.njitdef count_in_range2(arr, min_value, max_value): count = 0 for a in arr: if min_value < a < max_value: count += 1 break # <---- break here return countrng = np.random.default_rng(0)arr = rng.random(10 * 1000 * 1000)# To compare on even conditions, choose the condition that does not terminate early.min_value = 0.5max_value = min_value - 1e-10assert not np.any(np.logical_and(min_value <= arr, arr <= max_value))n = 100for f in (count_in_range, count_in_range2): f(arr, min_value, max_value) elapsed = timeit(lambda: f(arr, min_value, max_value), number=n) / n print(f"{f.__name__}: {elapsed * 1000:.3f} ms")
这段代码中,count_in_range 函数统计数组 arr 中位于 min_value 和 max_value 之间的元素的数量。count_in_range2 函数的功能类似,但它在找到第一个满足条件的元素后会立即跳出循环。令人惊讶的是,count_in_range2 函数的性能通常比 count_in_range 函数差得多。
原因分析:LLVM 向量化失败
Numba 使用 LLVM 编译器工具链将 Python 代码编译为本地代码。LLVM 会尝试自动向量化循环,即使用 SIMD (Single Instruction, Multiple Data) 指令并行处理多个数据元素。然而,当循环中存在 break 语句时,LLVM 通常无法进行有效的向量化。
为了更深入地了解这一点,我们可以使用 Clang (一个基于 LLVM 的 C++ 编译器) 来编译等效的 C++ 代码。以下是 count_in_range 函数的 C++ 版本:
#include #include #include int64_t count_in_range(const std::vector& arr, double min_value, double max_value){ int64_t count = 0; for(int64_t i=0 ; i<arr.size() ; ++i) { double a = arr[i]; if (min_value < a && a < max_value) { count += 1; } } return count;}
使用 Clang 编译此代码会生成使用 SIMD 指令的汇编代码,表明循环已成功向量化。但是,如果在 C++ 代码中添加 break 语句,则生成的汇编代码将不再使用 SIMD 指令,导致性能下降。
解决方案:分块处理
为了解决这个问题,我们可以将数组分成小块,并对每个块进行处理。这样,LLVM 仍然可以向量化块内的循环,并且我们仍然可以在找到第一个满足条件的元素后提前退出。
以下是修改后的 Numba 函数,它使用分块处理:
@numba.njitdef count_in_range_faster(arr, min_value, max_value): count = 0 for i in range(0, arr.size, 16): if arr.size - i >= 16: # Optimized SIMD-friendly computation of 1 chunk of size 16 tmp_view = arr[i:i+16] for j in range(0, 16): if min_value < tmp_view[j] 0: return 1 else: # Fallback implementation (variable-sized chunk) for j in range(i, arr.size): if min_value < arr[j] 0: return 1 return 0
在这个版本中,我们将数组分成大小为 16 的块。对于每个块,我们迭代其元素并检查它们是否满足条件。如果在任何块中找到满足条件的元素,我们立即返回。
性能测试
在配备 Xeon W-2255 CPU 的机器上使用 Numba 0.56.0 进行了性能测试,结果如下:
count_in_range: 7.112 mscount_in_range2: 35.317 mscount_in_range_faster: 5.827 ms
结果表明,count_in_range_faster 函数的性能明显优于 count_in_range2 函数,甚至略优于原始的 count_in_range 函数。
总结
在 Numba 函数中添加 break 语句可能会阻止 LLVM 进行有效的向量化,导致性能下降。一种解决方案是将数据分成小块并对每个块进行处理。这样,LLVM 仍然可以向量化块内的循环,并且我们仍然可以在找到第一个满足条件的元素后提前退出。在实际应用中,应该根据具体情况选择合适的块大小,以获得最佳性能。
以上就是Numba 函数中添加 break 语句导致性能显著下降的原因及解决方案的详细内容,更多请关注创想鸟其它相关文章!
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1376679.html
微信扫一扫
支付宝扫一扫