PyMC模型中自定义对数似然的性能优化:兼论JAX兼容性与数学表达式重构

PyMC模型中自定义对数似然的性能优化:兼论JAX兼容性与数学表达式重构

pymc模型中,当使用自定义pytensor op定义对数似然并尝试结合blackjax采样器时,可能遭遇jax转换兼容性错误。本文将深入探讨如何实现自定义对数似然,分析blackjax集成时的挑战,并提供一种通过数学表达式重构来显著提升核心计算函数性能的通用优化策略,即使无法利用jax加速,也能有效缩短采样时间。

1. PyMC中自定义对数似然函数的实现

在贝叶斯建模中,有时标准分布无法满足特定需求,需要引入自定义的对数似然函数。PyMC(基于PyTensor)提供了一种机制,允许用户通过定义自定义的pytensor.Op来集成任意Python函数及其梯度。

1.1 定义自定义PyTensor Op

要将一个复杂的Python函数(例如,涉及外部库或数值求解器)集成到PyMC的计算图中,需要创建两个pytensor.Op类:一个用于计算函数值(对数似然),另一个用于计算其梯度。

LogLikeWithGrad (对数似然函数)

这个Op负责计算给定参数的对数似然值。它需要实现perform方法来执行实际的对数似然计算,并重载grad方法来指定如何计算梯度。

import pytensor.tensor as ptimport numpy as npfrom scipy.optimize import approx_fprime # 用于数值梯度class LogLikeWithGrad(pt.Op):    itypes = [pt.dvector]  # 输入是一个参数向量    otypes = [pt.dscalar]  # 输出是一个标量(对数似然值)    def __init__(self, loglike_function):        self.likelihood = loglike_function        self.loglike_grad_op = LogLikeGrad(loglike_function) # 初始化梯度Op    def perform(self, node, inputs, outputs):        (theta,) = inputs        logl = self.likelihood(theta)        outputs[0][0] = np.array(logl)    def grad(self, inputs, grad_outputs):        (theta,) = inputs        # 调用自定义的梯度Op来计算梯度        grads = self.loglike_grad_op(theta)        return [grad_outputs[0] * grads]

LogLikeGrad (对数似然梯度函数)

这个Op专门用于计算对数似然函数相对于其输入的梯度。在缺乏解析梯度的情况下,可以使用数值近似方法,例如scipy.optimize.approx_fprime。

class LogLikeGrad(pt.Op):    itypes = [pt.dvector]  # 输入是一个参数向量    otypes = [pt.dvector]  # 输出是一个梯度向量    def __init__(self, loglike_function):        self.likelihood = loglike_function    def perform(self, node, inputs, outputs):        (theta,) = inputs        # 使用数值方法近似梯度        grads = approx_fprime(theta, self.likelihood, epsilon=1e-8)        outputs[0][0] = grads

1.2 在PyMC模型中集成自定义似然

一旦定义了自定义的LogLikeWithGrad Op,就可以将其作为pm.Potential添加到PyMC模型中。pm.Potential允许用户在模型中引入任意的对数概率贡献。

import pymc as pm# 假设 applyMCMC 是你的核心对数似然计算函数# 并且 param_names 和 lower/upper_boundaries 已经定义# logl = LogLikeWithGrad(applyMCMC)with pm.Model() as model:    # 定义模型参数    for i, name in enumerate(param_names):        pm.Uniform(name, lower=lower_boundaries[0][i], upper=upper_boundaries[0][i])    # 将所有参数组合成一个PyTensor向量    theta = pt.as_tensor_variable([model[param] for param in param_names])    # 将自定义对数似然作为潜力项添加到模型中    pm.Potential("likelihood", logl(theta))    # 执行采样    # trace = pm.sample(draws=niter, step=pm.NUTS(), tune=500, cores=64, init="jitter+adapt_diag", progressbar=True)

2. Blackjax采样器与PyTensor Op的JAX兼容性挑战

