PyMC模型中自定义对数似然的性能优化:兼论JAX兼容性与数学表达式重构

PyMC模型中自定义对数似然的性能优化:兼论JAX兼容性与数学表达式重构

pymc模型中,当使用自定义pytensor op定义对数似然并尝试结合blackjax采样器时,可能遭遇jax转换兼容性错误。本文将深入探讨如何实现自定义对数似然,分析blackjax集成时的挑战,并提供一种通过数学表达式重构来显著提升核心计算函数性能的通用优化策略,即使无法利用jax加速,也能有效缩短采样时间。

1. PyMC中自定义对数似然函数的实现

在贝叶斯建模中,有时标准分布无法满足特定需求,需要引入自定义的对数似然函数。PyMC(基于PyTensor)提供了一种机制,允许用户通过定义自定义的pytensor.Op来集成任意Python函数及其梯度。

1.1 定义自定义PyTensor Op

要将一个复杂的Python函数(例如,涉及外部库或数值求解器)集成到PyMC的计算图中,需要创建两个pytensor.Op类:一个用于计算函数值(对数似然),另一个用于计算其梯度。

LogLikeWithGrad (对数似然函数)

这个Op负责计算给定参数的对数似然值。它需要实现perform方法来执行实际的对数似然计算,并重载grad方法来指定如何计算梯度。

import pytensor.tensor as ptimport numpy as npfrom scipy.optimize import approx_fprime # 用于数值梯度class LogLikeWithGrad(pt.Op):    itypes = [pt.dvector]  # 输入是一个参数向量    otypes = [pt.dscalar]  # 输出是一个标量(对数似然值)    def __init__(self, loglike_function):        self.likelihood = loglike_function        self.loglike_grad_op = LogLikeGrad(loglike_function) # 初始化梯度Op    def perform(self, node, inputs, outputs):        (theta,) = inputs        logl = self.likelihood(theta)        outputs[0][0] = np.array(logl)    def grad(self, inputs, grad_outputs):        (theta,) = inputs        # 调用自定义的梯度Op来计算梯度        grads = self.loglike_grad_op(theta)        return [grad_outputs[0] * grads]

LogLikeGrad (对数似然梯度函数)

这个Op专门用于计算对数似然函数相对于其输入的梯度。在缺乏解析梯度的情况下,可以使用数值近似方法,例如scipy.optimize.approx_fprime。

class LogLikeGrad(pt.Op):    itypes = [pt.dvector]  # 输入是一个参数向量    otypes = [pt.dvector]  # 输出是一个梯度向量    def __init__(self, loglike_function):        self.likelihood = loglike_function    def perform(self, node, inputs, outputs):        (theta,) = inputs        # 使用数值方法近似梯度        grads = approx_fprime(theta, self.likelihood, epsilon=1e-8)        outputs[0][0] = grads

1.2 在PyMC模型中集成自定义似然

一旦定义了自定义的LogLikeWithGrad Op,就可以将其作为pm.Potential添加到PyMC模型中。pm.Potential允许用户在模型中引入任意的对数概率贡献。

import pymc as pm# 假设 applyMCMC 是你的核心对数似然计算函数# 并且 param_names 和 lower/upper_boundaries 已经定义# logl = LogLikeWithGrad(applyMCMC)with pm.Model() as model:    # 定义模型参数    for i, name in enumerate(param_names):        pm.Uniform(name, lower=lower_boundaries[0][i], upper=upper_boundaries[0][i])    # 将所有参数组合成一个PyTensor向量    theta = pt.as_tensor_variable([model[param] for param in param_names])    # 将自定义对数似然作为潜力项添加到模型中    pm.Potential("likelihood", logl(theta))    # 执行采样    # trace = pm.sample(draws=niter, step=pm.NUTS(), tune=500, cores=64, init="jitter+adapt_diag", progressbar=True)

2. Blackjax采样器与PyTensor Op的JAX兼容性挑战

