PyTorch中神经网络拟合圆形坐标平方和的收敛性优化

PyTorch中神经网络拟合圆形坐标平方和的收敛性优化

本教程旨在解决使用PyTorch神经网络拟合二维坐标 (x, y) 到其平方和 (x^2 + y^2) 时的收敛性问题。文章将深入探讨初始网络结构中存在的非线性表达能力不足、输入数据尺度不一以及超参数配置不当等常见挑战,并提供一套系统的优化策略,包括引入非线性激活函数、进行输入数据标准化以及精细调整训练周期和批处理大小,以显著提升模型的训练效率和拟合精度。

问题阐述:拟合圆形坐标平方和的挑战

我们的目标是构建一个pytorch神经网络,使其能够接收一个包含二维坐标 [x, y, 1] 的输入,并输出 x 和 y 的平方和 (x^2 + y^2)。这是一个典型的回归问题,但由于输出是非线性的二次函数,对神经网络的结构和训练策略提出了要求。

初始尝试的PyTorch代码如下所示:

import torchimport torch.nn as nnimport numpy as npfrom torch.utils.data import TensorDataset, DataLoaderimport torch.optim# 检查CUDA可用性device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 准备数据features = torch.tensor([[8.3572,-11.3008,1],[6.2795,-12.5886,1],[4.0056,-13.4958,1]                         ,[1.6219,-13.9933,1],[-0.8157,-14.0706,1],[-3.2280,-13.7250,1]                         ,[-5.5392,-12.9598,1],[-7.6952,-11.8073,1],[-9.6076,-10.3035,1],                         [-11.2532,-8.4668,1],[-12.5568,-6.3425,1],[-13.4558,-4.0691,1],                         [-13.9484,-1.7293,1],[-14.0218,0.7224,1],[-13.6791,3.1211,1],                         [-12.9064,5.4561,1],[-11.7489,7.6081,1],[-10.2251,9.5447,1],                         [5.4804,12.8044,1],[7.6332,11.6543,1],[9.5543,10.1454,1],                         [11.1890,8.3117,1],[12.4705,6.2460,1],[13.3815,3.9556,1],                         [13.8733,1.5884,1],[13.9509,-0.8663,1],[13.6014,-3.2793,1],                         [12.8572,-5.5526,1],[11.7042,-7.7191,1],[10.1761,-9.6745,1],                         [-8.4301,11.1605,1],[-6.3228,12.4433,1],[-4.0701,13.3401,1],                         [-1.6816,13.8352,1],[0.7599,13.9117,1],[3.1672,13.5653,1]]).to(device)labels = []for i in range(features.shape[0]):    label=(features[i][0])**2+(features[i][1])**2    labels.append(label)labels = torch.tensor(labels).to(device)# 定义网络结构(初始版本)num_input ,num_hidden,num_output = 3,64,1net = nn.Sequential(    nn.Linear(num_input,num_hidden),    nn.Linear(num_hidden,num_output)).to(device)# 权重初始化def init_weights(m):    if type(m) == nn.Linear:        nn.init.xavier_normal_(m.weight)net.apply(init_weights)loss = nn.MSELoss()num_epochs = 10batch_size = 6lr=0.001trainer = torch.optim.RAdam(net.parameters(),lr=lr)dataset = TensorDataset(features,labels)data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)print("初始训练过程中的损失:")for i in range (num_epochs):    for X,y in data_loader:        y_hat = net(X)        l = loss(y_hat,y.reshape(y_hat.shape))        trainer.zero_grad()        l.backward()        trainer.step()    with torch.no_grad():        print(f"Epoch {i+1}, Loss: {l.item():.4f}")

运行上述代码会发现,模型的损失值很高,且几乎无法收敛,这意味着网络未能有效地学习到 x^2 + y^2 这一关系。

收敛性问题分析

导致上述模型收敛困难的原因主要有以下几点:

缺乏非线性激活函数: 初始网络结构 nn.Sequential(nn.Linear(…), nn.Linear(…)) 仅由两个线性层组成。在没有非线性激活函数的情况下,无论堆叠多少个线性层,整个网络最终都等效于一个单一的线性变换。然而,我们要拟合的目标函数 x^2 + y^2 是一个典型的非线性函数,线性模型无法对其进行有效近似。输入数据尺度不一: 原始的 x 和 y 坐标值范围较大,且没有经过标准化处理。这可能导致梯度在不同维度上大小不一,使得优化器难以有效地找到最优解,容易陷入局部最优或导致训练过程不稳定。超参数配置不当: 初始的训练周期 num_epochs = 10 和批处理大小 batch_size = 6 对于学习这样一个非线性函数可能不足以使模型充分学习或稳定收敛。

优化策略一:引入非线性激活函数

为了使神经网络能够学习非线性关系,我们必须在网络层之间引入非线性激活函数。对于多层感知机(MLP),常用的激活函数包括ReLU(Rectified Linear Unit)、Sigmoid、Tanh等。在这里,我们选择 ReLU,它计算简单且能有效缓解梯度消失问题。

将 nn.ReLU() 添加到第一个线性层之后,网络结构将变为:nn.Sequential(nn.Linear(num_input, num_hidden), nn.ReLU(), nn.Linear(num_hidden, num_output))。

优化策略二:数据预处理——标准化输入

数据标准化是深度学习中的一项关键预处理步骤,它能将不同尺度的特征转换到相似的范围内。常用的方法是Z-score标准化(也称作均值-标准差标准化),即将数据调整为均值为0、标准差为1的分布。这有助于:

加速收敛: 使得损失函数的等高线更接近圆形,优化器(如梯度下降)可以更直接地向最小值移动,而不是在“狭长”的区域内震荡。防止梯度爆炸/消失: 确保所有输入特征对模型权重的更新具有相似的影响,避免某些特征因数值过大而主导梯度,或因数值过小而导致梯度消失。

对于我们的 features 数据,x 和 y 坐标位于前两列,第三列是常数 1。我们只需要对 x 和 y 进行标准化。

# 数据标准化mean = features[:,:2].mean(dim=0)std = features[:,:2].std(dim=0)features[:,:2] = (features[:,:2] - mean) / (std + 1e-5) # 添加一个小的epsilon防止除以零

优化策略三:超参数调优

适当的超参数配置对模型训练的成功至关重要。

增加训练周期 (num_epochs): 初始的10个训练周期对于模型学习复杂的非线性模式通常是不够的。增加训练周期可以让模型有更多机会迭代更新权重,从而更好地拟合数据。我们将 num_epochs 增加到100。

num_epochs = 100 # 增加训练周期

调整批处理大小 (batch_size): 批处理大小会影响梯度估计的稳定性和训练速度。较小的 batch_size(例如2)可以提供更频繁的权重更新,每次更新的梯度估计可能噪声更大,但在某些情况下能帮助模型跳出局部最优,或在数据集较小时表现更好。

batch_size = 2 # 调整批处理大小

整合优化后的完整代码

将上述所有优化策略整合到原始代码中,得到一个更健壮、更易收敛的PyTorch训练脚本。

