R-Drop论文复现

R-Drop是基于Dropout的改进正则化方法,通过对模型输出层施加约束减少过拟合。其让每个样本两次通过带Dropout的同一模型,用KL散度约束两次输出一致,总损失为交叉熵与KL散度之和。代码实现仅增KL项,实验显示能有效提升模型正确率。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

r-drop论文复现 - 创想鸟

R-Drop: Regularized Dropout for Neural Networks

  由于深度神经网络非常容易过拟合,因此 Dropout 方法采用了随机丢弃每层的部分神经元,以此来避免在训练过程中的过拟合问题。正是因为每次随机丢弃部分神经元,导致每次丢弃后产生的子模型都不一样,所以 Dropout 的操作一定程度上使得训练后的模型是一种多个子模型的组合约束。基于 Dropout 的这种特殊方式对网络带来的随机性,研究员们提出了 R-Drop 来进一步对(子模型)网络的输出预测进行了正则约束。论文通过实验得出一种改进的正则化方法R-dropout,简单来说,它通过使用若干次(论文中使用了两次)dropout,定义新的损失函数。实验结果表明,尽管结构非常简单,但是却能很好的防止模型过拟合,进一步提高模型的正确率。模型主体如下图所示。

R-Drop论文复现 - 创想鸟        

论文贡献

  由于深度神经网络非常容易过拟合,因此 Dropout 方法采用了随机丢弃每层的部分神经元,以此来避免在训练过程中的过拟合问题。正是因为每次随机丢弃部分神经元,导致每次丢弃后产生的子模型都不一样,所以 Dropout 的操作一定程度上使得训练后的模型是一种多个子模型的组合约束。基于 Dropout 的这种特殊方式对网络带来的随机性,研究员们提出了 R-Drop 来进一步对(子模型)网络的输出预测进行了正则约束。

实现思路

  与传统作用于神经元(Dropout)或者模型参数(DropConnect)上的约束方法不同,R-Drop 作用于模型的输出层,弥补了 Dropout 在训练和测试时的不一致性。简单来说就是在每个 mini-batch 中,每个数据样本过两次带有 Dropout 的同一个模型,R-Drop 再使用 KL-divergence 约束两次的输出一致。既约束了由于 Dropout 带来的两个随机子模型的输出一致性。

R-Drop论文复现 - 创想鸟        

论文公式

模型的训练目标包含两个部分,一个是两次输出之间的KL散度,如下:

R-Drop论文复现 - 创想鸟        

另一个是模型自有的损失函数交叉熵,如下:

R-Drop论文复现 - 创想鸟        

总损失函数为:

R-Drop论文复现 - 创想鸟        

代码实现

与传统的训练方法相比,R- Drop 只是简单增加了一个 KL-divergence 损失函数项,并没有其他任何改动。其PaddlePaddle版本对应的代码实现如下所示。

散度损失

交叉熵=熵+相对熵(KL散度) 其与交叉熵的关系如下:

R-Drop论文复现 - 创想鸟        

代码实现示意

import paddle.nn.functional as F# define your task model, which outputs the classifier logitsmodel = TaskModel()def compute_kl_loss(self, p, q, pad_mask=None):        p_loss = F.kl_div(F.log_softmax(p, axis=-1), F.softmax(q, axis=-1), reduction='none')    q_loss = F.kl_div(F.log_softmax(q, axis=-1), F.softmax(p, axis=-1), reduction='none')        # pad_mask is for seq-level tasks    if pad_mask is not None:        p_loss.masked_fill_(pad_mask, 0.)        q_loss.masked_fill_(pad_mask, 0.)    # You can choose whether to use function "sum" and "mean" depending on your task    p_loss = p_loss.sum()    q_loss = q_loss.sum()    loss = (p_loss + q_loss) / 2    return loss# keep dropout and forward twicelogits = model(x)logits2 = model(x)# cross entropy loss for classifierce_loss = 0.5 * (cross_entropy_loss(logits, label) + cross_entropy_loss(logits2, label))kl_loss = compute_kl_loss(logits, logits2)# 论文中对于CV任务的超参数α = 0.6# carefully choose hyper-parametersloss = ce_loss + α * kl_loss

   