PyMC 5.x 版本支持多种NUTS采样器后端,包括其默认的PyTensor后端以及基于JAX的Blackjax采样器,后者在GPU等加速设备上表现出色。然而,当模型中包含自定义的pytensor.Op时,尝试使用Blackjax采样器可能会遇到兼容性问题。

2.1 NotImplementedError 的根源

当尝试通过 pm.sample(nuts_sampler=”blackjax”) 使用Blackjax时,如果自定义的LogLikeWithGrad Op没有对应的JAX转换实现,PyTensor的JAX后端会抛出 NotImplementedError,错误信息通常为 No JAX conversion for the given Op: LogLikeWithGrad。

这是因为Blackjax采样器依赖于JAX的即时编译(JIT)能力,而JAX只能编译其能够理解的操作。自定义的pytensor.Op本质上是一个Python对象,PyTensor需要一个明确的规则来告诉JAX如何将其转换为JAX可执行的操作。对于标准PyTensor操作,这些转换已经内置,但对于用户自定义的Op,则需要手动提供。

2.2 解决方案方向

解决此问题通常需要以下两种方法之一:

重写Op以完全使用JAX操作:如果自定义对数似然函数的核心逻辑可以用JAX原生的操作(如jax.numpy)表达,那么可以直接在JAX中构建这个似然函数,并将其集成到PyMC模型中。这通常涉及将所有PyTensor代码替换为JAX代码。为自定义Op提供JAX转换规则:这是一种更高级的方法,涉及为自定义pytensor.Op编写一个JAX转换函数,告诉PyTensor的JAX后端如何将该Op转换为JAX的计算图。这通常通过注册JAX调度函数来实现,但实现起来较为复杂,且文档相对较少。

在许多情况下,特别是当自定义似然函数依赖于复杂外部库(如物理模拟器)时,直接将其完全转换为JAX操作可能非常困难或不可能。此时,即使无法利用Blackjax的JAX加速,我们仍然可以通过优化核心计算逻辑来提升采样性能。

3. 提升PyMC模型计算性能的通用策略:数学表达式优化

即使无法直接利用JAX的GPU加速,通过对核心数学计算函数进行细致的优化,也能显著提升PyMC模型的采样速度。这种优化策略侧重于减少冗余计算、避免重复的函数调用以及利用局部变量缓存中间结果。

3.1 识别并消除冗余计算

在复杂的数学表达式中,往往存在重复计算相同子表达式的情况。通过将这些子表达式的结果存储在局部变量中,可以避免多次计算,从而提高效率。

以原始代码中的dH和du函数为例,它们包含大量重复的幂运算和乘法:

(1 + z) 的不同幂次 ((1 + z)**2, (1 + z)**3, (1 + z)**4, (1 + z)**5)math.pi * Rho_mOmega_k * PhiPhi * Phi (即 Phi**2)

3.2 优化示例:dH 和 du 函数

以下是针对 dH 和 du 函数的优化版本,通过引入局部变量来缓存重复计算的中间结果:

import mathimport timeit # 用于性能测试# 假设 Rho_m, Phi, u, omega_BD, Omega_k, z 为示例输入# (为了测试方便,这里使用任意值,实际应是模型参数)Rho_m = -1.0Phi = 0.1u = 3.0omega_BD = 4.0Omega_k = -5.0z = 6.0# 原始 du 函数 (为对比而保留,实际代码中应替换为优化版本)def du_original(Rho_m, Phi, u, omega_BD, Omega_k, z):    return (        24 * math.pi * Rho_m * Phi**3        + (1 + z)        * u        * Phi**2        * (            8 * math.pi * (-3 + omega_BD) * Rho_m            - 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi        )        - 3        * (1 + z) ** 2        * u**2        * Phi        * (            -4 * math.pi * omega_BD * Rho_m            + (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi        )        - omega_BD        * u**3        * (            4 * math.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m            + (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi        )    ) / (        (1 + z) ** 2        * (3 + 2 * omega_BD)        * Phi**2        * (8 * math.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)    )# 优化后的 du 函数def du_optimized(Rho_m, Phi, u, omega_BD, Omega_k, z):    # 缓存幂次和重复乘法    Phi_pow2 = Phi * Phi    Phi_pow3 = Phi_pow2 * Phi    one_plus_z = 1 + z    one_plus_z_pow2 = one_plus_z * one_plus_z    one_plus_z_pow3 = one_plus_z_pow2 * one_plus_z    one_plus_z_pow4 = one_plus_z_pow3 * one_plus_z    one_plus_z_pow5 = one_plus_z_pow4 * one_plus_z    # 缓存其他重复子表达式    one_plus_z_pow2_times_3 = 3 * one_plus_z_pow2    pi_times_Rho_m = math.pi * Rho_m    Omega_k_times_Phi = Omega_k * Phi    u_pow2 = u * u    u_pow3 = u_pow2 * u    omg1 = (3 + 2 * omega_BD) # (3 + 2 * omega_BD)    omg = omg1 * Omega_k_times_Phi # (3 + 2 * omega_BD) * Omega_k * Phi    omg2 = omega_BD * pi_times_Rho_m # omega_BD * math.pi * Rho_m    return (        24 * pi_times_Rho_m * Phi_pow3        + one_plus_z * u * Phi_pow2 * (8 * (-3 + omega_BD) * pi_times_Rho_m - one_plus_z_pow2_times_3 * omg)        - one_plus_z_pow2_times_3 * u_pow2 * Phi * (-4 * omg2 + one_plus_z_pow4 * omg)        - omega_BD * u_pow3 * (4 * one_plus_z_pow3 * (pi_times_Rho_m + omg2) + one_plus_z_pow5 * omg)    ) / (        one_plus_z_pow2 * omg1 * Phi_pow2 * (8 * pi_times_Rho_m + one_plus_z_pow2_times_3 * Omega_k_times_Phi)    )# 原始 dH 函数 (为对比而保留)def dH_original(Rho_m, Phi, u, omega_BD, Omega_k, z):    val = (-16 * math.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (        6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi    )    if val >= 0:        return -(            (                (1 + z)                * (16 * math.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)                * (                    (1 + z) * omega_BD * u**3                    - 2                    * omega_BD                    * u                    * ((1 + z) * du_original(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)                    * Phi                    - 6 * du_original(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2                )                + (                    6                    * Phi                    * (                        -8 * math.pi * Rho_m                        + (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)                    )                    * (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))                )                / (1 + z)            )            / (                2                * math.sqrt(val)                * (                    (1 + z) ** 2 * omega_BD * u**2                    + 6 * (1 + z) * u * Phi                    - 6 * Phi**2                )                ** 2            )        )    else:        return None# 优化后的 dH 函数def dH_optimized(Rho_m, Phi, u, omega_BD, Omega_k, z):    # 缓存常用变量和幂次    Phi_pow2 = Phi * Phi    Phi_pow2_times_6 = Phi_pow2 * 6    Phi_times_6 = Phi * 6    one_plus_z = 1 + z    one_plus_z_pow2 = one_plus_z * one_plus_z    one_plus_z_times_u = one_plus_z * u    pi_times_Rho_m = math.pi * Rho_m    Omega_k_times_Phi = Omega_k * Phi    u_pow2 = u * u    u_pow3 = u_pow2 * u    # 重新计算 duu (如果 duu 在 dH 内部被调用多次,直接内联其计算可进一步优化)    # 此处为简洁起见,仍调用 du_optimized,但注意实际场景可内联    # 或者,如原答案所示,直接将 du_optimized 的计算逻辑复制到此处    # duu 的内联计算部分 (来自 du_optimized)    Phi_pow3_du = Phi_pow2 * Phi    one_plus_z_pow3_du = one_plus_z_pow2 * one_plus_z    one_plus_z_pow4_du = one_plus_z_pow3_du * one_plus_z    one_plus_z_pow5_du = one_plus_z_pow4_du * one_plus_z    one_plus_z_pow2_times_3_du = 3 * one_plus_z_pow2    omg1_du = (3 + 2 * omega_BD)    omg_du = omg1_du * Omega_k_times_Phi    omg2_du = omega_BD * pi_times_Rho_m    duu = (        24 * pi_times_Rho_m * Phi_pow3_du        + one_plus_z * u * Phi_pow2 * (8 * (-3 + omega_BD) * pi_times_Rho_m - one_plus_z_pow2_times_3_du * omg_du)        - one_plus_z_pow2_times_3_du * u_pow2 * Phi * (-4 * omg2_du + one_plus_z_pow4_du * omg_du)        - omega_BD * u_pow3 * (4 * one_plus_z_pow3_du * (pi_times_Rho_m + omg2_du) + one_plus_z_pow5_du * omg_du)    ) / (        one_plus_z_pow2 * omg1_du * Phi_pow2 * (8 * pi_times_Rho_m + one_plus_z_pow2_times_3_du * Omega_k_times_Phi)    )    # duu 内联计算结束    val1 = (-16 * pi_times_Rho_m - 6 * one_plus_z_pow2 * Omega_k_times_Phi)    val = val1 / (6 * one_plus_z_times_u + (one_plus_z_pow2 * omega_BD * u_pow2) / Phi - Phi_times_6)    if val >= 0:        Phi_times_2 = Phi + Phi        val2 = (one_plus_z_pow2 * omega_BD * u_pow2 + 6 * (one_plus_z_times_u * Phi - Phi_pow2))        # 优化分子中的复杂项        term1_numerator = one_plus_z_pow2 * val1 * (omega_BD * u * (one_plus_z * u_pow2 - Phi_times_2 * (one_plus_z * duu + u)) - duu * Phi_pow2_times_6)        term2_numerator = Phi_times_6 * (-8 * pi_times_Rho_m + one_plus_z_pow2 * Omega_k * (one_plus_z_times_u + Phi_times_2)) * (Phi_pow2_times_6 - one_plus_z_times_u * (one_plus_z_times_u * omega_BD + Phi_times_6))        return (term1_numerator - term2_numerator) / (2 * one_plus_z * math.sqrt(val) * val2 * val2)    else:        return None# 性能测试t_original = timeit.timeit('dH_original(Rho_m, Phi, u, omega_