PyMC 5.x 版本支持多种NUTS采样器后端,包括其默认的PyTensor后端以及基于JAX的Blackjax采样器,后者在GPU等加速设备上表现出色。然而,当模型中包含自定义的pytensor.Op时,尝试使用Blackjax采样器可能会遇到兼容性问题。

2.1 NotImplementedError 的根源

当尝试通过 pm.sample(nuts_sampler=”blackjax”) 使用Blackjax时,如果自定义的LogLikeWithGrad Op没有对应的JAX转换实现,PyTensor的JAX后端会抛出 NotImplementedError,错误信息通常为 No JAX conversion for the given Op: LogLikeWithGrad。

这是因为Blackjax采样器依赖于JAX的即时编译(JIT)能力,而JAX只能编译其能够理解的操作。自定义的pytensor.Op本质上是一个Python对象,PyTensor需要一个明确的规则来告诉JAX如何将其转换为JAX可执行的操作。对于标准PyTensor操作,这些转换已经内置,但对于用户自定义的Op,则需要手动提供。

2.2 解决方案方向

解决此问题通常需要以下两种方法之一:

重写Op以完全使用JAX操作:如果自定义对数似然函数的核心逻辑可以用JAX原生的操作(如jax.numpy)表达,那么可以直接在JAX中构建这个似然函数,并将其集成到PyMC模型中。这通常涉及将所有PyTensor代码替换为JAX代码。为自定义Op提供JAX转换规则:这是一种更高级的方法,涉及为自定义pytensor.Op编写一个JAX转换函数,告诉PyTensor的JAX后端如何将该Op转换为JAX的计算图。这通常通过注册JAX调度函数来实现,但实现起来较为复杂,且文档相对较少。

在许多情况下,特别是当自定义似然函数依赖于复杂外部库(如物理模拟器)时,直接将其完全转换为JAX操作可能非常困难或不可能。此时,即使无法利用Blackjax的JAX加速,我们仍然可以通过优化核心计算逻辑来提升采样性能。

3. 提升PyMC模型计算性能的通用策略:数学表达式优化

即使无法直接利用JAX的GPU加速,通过对核心数学计算函数进行细致的优化,也能显著提升PyMC模型的采样速度。这种优化策略侧重于减少冗余计算、避免重复的函数调用以及利用局部变量缓存中间结果。

3.1 识别并消除冗余计算

在复杂的数学表达式中,往往存在重复计算相同子表达式的情况。通过将这些子表达式的结果存储在局部变量中,可以避免多次计算,从而提高效率。

以原始代码中的dH和du函数为例,它们包含大量重复的幂运算和乘法:

(1 + z) 的不同幂次 ((1 + z)**2, (1 + z)**3, (1 + z)**4, (1 + z)**5)math.pi * Rho_mOmega_k * PhiPhi * Phi (即 Phi**2)

3.2 优化示例:dH 和 du 函数

以下是针对 dH 和 du 函数的优化版本,通过引入局部变量来缓存重复计算的中间结果:

import mathimport timeit # 用于性能测试# 假设 Rho_m, Phi, u, omega_BD, Omega_k, z 为示例输入# (为了测试方便,这里使用任意值,实际应是模型参数)Rho_m = -1.0Phi = 0.1u = 3.0omega_BD = 4.0Omega_k = -5.0z = 6.0# 原始 du 函数 (为对比而保留,实际代码中应替换为优化版本)def du_original(Rho_m, Phi, u, omega_BD, Omega_k, z):    return (        24 * math.pi * Rho_m * Phi**3        + (1 + z)        * u        * Phi**2        * (            8 * math.pi * (-3 + omega_BD) * Rho_m            - 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi        )        - 3        * (1 + z) ** 2        * u**2        * Phi        * (            -4 * math.pi * omega_BD * Rho_m            + (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi        )        - omega_BD        * u**3        * (            4 * math.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m            + (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi        )    ) / (        (1 + z) ** 2        * (3 + 2 * omega_BD)        * Phi**2        * (8 * math.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)    )# 优化后的 du 函数def du_optimized(Rho_m, Phi, u, omega_BD, Omega_k, z):    # 缓存幂次和重复乘法    Phi_pow2 = Phi * Phi    Phi_pow3 = Phi_pow2 * Phi    one_plus_z = 1 + z    one_plus_z_pow2 = one_plus_z * one_plus_z    one_plus_z_pow3 = one_plus_z_pow2 * one_plus_z    one_plus_z_pow4 = one_plus_z_pow3 * one_plus_z    one_plus_z_pow5 = one_plus_z_pow4 * one_plus_z    # 缓存其他重复子表达式    one_plus_z_pow2_times_3 = 3 * one_plus_z_pow2    pi_times_Rho_m = math.pi * Rho_m    Omega_k_times_Phi = Omega_k * Phi    u_pow2 = u * u    u_pow3 = u_pow2 * u    omg1 = (3 + 2 * omega_BD) # (3 + 2 * omega_BD)    omg = omg1 * Omega_k_times_Phi # (3 + 2 * omega_BD) * Omega_k * Phi    omg2 = omega_BD * pi_times_Rho_m # omega_BD * math.pi * Rho_m    return (        24 * pi_times_Rho_m * Phi_pow3        + one_plus_z * u * Phi_pow2 * (8 * (-3 + omega_BD) * pi_times_Rho_m - one_plus_z_pow2_times_3 * omg)        - one_plus_z_pow2_times_3 * u_pow2 * Phi * (-4 * omg2 + one_plus_z_pow4 * omg)        - omega_BD * u_pow3 * (4 * one_plus_z_pow3 * (pi_times_Rho_m + omg2) + one_plus_z_pow5 * omg)    ) / (        one_plus_z_pow2 * omg1 * Phi_pow2 * (8 * pi_times_Rho_m + one_plus_z_pow2_times_3 * Omega_k_times_Phi)    )# 原始 dH 函数 (为对比而保留)def dH_original(Rho_m, Phi, u, omega_BD, Omega_k, z):    val = (-16 * math.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (        6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi    )    if val >= 0:        return -(            (                (1 + z)                * (16 * math.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)                * (                    (1 + z) * omega_BD * u**3                    - 2                    * omega_BD                    * u                    * ((1 + z) * du_original(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)                    * Phi                    - 6 * du_original(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2                )                + (                    6                    * Phi                    * (                        -8 * math.pi * Rho_m                        + (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)                    )                    * (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))                )                / (1 + z)            )            / (                2                * math.sqrt(val)                * (                    (1 + z) ** 2 * omega_BD * u**2                    + 6 * (1 + z) * u * Phi                    - 6 * Phi**2                )                ** 2            )        )    else:        return None# 优化后的 dH 函数def dH_optimized(Rho_m, Phi, u, omega_BD, Omega_k, z):    # 缓存常用变量和幂次    Phi_pow2 = Phi * Phi    Phi_pow2_times_6 = Phi_pow2 * 6    Phi_times_6 = Phi * 6    one_plus_z = 1 + z    one_plus_z_pow2 = one_plus_z * one_plus_z    one_plus_z_times_u = one_plus_z * u    pi_times_Rho_m = math.pi * Rho_m    Omega_k_times_Phi = Omega_k * Phi    u_pow2 = u * u    u_pow3 = u_pow2 * u    # 重新计算 duu (如果 duu 在 dH 内部被调用多次,直接内联其计算可进一步优化)    # 此处为简洁起见,仍调用 du_optimized,但注意实际场景可内联    # 或者,如原答案所示,直接将 du_optimized 的计算逻辑复制到此处    # duu 的内联计算部分 (来自 du_optimized)    Phi_pow3_du = Phi_pow2 * Phi    one_plus_z_pow3_du = one_plus_z_pow2 * one_plus_z    one_plus_z_pow4_du = one_plus_z_pow3_du * one_plus_z    one_plus_z_pow5_du = one_plus_z_pow4_du * one_plus_z    one_plus_z_pow2_times_3_du = 3 * one_plus_z_pow2    omg1_du = (3 + 2 * omega_BD)    omg_du = omg1_du * Omega_k_times_Phi    omg2_du = omega_BD * pi_times_Rho_m    duu = (        24 * pi_times_Rho_m * Phi_pow3_du        + one_plus_z * u * Phi_pow2 * (8 * (-3 + omega_BD) * pi_times_Rho_m - one_plus_z_pow2_times_3_du * omg_du)        - one_plus_z_pow2_times_3_du * u_pow2 * Phi * (-4 * omg2_du + one_plus_z_pow4_du * omg_du)        - omega_BD * u_pow3 * (4 * one_plus_z_pow3_du * (pi_times_Rho_m + omg2_du) + one_plus_z_pow5_du * omg_du)    ) / (        one_plus_z_pow2 * omg1_du * Phi_pow2 * (8 * pi_times_Rho_m + one_plus_z_pow2_times_3_du * Omega_k_times_Phi)    )    # duu 内联计算结束    val1 = (-16 * pi_times_Rho_m - 6 * one_plus_z_pow2 * Omega_k_times_Phi)    val = val1 / (6 * one_plus_z_times_u + (one_plus_z_pow2 * omega_BD * u_pow2) / Phi - Phi_times_6)    if val >= 0:        Phi_times_2 = Phi + Phi        val2 = (one_plus_z_pow2 * omega_BD * u_pow2 + 6 * (one_plus_z_times_u * Phi - Phi_pow2))        # 优化分子中的复杂项        term1_numerator = one_plus_z_pow2 * val1 * (omega_BD * u * (one_plus_z * u_pow2 - Phi_times_2 * (one_plus_z * duu + u)) - duu * Phi_pow2_times_6)        term2_numerator = Phi_times_6 * (-8 * pi_times_Rho_m + one_plus_z_pow2 * Omega_k * (one_plus_z_times_u + Phi_times_2)) * (Phi_pow2_times_6 - one_plus_z_times_u * (one_plus_z_times_u * omega_BD + Phi_times_6))        return (term1_numerator - term2_numerator) / (2 * one_plus_z * math.sqrt(val) * val2 * val2)    else:        return None# 性能测试t_original = timeit.timeit('dH_original(Rho_m, Phi, u, omega_