代码实现实战

项目说明

本次实验以白菜生长的四个周期为例,进行生长情况识别实验。数据来自于讯飞的比赛。数据展示如下:发芽期、幼苗期、莲座期、结球期。

R-Drop论文复现 - 创想鸟 R-Drop论文复现 - 创想鸟R-Drop论文复现 - 创想鸟 R-Drop论文复现 - 创想鸟        

In [ ]

!cd 'data/data107306' && unzip -q img.zip!cd 'data/data106868' && unzip -q pdweights.zip

   In [ ]

# 导入所需要的库from sklearn.utils import shuffleimport osimport pandas as pdimport numpy as npfrom PIL import Imageimport paddleimport paddle.nn as nnfrom paddle.io import Datasetimport paddle.vision.transforms as Timport paddle.nn.functional as Ffrom paddle.metric import Accuracyimport warningswarnings.filterwarnings("ignore")# 读取数据train_images = pd.read_csv('data/data107306/img/df_all.csv')train_images = shuffle(train_images)# 划分训练集和校验集all_size = len(train_images)# print(all_size)train_size = int(all_size * 0.9)train_image_list = train_images[:train_size]val_image_list = train_images[train_size:]train_image_path_list = train_image_list['image'].valueslabel_list = train_image_list['label'].valuestrain_label_list = paddle.to_tensor(label_list, dtype='int64')val_image_path_list = val_image_list['image'].valuesval_label_list1 = val_image_list['label'].valuesval_label_list = paddle.to_tensor(val_label_list1, dtype='int64')# 定义数据预处理data_transforms = T.Compose([    T.Resize(size=(256, 256)),       T.Transpose(),    # HWC -> CHW    T.Normalize(               mean = [0, 0, 0],        std = [255, 255, 255],        to_rgb=True)    ])# 构建Datasetclass MyDataset(paddle.io.Dataset):    """    步骤一:继承paddle.io.Dataset类    """    def __init__(self, train_img_list, val_img_list,train_label_list,val_label_list, mode='train'):        """        步骤二:实现构造函数,定义数据读取方式,划分训练和测试数据集        """        super(MyDataset, self).__init__()        self.img = []        self.label = []        self.valimg = []        self.vallabel = []        # 借助pandas读csv的库        self.train_images = train_img_list        self.test_images = val_img_list        self.train_label = train_label_list        self.test_label = val_label_list        # self.mode = mode        if mode == 'train':             # 读train_images的数据            for img,la in zip(self.train_images, self.train_label):                self.img.append('data/data107306/img/imgV/'+img)                self.label.append(la)        else :            # 读test_images的数据            for img,la in zip(self.test_images, self.test_label):                self.img.append('data/data107306/img/imgV/'+img)                self.label.append(la)    def load_img(self, image_path):        # 实际使用时使用Pillow相关库进行图片读取即可,这里我们对数据先做个模拟        image = Image.open(image_path).convert('RGB')        image = np.array(image).astype('float32')        return image    def __getitem__(self, index):        """        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)        """        image = self.load_img(self.img[index])        label = self.label[index]               return data_transforms(image), label    def __len__(self):        """        步骤四:实现__len__方法,返回数据集总数目        """        return len(self.img)

   In [ ]

#train_loadertrain_dataset = MyDataset(train_img_list=train_image_path_list, val_img_list=val_image_path_list, train_label_list=train_label_list, val_label_list=val_label_list, mode='train')train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=8, shuffle=True, num_workers=0)#val_loaderval_dataset = MyDataset(train_img_list=train_image_path_list, val_img_list=val_image_path_list, train_label_list=train_label_list, val_label_list=val_label_list, mode='test')val_loader = paddle.io.DataLoader(val_dataset, places=paddle.CPUPlace(), batch_size=8, shuffle=True, num_workers=0)

   In [ ]

