『零基础+1』一文看懂LSTM原理-《动手学深度学习》

长短期记忆网络(LSTM)为解决隐变量模型的长期信息保存与短期输入缺失问题而设计,含记忆元及输入门、遗忘门、输出门三个门控机制,通过特定计算控制信息留存更新。文中介绍其数学原理、从零开始及简洁实现,提及变体(如带猫眼连接)、与GRU的区别,并展示了训练和预测示例。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

『零基础+1』一文看懂lstm原理-《动手学深度学习》 - 创想鸟

1 长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) Hochreiter.Schmidhuber.1997。

它有许多与门控循环单元(9.1节)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年。

1.1 门控记忆元

可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。

长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。

有些文献认为记忆元是隐状态的一种特殊类型,

它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。

为了控制记忆元,我们需要许多门。

其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。

另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。

我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理,

这种设计的动机与门控循环单元相同,能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。

注:

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - 创想鸟        

Sigmoid 层的输出值在 0 到 1 间,表示每个部分所通过的信息。0 表示「对所有信息关上大门」;1 表示「我家大门常打开」。

一个 LSTM 有三个这样的门,控制 cell 的状态。

门实质上是控制有百分之多少的信息保留下来。门操作由一个 sigmoid 网络层计算得到【0,1】的小数与输入数据流按位乘操作构成。

门的操作是相同的,只是根据不同的设计思想,不同的数据流,叫不同的名字

1.2 输入门、忘记门和输出门

就如在门控循环单元中一样,当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中,

它们由三个具有sigmoid激活函数的全连接层处理,以计算输入门、遗忘门和输出门的值。因此,这三个门的值都在(0,1)(0,1)的范围内。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - 创想鸟 『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - 创想鸟        


首先,LSTM 的第一步需要决定我们需要从 cell 中抛弃哪些信息。这个决定是从 sigmoid 中的「遗忘层」来实现的。

它的输入是 ht-1 和 xt,输出为一个 0 到 1 之间的数。Ct−1 就是每个在 cell 中所有在 0 和 1 之间的数值,就像我们刚刚所说的,0 代表全抛弃,1 代表全保留。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - 创想鸟        

下一步,我们需要决定什么样的信息应该被存储起来。这个过程主要分两步。

首先是 sigmoid 层(输入门)决定我们需要更新哪些值;

随后,tanh 层生成了一个新的候选向量 C`,它能够加入状态中。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - 创想鸟        

接下来,我们就可以更新 cell 的状态了。

将旧状态与 ft 相乘,忘记此前我们想要忘记的内容,然后加上 C`。此时遗忘门为ftft

得到的结果便是新的候选值,依照itit进行缩放。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - 创想鸟        

最后,我们需要决定要输出什么。此输出将基于我们处理后的单元状态。

首先,我们会运行一个 sigmoid 层决定 cell 状态输出哪一部分。

随后,我们把 cell 状态通过 tanh 函数,将输出值保持在-1 到 1 间。

之后,我们再乘以 sigmoid 门的输出值,就可以得到结果了。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - 创想鸟        


我们来细化一下长短期记忆网络的数学表达。

假设有hh个隐藏单元,批量大小为nn,输入数为dd。

因此,输入为Xt∈Rn×dXt∈Rn×d,

前一时间步的隐状态为Ht−1∈Rn×hHt−1∈Rn×h。

相应地,时间步tt的门被定义如下:

输入门是It∈Rn×hIt∈Rn×h,

遗忘门是Ft∈Rn×hFt∈Rn×h,

输出门是Ot∈Rn×hOt∈Rn×h。

它们的计算方法如下:

It=σ(XtWxi+Ht−1Whi+bi)It=σ(XtWxi+Ht−1Whi+bi)

Ft=σ(XtWxf+Ht−1Whf+bf),Ft=σ(XtWxf+Ht−1Whf+bf),

Ot=σ(XtWxo+Ht−1Who+bo)Ot=σ(XtWxo+Ht−1Who+bo)

其中Wxi,Wxf,Wxo∈Rd×hWxi,Wxf,Wxo∈Rd×h

和Whi,Whf,Who∈Rh×hWhi,Whf,Who∈Rh×h是权重参数,