import torchimport torch.nn as nnimport numpy as npfrom torch.utils.data import TensorDataset, DataLoaderimport torch.optim# 检查CUDA可用性device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 原始数据features_raw = torch.tensor([[8.3572,-11.3008,1],[6.2795,-12.5886,1],[4.0056,-13.4958,1]                         ,[1.6219,-13.9933,1],[-0.8157,-14.0706,1],[-3.2280,-13.7250,1]                         ,[-5.5392,-12.9598,1],[-7.6952,-11.8073,1],[-9.6076,-10.3035,1],                         [-11.2532,-8.4668,1],[-12.5568,-6.3425,1],[-13.4558,-4.0691,1],                         [-13.9484,-1.7293,1],[-14.0218,0.7224,1],[-13.6791,3.1211,1],                         [-12.9064,5.4561,1],[-11.7489,7.6081,1],[-10.2251,9.5447,1],                         [5.4804,12.8044,1],[7.6332,11.6543,1],[9.5543,10.1454,1],                         [11.1890,8.3117,1],[12.4705,6.2460,1],[13.3815,3.9556,1],                         [13.8733,1.5884,1],[13.9509,-0.8663,1],[13.6014,-3.2793,1],                         [12.8572,-5.5526,1],[11.7042,-7.7191,1],[10.1761,-9.6745,1],                         [-8.4301,11.1605,1],[-6.3228,12.4433,1],[-4.0701,13.3401,1],                         [-1.6816,13.8352,1],[0.7599,13.9117,1],[3.1672,13.5653,1]]).to(device)# 创建可变副本进行标准化features = features_raw.clone().detach()# 数据标准化:对x, y坐标进行均值-标准差标准化mean = features[:,:2].mean(dim=0)std = features[:,:2].std(dim=0)features[:,:2] = (features[:,:2] - mean) / (std + 1e-5) # 添加一个小的epsilon防止除以零# 计算标签 (labels不需要标准化,因为它们是目标输出)labels = []for i in range(features_raw.shape[0]): # 注意:标签基于原始数据计算    label=(features_raw[i][0])**2+(features_raw[i][1])**2    labels.append(label)labels = torch.tensor(labels).to(device)# 定义网络结构(优化版本:添加ReLU激活函数)num_input ,num_hidden,num_output = 3,64,1net = nn.Sequential(    nn.Linear(num_input,num_hidden),    nn.ReLU(), # 引入非线性激活函数    nn.Linear(num_hidden,num_output)).to(device)# 权重初始化def init_weights(m):    if type(m) == nn.Linear:        nn.init.xavier_normal_(m.weight)        # 偏置项通常初始化为0或小常数,nn.Linear默认已处理net.apply(init_weights)loss = nn.MSELoss()# 超参数调优num_epochs = 100 # 增加训练周期batch_size = 2   # 调整批处理大小lr=0.001trainer = torch.optim.RAdam(net.parameters(),lr=lr)dataset = TensorDataset(features,labels)data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)print("n优化后的训练过程中的损失:")for epoch in range (num_epochs):    for X,y in data_loader:        y_hat = net(X)        l = loss(y_hat,y.reshape(y_hat.shape))        trainer.zero_grad()        l.backward()        trainer.step()    with torch.no_grad():        if (epoch + 1) % 10 == 0 or epoch == 0: # 每10个epoch打印一次,以及第1个epoch            print(f"Epoch {epoch+1}, Loss: {l.item():.4f}")# 训练结束后,可以进行模型评估或预测# 示例:使用训练后的模型对部分数据进行预测net.eval() # 将模型

以上就是PyTorch中神经网络拟合圆形坐标平方和的收敛性优化的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1374318.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年12月14日 14:02:52
下一篇 2025年12月14日 14:03:04