from work.senet154 import SE_ResNeXt50_vd_32x4dfrom work.res2net import Res2Net50_vd_26w_4sfrom work.se_resnet import SE_ResNet50_vd# 模型封装# model_re2 = SE_ResNeXt50_vd_32x4d(class_num=4)model_re2 = Res2Net50_vd_26w_4s(class_dim=4)model_ss = SE_ResNet50_vd(class_num=4)model_ss.train()model_re2.train()epochs = 2optim1 = paddle.optimizer.Adam(learning_rate=3e-4, parameters=model_re2.parameters())optim2 = paddle.optimizer.Adam(learning_rate=3e-4, parameters=model_ss.parameters())

   In [ ]

import paddle.nn.functional as Fdef compute_kl_loss(p, q, pad_mask=None):    p_loss = F.kl_div(F.log_softmax(p, axis=-1), F.softmax(q, axis=-1), reduction='none')    q_loss = F.kl_div(F.log_softmax(q, axis=-1), F.softmax(p, axis=-1), reduction='none')        # pad_mask is for seq-level tasks    if pad_mask is not None:        p_loss.masked_fill_(pad_mask, 0.)        q_loss.masked_fill_(pad_mask, 0.)    # You can choose whether to use function "sum" and "mean" depending on your task    p_loss = p_loss.sum()    q_loss = q_loss.sum()    loss = (p_loss + q_loss) / 2    return loss

   In [7]

# 用Adam作为优化函数for epoch in range(epochs):    for batch_id, data in enumerate(train_loader()):        x_data = data[0]        y_data = data[1]        predicts1 = model_re2(x_data)        predicts2 = model_ss(x_data)                loss1 = F.cross_entropy(predicts1, y_data, soft_label=False)        loss2 = F.cross_entropy(predicts2, y_data, soft_label=False)                # cross entropy loss for classifier        ce_loss = 0.5 * (loss1 + loss2)        kl_loss = compute_kl_loss(predicts1, predicts2)        # 论文中对于CV任务的超参数        α = 0.6        # carefully choose hyper-parameters        loss = ce_loss + α * kl_loss        # 计算损失        acc1 = paddle.metric.accuracy(predicts1, y_data)        acc2 = paddle.metric.accuracy(predicts2, y_data)                loss.backward()        if batch_id % 50 == 0:            print("epoch: {}, batch_id: {}, loss1 is: {}".format(epoch, batch_id, loss.numpy()))        optim1.step()        optim1.clear_grad()            optim2.step()        optim2.clear_grad()

   

以上就是R-Drop论文复现的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/42659.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月6日 20:50:11
下一篇 2025年11月6日 20:53:11