bi,bf,bo∈R1×hbi,bf,bo∈R1×h是偏置参数。


我们将其中的一些操作集合命名为不同的记忆元名称

1.3 候选记忆元

由于还没有指定各种门的操作,所以先介绍候选记忆元(candidate memory cell) C~t∈Rn×hC~t∈Rn×h。 它的计算与上面描述的三个门的计算类似, 但是使用tanh⁡tanh函数作为激活函数,函数的值范围为(−1,1)(−1,1)。 下面导出在时间步tt处的方程:

C~t=tanh(XtWxc+Ht−1Whc+bc),C~t=tanh(XtWxc+Ht−1Whc+bc),

其中Wxc∈Rd×hWxc∈Rd×h和 Whc∈Rh×hWhc∈Rh×h是权重参数, bc∈R1×hbc∈R1×h是偏置参数。


1.4 记忆元

在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。 类似地,在长短期记忆网络中,也有两个门用于这样的目的: 输入门ItIt控制采用多少来自C~tC~t的新数据, 而遗忘门FtFt控制保留多少过去的 记忆元Ct−1∈Rn×hCt−1∈Rn×h的内容。 使用按元素乘法,得出:

Ct=Ft⊙Ct−1+It⊙C~t.Ct=Ft⊙Ct−1+It⊙C~t.

如果遗忘门始终为11且输入门始终为00, 则过去的记忆元Ct−1Ct−1 将随时间被保存并传递到当前时间步。 引入这种设计是为了缓解梯度消失问题, 并更好地捕获序列中的长距离依赖关系。


1.5 隐状态

最后,我们需要定义如何计算隐状态 Ht∈Rn×hHt∈Rn×h, 这就是输出门发挥作用的地方。 在长短期记忆网络中,它仅仅是记忆元的tanh⁡tanh的门控版本。 这就确保了HtHt的值始终在区间(−1,1)(−1,1)内:

Ht=Ot⊙tanh⁡(Ct).          (9.2.4)Ht=Ot⊙tanh(Ct).          (9.2.4)

只要输出门接近11,我们就能够有效地将所有记忆信息传递给预测部分, 而对于输出门接近00,我们只保留记忆元内的所有信息,而不需要更新隐状态。

2 从零开始实现

现在,我们从零开始实现长短期记忆网络。

我们首先加载时光机器数据集。

In [1]

import paddlefrom paddle import nnfrom d2l import paddle as d2limport paddle.nn.functional as Functionbatch_size, num_steps = 32, 35train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

   

2.1 初始化模型参数

接下来,我们需要定义和初始化模型参数。

如前所述,超参数num_hiddens定义隐藏单元的数量。

我们按照标准差0.010.01的高斯分布初始化权重,并将偏置项设为00。

In [2]

def get_lstm_params(vocab_size, num_hiddens):    num_inputs = num_outputs = vocab_size    def normal(shape):        return paddle.randn(shape=shape)*0.01    def three():        return (normal((num_inputs, num_hiddens)),                normal((num_hiddens, num_hiddens)),                paddle.zeros([num_hiddens]))    W_xi, W_hi, b_i = three()  # 输入门参数    W_xf, W_hf, b_f = three()  # 遗忘门参数    W_xo, W_ho, b_o = three()  # 输出门参数    W_xc, W_hc, b_c = three()  # 候选记忆元参数    # 输出层参数    W_hq = normal((num_hiddens, num_outputs))    b_q = paddle.zeros([num_outputs])    # 附加梯度    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,              b_c, W_hq, b_q]    for param in params:        param.stop_gradient = False    return params

   

2.2 定义模型

在[初始化函数]中,长短期记忆网络的隐状态需要返回一个额外的记忆元,单元的值为0,形状为(批量大小,隐藏单元数)。

因此,我们得到以下的状态初始化。

In [3]

def init_lstm_state(batch_size, num_hiddens):    return (paddle.zeros([batch_size, num_hiddens]),            paddle.zeros([batch_size, num_hiddens]))

   

实际模型的定义与我们前面讨论的一样:提供三个门和一个额外的记忆元。

请注意:只有隐状态才会传递到输出层,而记忆元CtCt不直接参与输出计算。In [4]