以上就是PyMC模型中自定义对数似然的性能优化:兼论JAX兼容性与数学表达式重构的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1381260.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年12月14日 22:48:38
下一篇 2025年12月14日 22:48:47

相关推荐

  • 如何解决本地图片在使用 mask JS 库时出现的跨域错误?

    如何跨越localhost使用本地图片? 问题: 在本地使用mask js库时,引入本地图片会报跨域错误。 解决方案: 要解决此问题,需要使用本地服务器启动文件,以http或https协议访问图片,而不是使用file://协议。例如: python -m http.server 8000 然后,可以…

    2025年12月24日
    200
  • 使用 Mask 导入本地图片时,如何解决跨域问题?

    跨域疑难:如何解决 mask 引入本地图片产生的跨域问题? 在使用 mask 导入本地图片时,你可能会遇到令人沮丧的跨域错误。为什么会出现跨域问题呢?让我们深入了解一下: mask 框架假设你以 http(s) 协议加载你的 html 文件,而当使用 file:// 协议打开本地文件时,就会产生跨域…

    2025年12月24日
    200
  • 构建模拟:从头开始的实时交易模拟器

    简介 嘿,开发社区!我很高兴分享我的业余项目 Simul8or – 一个实时日间交易模拟器,旨在为用户提供一个无风险的环境来练习交易策略。该项目 100% 构建在 ASP.NET WebForms、C#、JavaScript、CSS 和 SQL Server 技术堆栈上,没有外部库或框架。从头开始构…

    2025年12月24日
    300
  • 正则表达式在文本验证中的常见问题有哪些?

    正则表达式助力文本输入验证 在文本输入框的验证中,经常遇到需要限定输入内容的情况。例如,输入框只能输入整数,第一位可以为负号。对于不会使用正则表达式的人来说,这可能是个难题。下面我们将提供三种正则表达式,分别满足不同的验证要求。 1. 可选负号,任意数量数字 如果输入框中允许第一位为负号,后面可输入…

    2025年12月24日
    000
  • HTML、CSS 和 JavaScript 项目

    欢迎来到我的 html、css 和 javascript 项目集合!这篇博文全面概述了我创建的各种项目,展示了 web 开发的不同方面。每个项目都可以在自己的存储库中找到,其中包含您需要探索和学习的所有代码。 目录 简介项目概况开始使用贡献作者 介绍 作为一名 web 开发人员,我喜欢从事各种项目,…

    2025年12月24日
    000
  • 为什么多年的经验让我选择全栈而不是平均栈

    在全栈和平均栈开发方面工作了 6 年多,我可以告诉您,虽然这两种方法都是流行且有效的方法,但它们满足不同的需求,并且有自己的优点和缺点。这两个堆栈都可以帮助您创建 Web 应用程序,但它们的实现方式却截然不同。如果您在两者之间难以选择,我希望我在两者之间的经验能给您一些有用的见解。 在这篇文章中,我…

    2025年12月24日
    000
  • 姜戈顺风

    本教程演示如何在新项目中从头开始配置 django 和 tailwindcss。 django 设置 创建一个名为 .venv 的新虚拟环境。 # windows$ python -m venv .venv$ .venvscriptsactivate.ps1(.venv) $# macos/linu…

    2025年12月24日
    000
  • 浏览 CSS 响应式设计

    前端开发人员的一项主要职责是创建响应式设计布局。这也是他们的挑战之一。 您可能和我一样相信,在使用 html/css 和 javascript 进行项目时“是时候开始构建响应式设计了”,或者您可能会发现很难让您的设计响应式。 无论什么情况,让我们开始学习如何导航 css 响应式设计,sailor。 …

    2025年12月24日
    000
  • 花 $o 学习这些编程语言或免费

    → Python → JavaScript → Java → C# → 红宝石 → 斯威夫特 → 科特林 → C++ → PHP → 出发 → R → 打字稿 []https://x.com/e_opore/status/1811567830594388315?t=_j4nncuiy2wfbm7ic…

    2025年12月24日
    000
  • 如何克服响应式布局的不足之处

    如何克服响应式布局的不足之处 随着移动设备的普及和互联网的发展,响应式布局成为了现代网页设计中必不可少的一部分。通过响应式设计,网页可以根据用户所使用的设备自动调整布局,使用户在不同的屏幕尺寸下都能获得良好的浏览体验。 然而,尽管响应式布局在提供多屏幕适应性方面做得相当出色,但仍然存在一些不足之处。…

    2025年12月24日
    000
  • 响应式布局优化移动设备适配的策略与实用技巧

    响应式布局在移动设备上的适配策略与最佳实践 随着移动设备的普及和使用频率的增加,响应式布局逐渐成为网页设计的主流趋势。在移动设备上实现良好的用户体验,需要采用适配策略和最佳实践来确保网页能够在不同尺寸的屏幕上自适应地显示。 一、视口设置为了适应不同尺寸的移动设备屏幕,需要正确设置视口。在网页的头部添…

    2025年12月24日
    000
  • 掌握响应式布局网站的关键要点

    了解响应式布局网站的必备知识 随着移动设备的普及和使用率的增加,人们越来越多地使用手机和平板电脑来浏览网页。为了让网站在不同尺寸的屏幕上都能够有良好的显示效果,响应式布局逐渐成为了现代网页设计的一种重要趋势。本文将介绍响应式布局网站的必备知识,帮助读者更好地了解和运用响应式布局。 一、响应式布局的定…

    2025年12月24日
    200
  • jimdo如何添加html5表单_jimdo表单html5代码嵌入与字段设置【实操】

    可通过嵌入HTML5表单代码、启用字段验证属性、添加CSS样式反馈及替换提交按钮并绑定JS事件四种方式在Jimdo实现自定义表单行为。 如果您在 Jimdo 网站中需要自定义表单行为或字段逻辑,而内置表单编辑器无法满足需求,则可通过嵌入 HTML5 表单代码实现更灵活的控制。以下是具体操作步骤: 一…

    2025年12月23日
    000
  • html5怎么导视频_html5用video标签导出或Canvas转DataURL获视频【导出】

    HTML5无法直接导出video标签内容,需借助Canvas捕获帧并结合MediaRecorder API、FFmpeg.wasm或服务端协同实现。MediaRecorder适用于WebM格式前端录制;FFmpeg.wasm支持MP4等格式及精细编码控制;服务端方案适合高负载场景。 如果您希望在网页…

    2025年12月23日
    300
  • 如何查看编写的html_查看自己编写的HTML文件效果【效果】

    要查看HTML文件的浏览器渲染效果,需确保文件以.html为扩展名保存、用浏览器直接打开、利用开发者工具调试、必要时启用本地HTTP服务器、或使用编辑器实时预览插件。 如果您编写了HTML代码,但无法直观看到其在浏览器中的实际渲染效果,则可能是由于文件未正确保存、未使用浏览器打开或文件扩展名设置错误…

    2025年12月23日
    400
  • html5怎么加php_html5用Ajax与PHP后端交互实现数据传递【交互】

    HTML5不能直接运行PHP,需通过Ajax与PHP通信:前端用fetch发送请求,PHP接收处理并返回JSON,前端解析响应更新DOM;注意跨域、编码、CSRF防护和输入过滤。 HTML5 本身是前端标记语言,不能直接运行 PHP 代码,但可以通过 Ajax(异步 JavaScript)与 PHP…

    2025年12月23日
    300
  • html5 js怎么加_html5用script标签内嵌或外链引入JS代码【添加】

    在HTML5中执行JavaScript需通过script标签:一、内联编写于head或body中;二、外链引入.js文件并建议放body末尾或加defer;三、defer按序执行,async独立执行;四、可动态创建script元素插入执行。 如果您希望在HTML5页面中执行JavaScript代码,…

    2025年12月23日
    000
  • node.js怎么运行html_node.js运行html步骤【指南】

    答案是使用Node.js内置http模块、Express框架或第三方工具serve可快速搭建服务器预览HTML文件。首先通过http模块创建服务器并读取index.html返回响应;其次用Express初始化项目并配置静态文件服务;最后利用serve工具全局安装后一键启动服务器,三种方式均在浏览器访…

    2025年12月23日
    300
  • html5能否插入带表单的文档_html5表单文档嵌入与数据提交【步骤】

    HTML5中无法直接嵌入外部带表单的HTML文档并原生提交;可行方案有四:一、用iframe嵌入,需同源或CORS支持,并用postMessage通信;二、用fetch+DOMParser动态加载表单片段并手动绑定事件;三、在当前页面直接编写表单,最规范且兼容性好;四、用JavaScript+fet…

    2025年12月23日
    000
  • 360怎么装html5_360浏览器默认支持HTML5无需额外安装设置【说明】

    HTML5是网页标准,非独立软件,360浏览器7.0+已原生支持;需确认内核为Blink/Chromium、关闭兼容模式、禁用强制兼容策略、重置Flash插件、清除HTML5本地存储、检查系统Media Foundation组件。 如果您在使用360浏览器时发现HTML5网页功能异常(如视频无法播放…

    2025年12月23日
    000

发表回复

登录后才能评论
关注微信