相关推荐

  • YOLOv8 视频帧级对象分类与结果解析教程

    本教程详细阐述了如何使用YOLOv8模型对视频帧进行逐帧对象分类,并着重解决了在处理模型预测结果时常见的类名提取错误。文章将指导读者正确解析YOLOv8的预测输出,通过迭代每个检测到的边界框来获取其对应的类别ID和名称,从而实现准确的帧分类和后续处理,如根据类别堆叠视频帧,确保数据处理的准确性和逻辑…

    好文分享 2025年12月14日
    000
  • PyTorch 神经网络拟合 x^2+y^2 函数的实践与优化

    本文探讨了如何使用 PyTorch 神经网络拟合圆周坐标的平方和函数 x^2+y^2。针对初始模型训练过程中遇到的高损失和难以收敛的问题,文章提供了详细的优化策略,包括对输入数据进行标准化处理、调整训练轮次(epochs)以及优化批次大小(batch_size)。通过这些方法,显著提升了模型的收敛性…

    2025年12月14日
    000
  • Python pyheif库在Windows上的安装挑战与解决方案

    在Windows系统上安装pyheif库时常遭遇F%ignore_a_1%led building wheel错误,根本原因在于其底层依赖libheif库在Windows环境下缺乏便捷的编译与安装途径。本文深入分析了此问题,并提供了多种实用的解决方案,包括利用Windows Subsystem fo…

    2025年12月14日
    000
  • 使用PyTorch训练神经网络计算坐标平方和

    本文详细阐述了如何使用PyTorch构建并训练一个神经网络,使其能够根据输入的二维坐标[x, y, 1]计算并输出x^2 + y^2。文章首先分析了初始实现中遇到的收敛困难,随后深入探讨了通过输入数据标准化、增加训练周期以及调整批量大小等关键优化策略来显著提升模型性能和收敛速度,并提供了完整的优化代…

    2025年12月14日
    000
  • 解决Python虚拟环境下WebSocket回调不执行的问题:主线程阻塞策略

    本文探讨并解决了Python虚拟环境下WebSocket回调函数(如on_ticks)不执行的问题。核心原因是主线程在异步操作完成前过早退出,导致回调机制无法被触发。解决方案是通过阻塞主线程,确保程序有足够时间接收并处理来自WebSocket的异步数据,从而使回调函数正常工作。 问题现象分析 在使用…

    2025年12月14日
    000
  • 解决GridSearchCV中n_splits与类别成员数冲突的策略

    在使用sklearn的GridSearchCV进行模型调优时,当cv参数设置为整数且用于分类任务时,默认会执行分层K折交叉验证。如果数据集中最小类别的样本数量小于指定的n_splits值,将抛出ValueError。本文将深入解析此错误的原因,并提供两种有效的解决方案:调整折叠数或显式使用非分层K折…

    2025年12月14日
    000
  • Python程序打包后进程无限复制的解决方案

    问题描述 在使用 PyInstaller 将 Python 脚本打包成可执行文件后,可能会遇到一个令人头疼的问题:程序在运行时会不断地复制自身进程,最终导致系统资源耗尽并崩溃。这种现象通常发生在涉及到屏幕截图等操作的程序中。 原因分析 该问题通常与特定的第三方库在打包后的行为有关。在本例中,问题出在…

    2025年12月14日
    000
  • 无限进程克隆:PyInstaller打包Python截图脚本的解决方案

    本文将针对使用PyInstaller打包Python截图脚本时可能遇到的无限进程克隆问题提供解决方案。这类问题通常表现为程序在打包成可执行文件后,运行时会不断产生新的进程,最终导致系统资源耗尽并崩溃。我们将分析可能的原因,并提供一种可行的替代方案,帮助你成功打包并运行截图脚本。 问题分析 使用PyI…

    2025年12月14日
    000
  • 解决FastAPI服务器因长时间请求而冻结的问题

    第一段引用上面的摘要: 本文旨在解决FastAPI应用在高并发场景下,由于同步阻塞操作导致服务器冻结的问题。通过分析同步代码阻塞事件循环的原理,提供了使用异步替代方案或将阻塞操作迁移至线程池的解决方案,以提升FastAPI应用的并发处理能力和响应速度。 FastAPI 作为一个现代化的 Web 框架…

    2025年12月14日
    000
  • Pydantic V2 ValidationError 警告的解决与迁移指南

    在升级到 Pydantic V2 (例如 2.5.2) 或更高版本后,你可能会在日志中看到如下警告: /usr/local/lib/python3.12/site-packages/pydantic/_migration.py:283: UserWarning: `pydantic.error_wr…

    2025年12月14日
    000
  • YOLOv8视频帧目标分类:正确提取预测类别与帧处理实践

    本文详细阐述了在使用YOLOv8进行视频帧目标分类时,如何准确提取每个检测框的预测类别信息。针对常见的错误,即误用模型整体类别列表的第一个元素,文章提供了正确的迭代方法,通过访问每个检测框的cls属性来获取其对应的类别ID,并据此从模型类别字典中检索正确的类别名称。同时,文章结合视频处理场景,给出了…

    2025年12月14日
    000
  • YOLOv8视频帧目标检测:精确类别提取与处理指南

    本文旨在解决YOLOv8模型在视频帧处理中常见的类别识别错误问题。通过深入解析YOLOv8的预测结果结构,特别是result.boxes和result.names属性,文章将指导读者如何正确提取每个检测对象的实际类别名称,而非误用固定索引。教程提供了详细的代码示例,确保视频帧能被准确地分类和处理,从…

    2025年12月14日
    000
  • YOLOv8视频帧多类别检测:正确提取预测类别名称的实践指南

    本文详细阐述了在使用YOLOv8模型对视频帧进行多类别目标检测时,如何准确地从预测结果中提取每个检测到的对象的类别名称。文章纠正了常见的results.names[0]误用,并通过示例代码演示了正确的迭代boxes并利用box.cls获取精确类别ID的方法,确保在视频处理流程中正确分类和处理每一帧的…

    2025年12月14日
    000
  • YOLOv8视频帧多类别目标检测:正确解析与处理预测结果

    本教程详细阐述了在使用YOLOv8模型对视频帧进行多类别目标检测时,如何正确解析模型预测结果,避免将不同类别的检测混淆。我们将重点解决从results对象中准确提取每个检测框的类别名称,并根据类别对视频帧进行分类存储和可视化,确保数据处理的准确性和一致性。 YOLOv8预测结果解析的常见误区 在使用…

    2025年12月14日
    000
  • 如何在文本游戏中将物品从房间放入背包

    本文旨在解决在文本冒险游戏中,玩家无法将房间内的物品添加到背包的问题。通过分析常见错误,例如字典访问方式不正确,以及物品判断逻辑的缺失,提供清晰的代码示例和步骤,帮助开发者构建一个可用的物品收集系统,从而提升游戏体验。 在开发文本冒险游戏时,一个核心功能就是允许玩家从房间中拾取物品并将它们放入背包。…

    2025年12月14日
    000
  • 如何在文本冒险游戏中将物品从房间放入背包

    本文档旨在解决在文本冒险游戏中,玩家无法将房间内的物品放入背包的问题。通过分析游戏代码,找出错误原因,并提供正确的代码示例,帮助开发者实现物品拾取功能,完善游戏逻辑。 理解游戏逻辑 在文本冒险游戏中,玩家通常通过输入指令与游戏世界互动。其中一个常见的功能就是拾取物品。实现这一功能需要以下几个关键步骤…

    2025年12月14日
    000
  • Python 错误与异常处理从入门到精通

    答案:Python通过try-except处理异常,支持自定义异常类、多异常捕获及traceback、pdb和logging等调试方法,提升程序健壮性。 Python 错误与异常处理,简单来说,就是让你的代码在出错时不要直接崩溃,而是优雅地处理问题,甚至继续运行下去。这不仅能提升用户体验,也是保证程…

    2025年12月14日
    000
  • Arduino与Raspberry Pi CM4串口通信速度慢的解决方案

    在Arduino项目中,经常需要使用串口进行设备间的通信,例如Raspberry Pi与ESP8266之间的通信。然而,有时会遇到串口通信速度慢的问题,导致数据传输延迟。本文将针对这种问题进行分析,并提供解决方案。 问题分析 在提供的代码中,Raspberry Pi通过串口向ESP8266发送PWM…

    2025年12月14日
    000
  • 解决 PySpark 查询中的 Column Ambiguous 错误

    正如摘要所述,本文旨在帮助读者理解并解决在使用 PySpark 进行 DataFrame 连接操作时遇到的 “Column Ambiguous” 错误。我们将深入探讨该错误的原因,并提供明确的解决方案,包括使用别名和限定列名等方法,确保你的 PySpark 代码能够高效且准确…

    2025年12月14日
    000
  • 解决PySpark查询中的列名歧义错误:一份详细指南

    正如摘要所述,本文旨在帮助读者理解和解决在使用PySpark进行数据帧(DataFrame)连接操作时可能遇到的“列名歧义”错误。通过分析错误原因,提供详细的解决方案,并给出示例代码,帮助读者避免和解决类似问题,提升PySpark数据处理能力。 在PySpark中,当多个数据帧包含相同名称的列,并且…

    2025年12月14日
    000

发表回复

登录后才能评论
关注微信