def lstm(inputs, state, params):    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,     W_hq, b_q] = params    (H, C) = state    outputs = []    for X in inputs:        I = Function.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)        F = Function.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)        O = Function.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)        C_tilda = paddle.tanh((X @ W_xc) + (H @ W_hc) + b_c)        C = F * C + I * C_tilda        H = O * paddle.tanh(C)        Y = (H @ W_hq) + b_q        outputs.append(Y)    return paddle.concat(outputs, axis=0), (H, C)

   

2.3 训练 和 预测

让我们通过实例化8.5节中,引入的RNNModelScratch类来训练一个长短期记忆网络。

此外,我们还加入了额外的模型测试。

In [6]

##  训练vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()num_epochs, lr = 500, 1.0model = d2l.RNNModelScratch(len(vocab), num_hiddens, device,get_lstm_params,                            init_lstm_state, lstm)d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

   In [10]

##  预测# 自定义 prefix , num_preds 进行预测prefix = 'tr'num_preds = 5net = modeld2l.predict_ch8(prefix, num_preds, net, vocab, device)

       

'treasth'

               

2.4 简洁实现

使用高级API,我们可以直接实例化LSTM模型。

高级API封装了前文介绍的所有配置细节。

这段代码的运行速度要快得多,因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

In [7]

num_inputs = vocab_sizelstm_layer = nn.LSTM(num_inputs, num_hiddens, time_major=True)model = d2l.RNNModel(lstm_layer, len(vocab))d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

   

2.5 结构拓展

比较流行的 LSTM 变体就是 Gers & Schmidhuber (2000) 提出的「猫眼连接」(peephole connections)的神经网络,也就是说,门连接层能够接收到 cell 的状态。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - 创想鸟        

上图展示了全加上「猫眼连接」的效果,但实际上论文中并不会加这么多。

另一种变体就是采用一对门,分别叫遗忘门(forget)及输入门(input)。

与分开决定遗忘及输入的内容不同,现在的变体会将这两个流程一同实现。

我们只有在将要输入新信息时才会遗忘,而也只会在忘记信息的同时才会有新的信息输入。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - 创想鸟        

一个比较知名的变体为 GRU(Gated Recurrent),由 Cho, et al. (2014) 提出。他将遗忘门与输入门结合在一起,名为**「更新门」**(update gate),并将 cell 状态与隐藏层状态合并在一起,此外还有一些小的改动。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - 创想鸟        

GRU和LSTM的区别:

LSTM有三个门,而GRU有两个门去掉了细胞单元C输出的时候取消了二阶的非线性函数

这个模型比起标准 LSTM 模型简单一些,因此也变得更加流行了。

当然,这里所列举的只是一管窥豹,还有很多其它的变体,

比如 Yao, et al. (2015) 提出的 Depth Gated RNNs;或是另辟蹊径处理长期依赖问题的 Clockwork RNNs,由 Koutnik, et al. (2014) 提出。

哪个是最好的呢?而这些变化是否真的意义深远?

Greff, et al. (2015) 曾经对比较流行的几种变种做过对比,发现它们基本上都差不多;

Jozefowicz, et al. (2015) 测试了超过一万种 RNN 结构,发现有一些能够在特定任务上超过 LSTMs。

以上就是『零基础+1』一文看懂LSTM原理-《动手学深度学习》的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/46256.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
利用Webman实现网站的响应式图片显示
上一篇 2025年11月7日 17:34:39
APP系统软件开发的总体流程分析
下一篇 2025年11月7日 17:34:40