以上就是PyMC模型中自定义对数似然的性能优化:兼论JAX兼容性与数学表达式重构的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1381260.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
解决Django 404错误:深入理解URL配置与调试
上一篇 2025年12月14日 22:48:38
使用 Polars LazyFrame 进行列级乘法
下一篇 2025年12月14日 22:48:47

相关推荐

  • Matplotlib 地图中多类型图例的创建与优化

    Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化

    本教程旨在解决matplotlib地图可视化中,如何在一个图例中同时展示颜色块(如区域分类)和自定义标记(如特定兴趣点)的问题。文章详细介绍了当传统`patch`对象无法正确显示标记时,如何利用`matplotlib.lines.line2d`创建标记图例句柄,并将其与颜色块图例句柄合并,从而生成一…

    2026年5月10日 用户投稿
    900
  • 利用海象运算符简化条件赋值:Python教程与最佳实践

    本文旨在探讨Python中海象运算符(:=)在条件赋值场景下的应用。通过对比传统if/else语句与海象运算符,以及条件表达式,分析海象运算符在简化代码、提高可读性方面的优势与局限性。并通过具体示例,展示如何在列表推导式等场景下合理使用海象运算符,同时强调其潜在的复杂性及替代方案,帮助开发者更好地掌…

    2026年5月10日
    300
  • 比特币新手教程 比特币交易平台有哪些

    比特币是一种去中心化的数字货币,基于区块链技术实现点对点交易,具有匿名性、有限发行和不可篡改等特点;新手可通过交易所购买,P2P交易获得比特币,常用平台包括Binance、OKX和Huobi;交易流程包括注册账户、实名认证、绑定支付方式、充值法币并下单购买,可选择市价单或限价单;比特币存储方式有交易…

    2026年5月10日
    000
  • RichHandler与Rich Progress集成:解决显示冲突的教程

    在使用rich库的`richhandler`进行日志输出并同时使用`progress`组件时,可能会遇到显示错乱或溢出问题。这通常是由于为`richhandler`和`progress`分别创建了独立的`console`实例导致的。解决方案是确保日志处理器和进度条组件共享同一个`console`实例…

    2026年5月10日
    300
  • 使用 Jupyter Notebook 进行探索性数据分析

    Jupyter Notebook通过单元格实现代码与Markdown结合,支持数据导入(pandas)、清洗(fillna)、探索(matplotlib/seaborn可视化)、统计分析(describe/corr)和特征工程,便于记录与分享分析过程。 Jupyter Notebook 是进行探索性…

    2026年5月10日
    000
  • 深入理解 Express.js 中 next() 参数的作用与中间件机制

    本文深入探讨 express.js 中间件函数中的 `next()` 参数。它负责将控制权传递给请求-响应周期中的下一个中间件或路由处理程序。文章将详细解释 `next()` 的工作原理、中间件的注册与执行顺序,以及不正确使用 `next()` 可能导致请求挂起的风险,并通过代码示例和实际应用场景,…

    2026年5月10日
    000
  • Python命令怎样使用profile分析脚本性能 Python命令性能分析的基础教程

    使用Python的cProfile模块分析脚本性能最直接的方式是通过命令行执行python -m cProfile your_script.py,它会输出每个函数的调用次数、总耗时、累积耗时等关键指标,帮助定位性能瓶颈;为进一步分析,可将结果保存为文件python -m cProfile -o ou…

    2026年5月10日
    000
  • PHP动态生成表单输入与POST数据获取实践指南

    本教程详细阐述了如何在php中根据动态数据源(如数据库值)生成多个表单输入框,并演示了如何通过post方法准确无误地获取这些动态生成的输入值。文章强调了正确的输入框命名策略,避免了常见的命名误区,并提供了完整的代码示例,确保开发者能够高效处理动态表单数据。 动态生成表单输入 在Web开发中,我们经常…

    2026年5月10日
    000
  • Python递归函数追踪与性能考量:以序列打印为例

    本文深入探讨了Python中一种递归打印序列元素的方法,并着重演示了如何通过引入缩进参数来有效追踪递归函数的执行流程和参数变化。通过实际代码示例,文章揭示了递归调用可能带来的潜在性能开销,特别是对调用栈空间的需求,以及Python默认递归深度限制可能导致的错误,为读者提供了理解和优化递归算法的实用见…

    2026年5月10日
    000
  • python中zip函数详解 python多序列压缩zip函数应用场景

    zip函数的应用场景包括:1) 同时遍历多个序列,2) 合并多个列表的数据,3) 数据分析和科学计算中的元素运算,4) 处理csv文件,5) 性能优化。zip函数是一个强大的工具,能够简化代码并提高处理多个序列时的效率。 在Python中,zip函数是一个非常有用的工具,它能够将多个可迭代对象打包成…

    2026年5月10日
    300
  • Python中怎样使用pymongo?

    在python中使用pymongo可以轻松地与mongodb数据库进行交互。1)安装pymongo:pip install pymongo。2)连接到mongodb:from pymongo import mongoclient; client = mongoclient(‘mongod…

    2026年5月10日
    000
  • Golang空接口如何应用在项目中

    空接口可用于接收任意类型值,常见于日志函数、通用数据结构、JSON动态解析及配置驱动逻辑,提升代码灵活性,但需配合类型断言确保安全,避免滥用以降低维护成本。 空接口 interface{} 在 Go 语言中是一个非常灵活的类型,它可以存储任何类型的值。虽然它牺牲了一部分类型安全,但在实际项目中合理使…

    2026年5月10日
    100
  • JavaScript计算器开发:解决数值显示与初始化问题

    本教程深入探讨了使用JavaScript构建计算器时常见的数值显示异常问题,特别是由于类属性未初始化导致的`Cannot read properties of undefined`错误。我们将详细分析问题根源,并通过在构造函数中调用初始化方法来解决该问题,同时优化显示逻辑,确保计算器功能稳定且界面显…

    2026年5月10日
    000
  • Python 函数参数类型:如何使用可变参数和动态参数?

    python 中的参数类型:关键词参数、可变参数和动态参数 在 python 中,函数的参数可以分为以下几种类型: 关键词参数(kw)**:这些参数具有名称,并且在调用函数时明确指定。可变参数(*args):这些参数没有名称,允许函数接受任意数量的位置参数。它们将被收集到一个元组中。动态参数(kwa…

    2026年5月10日
    000
  • Circle为何在凌晨向Solana新增铸造5亿枚USDC?USDC增发原因与对SOL生态影响深度解析

    近日,链上数据显示,Circle 在凌晨向 Solana 链新增铸造了 5亿枚USDC。此次大规模增发引起市场关注,投资者需要了解背后的原因以及对 Solana 生态的潜在影响。 USDC增发原因分析 增发 USDC 的主要原因可能包括: 满足市场需求:近期 Solana 上交易活动活跃,USDC …

    2026年5月10日
    000
  • JavaScript 高效判断页面所有复选框状态的技巧与实践

    本文旨在提供一套高效且专业的javascript方法,用于判断网页中所有复选框的选中状态。我们将探讨如何利用`array.some()`快速确定是否有未选中的复选框(进而判断是否全部选中),以及如何使用`array.filter()`统计选中和未选中的复选框数量。通过优化dom元素选择和数组操作,提…

    2026年5月10日
    100
  • pycharm解析器怎么添加 解析器添加详细流程

    在pycharm中添加解析器的步骤包括:1) 打开pycharm并进入设置,2) 选择project interpreter,3) 点击齿轮图标并选择add,4) 选择解析器类型并配置路径,5) 点击ok完成添加。添加解析器后,选择合适的类型和版本,配置环境变量,并利用解析器的功能提高开发效率。 在…

    2026年5月10日
    100
  • python中numpy的用法

    NumPy是Python中用于科学计算的强大库,它提供了以下功能:多维数组处理矩阵运算快速傅里叶变换(FFT)线性代数随机数生成 NumPy在Python中的强大功能 NumPy是Python中用于科学计算的一个强大且灵活的库。它提供了用于处理多维数组和矩阵的一组高效工具,是数据分析和机器学习项目的…

    2026年5月10日
    100
  • 从 JavaScript 获取 URL 并在 PHP DataGrid 中使用

    本文档旨在指导开发者如何从 JavaScript 函数中获取 URL,并将其动态应用于 PHP DataGrid。通过前端 JavaScript 动态生成 API 地址,并将其传递给后端的 PHP DataGrid,实现数据根据用户会话动态加载。 动态配置 DataGrid 的 URL 在构建动态 …

    2026年5月10日
    100
  • python如何捕获所有类型的异常_python try except捕获所有异常的方法

    答案:捕获所有异常推荐使用except Exception as e,可捕获常规错误并记录日志,避免影响程序正常退出;需拦截系统信号时才用except BaseException as e。 在Python中,要捕获所有类型的异常,最常见且推荐的方法是使用 except Exception as e…

    2026年5月10日
    300

发表回复

登录后才能评论
关注微信