相关推荐

  • Go语言中切片迭代与元素修改的正确姿势

    在go语言中,使用`for…range`循环遍历切片时,理解其迭代行为对于正确访问和修改元素至关重要。本文将深入探讨`for…range`在单变量和双变量模式下的行为差异,特别是当尝试修改切片元素时可能遇到的常见陷阱,并提供通过索引进行修改的正确方法,以避免“undefine…

    2025年12月16日
    000
  • 如何在Go语言中高效读取文本文件:整文件与逐行处理

    本文详细介绍了在go语言中读取文本文件的两种主要方法:一次性读取整个文件和逐行扫描。我们将探讨`ioutil.readfile`与`strings.split`的组合,适用于小型文件,以及`bufio.scanner`的逐行处理机制,更适合大型文件以优化内存使用。文章将提供清晰的代码示例、错误处理实…

    2025年12月16日
    000
  • 如何在Go语言中正确解组带有命名空间的XML属性

    本文详细阐述了在go语言中使用`encoding/xml`包解组带有命名空间前缀(如`xlink:href`)的xml属性的正确方法。核心在于理解xml命名空间的作用,并确保xml源数据中包含完整的命名空间声明。在go结构体标签中,应使用命名空间url与属性名相结合的方式来准确映射这些属性,从而成功…

    2025年12月16日
    000
  • 深入理解Go语言中字符串与字节切片的比较及用户输入处理

    本文旨在探讨go语言中`string`类型与`[]byte`(字节切片)之间的核心差异,并针对用户输入场景下常见的比较问题提供解决方案。我们将详细分析`bufio.readbytes`等函数如何处理换行符,并提供实用的代码示例,以确保在比较用户输入时能够准确无误地进行。 在Go语言开发中,处理用户输…

    2025年12月16日
    000
  • Go语言中高效读取文本文件:按行处理的实践指南

    本文详细介绍了在go语言中读取文本文件并按行处理的多种方法。重点讲解了如何使用`ioutil.readfile`结合`strings.split`函数,将文件内容一次性读入内存并分割成字符串切片,适用于中小型文件。同时,也简要提及了`bufio.scanner`在处理大型文件时的优势,帮助开发者根据…

    2025年12月16日
    000
  • 如何在Golang中优化多文件并发读写_Golang多文件并发读写性能优化方法汇总

    使用sync.Pool复用缓冲区降低GC压力,通过信号量或缓冲channel限制并发数防止资源耗尽,结合io.Copy、bufio等工具减少系统调用,合理设置文件打开模式并复用文件句柄,避免频繁读写导致性能下降。 在Golang中处理多文件并发读写时,性能和资源管理是关键。不当的并发控制可能导致文件…

    2025年12月16日
    000
  • 使用约束条件创建自定义类型:Go 语言实战教程

    本文将介绍如何在 Go 语言中创建具有约束条件的自定义类型,以确保类型只能接受预定义的一组有效值。我们将通过示例代码演示如何实现这一目标,并讨论不同实现方式的优缺点,帮助你选择最适合自己场景的方案。 在 Go 语言中,虽然没有像其他一些语言那样直接支持枚举或受限类型,但我们可以通过一些技巧来模拟实现…

    2025年12月16日
    000
  • Go语言高流量UDP服务内存泄漏排查与解决:defer闭包与版本升级

    本文探讨go语言在高流量udp日志处理服务中遇到的内存暴涨问题。通过`pprof`分析发现`newdefer`函数占用大量内存,根源在于go早期版本中`defer`闭包的内存泄漏。文章提供了通过升级go版本解决该问题的方案,并强调了编写健壮代码、避免不必要的`panic`以减少`defer`开销的重…

    2025年12月16日
    000
  • Go语言中将exec.Cmd标准输出重定向到文件的最佳实践

    本文将介绍在go语言中如何将`os/exec`包执行的外部命令的标准输出(stdout)高效地重定向并写入到文件中。通过将目标文件直接赋值给`exec.cmd`的`stdout`字段,可以实现简洁且可靠的输出捕获,避免了手动管理管道和协程的复杂性,确保命令执行结果准确地保存到指定文件。 在Go语言中…

    2025年12月16日
    000
  • MySQL INSERT 语句:提升可读性的 SET 语法

    在mysql中,传统的`insert … values`语法在处理大量列时,其值与列名的对应关系不易辨识,导致语句可读性下降。本文将介绍如何利用mysql特有的`insert … set`语法,通过明确的`列名 = 值`对,显著提升`insert`语句的清晰度和维护性,使数据…

    2025年12月16日
    000
  • 创建带约束的自定义类型:Go语言实践指南

    本文介绍了如何在 Go 语言中创建自定义类型,并限制其可接受的值。通过示例代码,展示了两种实现方式:使用结构体和使用类型别名,并讨论了各自的优缺点。帮助开发者构建更健壮、更安全的代码。 Go 语言允许开发者创建自定义类型,以增强代码的可读性和类型安全性。然而,有时我们需要更进一步,限制自定义类型可以…

    2025年12月16日
    000
  • Golang如何实现并发文件写入

    使用互斥锁可确保多goroutine安全写入同一文件,通过sync.Mutex实现原子操作;采用channel结合生产者-消费者模型能提升效率与可扩展性,由单一goroutine集中写入;若无需共用文件,可让每个goroutine写独立文件最后合并,避免竞争。 Go语言通过goroutine和cha…

    2025年12月16日
    000
  • 在Go语言中为macOS创建OpenGL 3.2上下文的指南

    本文旨在解决在macos系统上使用go语言和glfw库创建opengl 3.2上下文时遇到的常见问题。核心在于,除了设置主次版本号和核心配置文件外,还需要明确启用opengl前向兼容性,并确保glfw库的初始化顺序正确无误,才能成功获取到高于opengl 2.1的上下文版本。 理解macOS上的Op…

    2025年12月16日
    000
  • Go语言中获取命令行用户输入的实用指南

    本教程将详细介绍在go语言中如何高效、安全地从命令行获取用户输入。我们将重点探讨使用`bufio`包配合`os.stdin`实现交互式输入的方法,并提供清晰的代码示例,帮助开发者轻松处理用户在终端中键入的数据,包括输入提示、读取整行内容及错误处理机制。 在开发命令行工具或交互式程序时,从用户那里获取…

    2025年12月16日
    000
  • Go 语言中 readUInt16BE 的等效实现与字节序处理

    本文详细介绍了如何在 go 语言中实现 node.js `buffer.readuint16be` 的功能。通过 `encoding/binary` 包,我们可以高效地处理字节序,实现从字节切片中读取和写入无符号16位整数。文章将演示如何使用 `binary.bigendian.uint16` 和 …

    2025年12月16日
    000
  • Go语言中实现JSON字段的单向序列化与反序列化:结构体分离策略

    本文探讨了在go语言中如何实现json字段的单向处理,即允许字段从json反序列化(读取)但阻止其在序列化(写入)时出现。针对 `json:”-“` 标签无法满足此需求的问题,文章提出了一种有效的结构体分离策略。通过定义两个语义不同的结构体——一个用于内部完整数据表示,另一个…

    2025年12月16日
    000
  • 如何识别Go二进制文件编译时使用的Go版本

    本文介绍了一种简单有效的方法,通过命令行工具`strings`和`grep`来检测go二进制文件编译时所使用的go版本。这对于验证编译环境、排查版本兼容性问题或确认特定go版本的使用情况非常有用,尤其是在多go环境共存的工作站上。 识别Go二进制文件编译版本的方法 在开发和部署Go应用程序时,尤其是…

    2025年12月16日
    000
  • Golang如何在虚拟机中搭建开发环境_Golang虚拟环境配置完整方案

    在虚拟机中搭建Go开发环境可隔离依赖并便于测试。使用VirtualBox等工具创建Ubuntu 22.04或CentOS Stream 9虚拟机,配置至少2GB内存、20GB硬盘及桥接/NAT网络;安装后更新系统包,下载Go 1.21.5并解压至/usr/local,配置PATH、GOPATH环境变…

    2025年12月16日
    000
  • 掌握Go语言中命令行参数与用户输入处理技巧

    本文深入探讨go语言中处理命令行参数和用户输入的实用技巧。我们将学习如何检查并解析可选的命令行参数,为程序提供灵活的启动配置;同时,也将解决使用`fmt.scanf`时无法识别空行输入的问题,并介绍更健壮的行输入方法,确保程序能准确响应用户的回车操作,提升交互体验。 在Go语言中开发命令行工具或交互…

    2025年12月16日
    000
  • Go语言中带值约束的自定义类型实现指南

    本教程探讨go语言中如何创建具有特定值约束的自定义类型。针对go语言缺乏操作符重载的特点,文章详细介绍了两种主要实现方法:一是通过带有验证逻辑的构造函数配合结构体类型,确保类型实例在创建时即满足条件;二是通过为基础类型定义验证方法,实现对值的运行时检查。这两种方法各有侧重,旨在帮助开发者根据实际需求…

    2025年12月16日
    000

发表回复

登录后才能评论
关注微信