相关推荐

  • python collections.Counter的计数

    Counter是Python中用于统计元素频次的高效工具,支持列表、字符串等可迭代对象;其以字典形式返回结果,键为元素,值为出现次数;可进行访问计数、获取最常见元素、更新或减去数据及数学运算;适用于词频统计、判断异位词和算法题等场景。 Python 的 collections.Counter 是一个…

    2026年5月10日
    000
  • js 怎样用defaults为对象数组添加默认值

    为 javascript 对象数组添加默认值的核心方法有三种:1. 使用 object.assign() 将默认值合并到每个对象的副本中,确保原始数据不变;2. 使用扩展运算符 ({ …defaults, …item }) 实现更简洁的浅层合并;3. 使用 lodash 的 …

    2026年5月10日
    000
  • c语言如何写脚本

    C 语言虽然不适合传统脚本编写,但通过模块化和库集成,可以创建强大的脚本。它可以通过以下步骤实现:模块化代码集成第三方库(如 Lua、Python、GNU Guile)创建脚本解释器实现脚本函数脚本文件格式设计优点:访问 C 语言的低级功能高性能可移植性缺点:学习曲线陡峭缺乏对动态类型的支持语法复杂…

    2026年5月10日
    000
  • 如何使用CSS Flexbox将导航栏精确地定位到右侧

    本教程详细介绍了如何利用CSS Flexbox技术,将网页导航栏(Nav Bar)精准地定位到容器的右侧,同时保持其背景透明。文章通过分析常见的布局问题,提供了基于Flexbox的优化解决方案,并深入解析了display: flex、flex-direction和align-items等关键CSS属…

    2026年5月10日
    000
  • Google TV 配对协议中的 SSL 握手失败与 Go 语言客户端证书处理

    本文旨在解决使用 Go 语言连接 Google TV 配对协议时遇到的 SSL 握手失败问题。核心在于 Google TV 要求客户端提供特定格式的客户端证书进行身份验证。文章将详细解释为何会发生握手失败,并提供解决方案,包括客户端证书的生成要求(特别是通用名称 CN 的格式),以及如何在 Go 语…

    2026年5月10日
    000
  • HTML行内样式怎么应用_HTML行内样式应用实例解析

    行内样式通过HTML元素的style属性定义CSS,优先级高于外部和内部样式表,适用于个别元素的快速调试与特殊设置。其语法为在标签内使用style属性,值为“属性: 值”形式的CSS声明,以分号分隔多个声明,如红色文字。典型应用包括文字样式调整、背景边框设置及尺寸布局控制,如蓝色加粗文本、带边框区块…

    2026年5月10日
    000
  • 网页标题怎么设置?title标签应该放在哪里?

    网页标题由html中 区域内的标签定义,必须且只能出现在该位置;2. 设置标题需在内插入标签并填入文本,如“我的个人博客”;3. 撰写标题时应包含核心关键词但避免堆砌,控制在50-60字符内,确保独特性与吸引力,并与内容高度相关;4. 未设置或设置不当会导致用户体验差、seo效果差、社交媒体分享效果…

    2026年5月10日
    000
  • 优化Django DetailView浏览量计数:避免重复递增与实现原子更新

    本文旨在解决Django DetailView中浏览量(views_count)重复递增的问题,特别是当使用get_object()方法进行计数时可能出现多次递增的现象。我们将深入探讨问题根源,并提供一种健壮的解决方案,通过将计数逻辑迁移至render_to_response()方法,并结合Djan…

    2026年5月10日
    000
  • 在Go语言Web应用中安全有效地检索HTTP Cookie

    本教程详细讲解了在go语言web应用中如何正确检索http cookie。我们将探讨`http.request.cookie()`方法的使用,重点关注常见的变量作用域问题及其解决方案,并提供一个健壮的代码示例,演示如何在处理cookie不存在的情况,以及如何将cookie值安全地传递给html模板进…

    2026年5月10日
    100
  • PHP多维数组怎么遍历_PHP多维数组遍历方法与代码示例

    遍历PHP多维数组需根据结构选择方法:固定层级用嵌套foreach,未知深度用递归函数或array_walk_recursive;常见陷阱包括深度不确定、非数组元素未检查、引用副作用及性能问题;筛选或修改数据可在遍历中加条件判断,结合引用修改原数组;扁平化常用递归+array_merge或array…

    2026年5月10日
    100
  • Go 性能剖析文件图形化可视化教程:使用 pprof 及 Graphviz

    本教程详细介绍了如何利用 Go 语言内置的 go tool pprof 工具对性能剖析文件进行图形化可视化。我们将解决常见的函数名显示问题,并通过 web 命令结合 Graphviz 生成直观的调用图,从而帮助开发者高效分析程序性能瓶颈。 1. 理解 Go 性能剖析与 pprof Go 语言提供了一…

    2026年5月10日
    000
  • 如何使用JavaScript高效筛选对象数组中具有重复name属性值的对象?

    javascript对象数组去重:筛选重复name属性值的对象 本文介绍如何使用JavaScript高效地从对象数组中筛选出具有重复name属性值的对象。 如果某个对象的name属性值在数组中出现多次,则保留所有具有该name值的对象;如果name属性值唯一,则将其删除。 示例数据: const a…

    2026年5月10日
    000
  • 在Laravel中计算JSON字段中数值的总和

    本教程详细介绍了如何在laravel应用中处理存储在数据库字段中的json字符串,并计算其中所有数值的总和。通过迭代eloquent模型集合,解析json数据,并对解析后的数值进行累加,为每个记录动态添加一个总和字段。 在现代Web应用开发中,将结构化数据以JSON格式存储在数据库的文本字段中是一种…

    2026年5月10日
    000
  • 如何在C++中实现单例模式?

    在c++++中实现单例模式可以通过静态成员变量和静态成员函数来确保类只有一个实例。具体步骤包括:1. 使用私有构造函数和删除拷贝构造函数及赋值操作符,防止外部直接实例化。2. 通过静态方法getinstance提供全局访问点,确保只创建一个实例。3. 为了线程安全,可以使用双重检查锁定模式。4. 使…

    2026年5月10日
    000
  • 优化Tkinter主题性能:解决UI卡顿与提升响应速度

    本文旨在探讨Tkinter应用中主题性能下降的问题,尤其是在Windows和macOS平台上使用图像密集型主题时。我们将分析导致UI卡顿的常见原因,并提供优化策略,包括选择高性能主题(如sv-ttk)、减少图像依赖,以及在必要时考虑其他现代GUI框架,以帮助开发者构建更流畅、响应更快的用户界面。 T…

    2026年5月10日
    000
  • python如何解决初始化执行次数

    初始化执行多次通常因对象重复创建或继承调用不当。1. 避免频繁实例化,复用对象可减少__init__调用;2. 使用单例模式通过__new__控制实例唯一性,并用标记确保__init__仅执行一次;3. 多重继承中应正确使用super(),依赖MRO机制避免父类__init__被重复调用;4. 可采…

    2026年5月10日
    000
  • JavaScript中的迭代器与生成器详解_js ES6+

    迭代器是遵循迭代器协议的对象,提供next()方法返回{value, done};2. 生成器函数用function*定义,通过yield暂停并返回值,自动实现迭代器接口。 在JavaScript ES6+中,迭代器(Iterator)和生成器(Generator)是处理数据序列的重要机制。它们让开…

    2026年5月10日
    100
  • 如何用Golang实现第一个CLI工具 详解cobra库创建命令行应用

    如何用Golang实现第一个CLI工具 详解cobra库创建命令行应用如何用Golang实现第一个CLI工具 详解cobra库创建命令行应用如何用Golang实现第一个CLI工具 详解cobra库创建命令行应用如何用Golang实现第一个CLI工具 详解cobra库创建命令行应用

    用golang实现cli工具可借助cobra库快速完成。1. 安装cobra:使用go install github.com/spf13/cobra-cli@latest;2. 初始化项目结构:运行cobra init –pkg-name mycli生成基础代码;3. 添加子命令:执行c…

    2026年5月10日 用户投稿
    000
  • Go语言中指针赋值的原子性与并发安全

    在go语言中,指针赋值操作并非天然原子性。在并发环境下,若不采取额外同步措施,对共享指针的读写可能导致数据竞争和不一致状态。本文将深入探讨go语言中确保指针赋值并发安全的方法,包括使用`sync.mutex`进行互斥保护,以及在特定场景下利用`sync/atomic`包实现原子操作。同时,也将提及通…

    2026年5月10日
    100
  • Golang Docker容器网络调试与问题排查实践

    首先检查容器网络模式与端口映射是否正确,确认使用-p参数暴露端口或host模式下服务绑定到0.0.0.0;接着验证Golang服务监听地址为0.0.0.0:8080而非127.0.0.1,并检查宿主机防火墙或安全组规则;然后通过自定义bridge网络实现容器间通信,利用curl测试连通性;最后借助n…

    2026年5月10日
    000

发表回复

登录后才能评论
关注微信