【论文复现】基于 PaddlePaddle 实现 GreedyHash

本文基于PaddlePaddle复现GreedyHash算法,解决图像检索中NP优化难题。在CIFAR-10 (I)数据集上,12/24/32/48bits模型精度达0.798、0.809、0.817、0.819(最高0.824),优于原论文及PyTorch重跑结果,含完整代码与权重。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

【论文复现】基于 paddlepaddle 实现 greedyhash - 创想鸟

【论文复现-图像分类检索】基于 PaddlePaddle 实现 GreedyHash(NeurIPS2018)

原论文:Greedy Hash: Towards Fast Optimization for Accurate Hash Coding in CNN.

官方原版代码(基于PyTorch)GreedyHash.

第三方参考代码(基于PyTorch)DeepHash-pytorch.

本项目GitHub repo paddle_greedyhash

1. 简介

GreedyHash 意在解决图像检索 Deep Hashing 领域中NP优化难的问题,为此,作者在每次迭代中向可能的最优离散解迭代式更新网络参数。具体来说,GreedyHash 在网络模型中加入了一个哈希编码层,在前向传播过程中为了保持离散的限制条件,严格利用sign函数。在反向传播过程中,梯度完整地传向前一层,进而可以避免梯度弥散现象。算法伪代码如下。

【论文复现】基于 PaddlePaddle 实现 GreedyHash - 创想鸟

GreedyHash 算法伪代码

2. 数据集和复现精度

数据集:cifar-1(即CIFAR-10 (I))

CIFAR-10 数据集共10类,由 60,000 个 32×32 的彩色图像组成。

CIFAR-10 (I)中,选择 1000 张图像(每类 100 张图像)作为查询集,其余 59,000 张图像作为数据库, 而从数据库中随机采样 5,000 张图像(每类 500 张图像)作为训练集。数据集处理代码详见 utils/datasets.py。

复现精度

Framework 12bits 24bits 32bits 48bits

论文结果PyTorch0.7740.7950.8100.822重跑结果PyTorch0.7890.7990.8130.824复现结果PaddlePaddle0.7980.8090.8170.819(0.824)

需要注意的是,此处在重跑PyTorch版本代码时发现原论文代码 GreedyHash/cifar1.py 由于PyTorch版本较老,CIFAR-10 数据集处理部分代码无法运行,遂将第三方参考代码 DeepHash-pytorch 中的 CIFAR-10 数据集处理部分代码照搬运行,得以重跑PyTorch版本代码,结果罗列如上。严谨起见,已将修改后的PyTorch版本代码及训练日志放在 pytorch_greedyhash/main.py 和 pytorch_greedyhash/logs 中。因为跑的时候忘记设置随机数种子了,复现的时候可能结果有所偏差,不过应该都在可允许范围内,问题不大。

本项目(基于 PaddlePaddle )依次跑 12/24/32/48 bits 的结果罗列在上表中,且已将训练得到的模型参数与训练日志 log 存放于output文件夹下。由于训练时设置了随机数种子,理论上是可复现的。但在反复重跑几次发现结果还是会有波动,比如有1次 48bits 的模型跑到了 0.824,我把对应的 log 和权重放在 output/bit48_alone 路径下了,说明算法的随机性仍然存在。

3. 准备环境

本人环境配置:

Python: 3.7.11

PaddlePaddle: 2.2.2

硬件:NVIDIA 2080Ti * 1

飞桨PaddlePaddle 飞桨PaddlePaddle

飞桨PaddlePaddle开发者社区与布道,与社区共同进步

飞桨PaddlePaddle 12 查看详情 飞桨PaddlePaddle

p.s. 因为数据集很小,所以放单卡机器上跑了,多卡的代码可能后续补上

4. 快速开始

step1: 下载本项目及训练权重

本项目在AI Studio上,您可以选择fork下来直接运行。首先,cd到paddle_greedyhash项目文件夹下:

In [ ]

cd paddle_greedyhash
/home/aistudio/paddle_greedyhash

或者,您也可以从GitHub上git本repo在本地运行:

git clone https://github.com/hatimwen/paddle_greedyhash.gitcd paddle_greedyhash

权重部分:

由于权重比较多,加起来有 1 个 GB ,因此我放到百度网盘里了,烦请下载后按照 5. 项目结构 排列各个权重文件。或者您也可以按照下载某个bit位数的权重以测试相应性能。

下载链接:BaiduNetdisk, 提取码: tl1i 。

注意:在AI Studio上,已上传了 bit_48.pdparams 权重文件在 output 路径下,方便体验。

step2: 修改参数

请根据实际情况,修改main.py中的 arguments 配置内容(如:batch_size等)。

step3: 验证模型

需要提前下载并排列好 BaiduNetdisk 中的各个预训练模型。

注意:在AI Studio上,由于已预先上传bit_48.pdparams 权重文件,因此可以直接运行:

In [ ]

# 验证模型! python eval.py --batch-size 32 --bit 48
W0427 21:33:47.931723   449 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1W0427 21:33:47.935976   449 device_context.cc:465] device: 0, cuDNN Version: 7.6.Loading AlexNet state from path: /home/aistudio/paddle_greedyhash/models/AlexNet_pretrained.pdparams0427 09:33:53 PM Namespace(batch_size=32, bit=48, crop_size=224, dataset='cifar10-1', log_path='logs/', model='GreedyHash', n_class=10, pretrained=None, seed=2000, topK=-1)0427 09:33:53 PM ----- Pretrained: Load model state from output/bit_48.pdparams--- Calculating Acc : 100%|█████████████████████| 32/32 [00:02<00:00, 13.36it/s]--- Compressing(train) : 100%|██████████████| 1844/1844 [01:42<00:00, 17.97it/s]--- Compressing(test) : 100%|███████████████████| 32/32 [00:02<00:00, 13.89it/s]--- Calculating mAP : 100%|█████████████████| 1000/1000 [01:23<00:00, 11.94it/s]0427 09:37:06 PM EVAL-GreedyHash, bit:48, dataset:cifar10-1, MAP:0.819

step4: 训练模型

例如要训练 12bits 的模型,可以运行:In [4]

# 训练模型! python train.py --batch-size 32 --learning_rate 1e-3 --seed 2000 --bit 12# 这里记录是看运行没问题就中断了。
W0427 21:38:07.032394   780 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1W0427 21:38:07.036984   780 device_context.cc:465] device: 0, cuDNN Version: 7.6.Loading AlexNet state from path: /home/aistudio/paddle_greedyhash/models/AlexNet_pretrained.pdparams0427 09:38:12 PM Namespace(alpha=0.1, batch_size=32, bit=12, crop_size=224, dataset='cifar10-1', epoch=50, epoch_lr_decrease=30, eval_epoch=2, learning_rate=0.001, log_path='logs/', model='GreedyHash', momentum=0.9, n_class=10, num_train=5000, optimizer='SGD', output_dir='checkpoints/', seed=2000, topK=-1, weight_decay=0.0005)0427 09:38:22 PM GreedyHash[ 1/50][21:38:22] bit:12, lr:0.001000000, dataset:cifar10-1, train loss:1.9040427 09:38:31 PM GreedyHash[ 2/50][21:38:31] bit:12, lr:0.001000000, dataset:cifar10-1, train loss:1.574--- Calculating Acc : 100%|█████████████████████| 32/32 [00:02<00:00, 13.48it/s]--- Compressing(train) : 100%|██████████████| 1844/1844 [01:46<00:00, 17.28it/s]--- Compressing(test) : 100%|███████████████████| 32/32 [00:02<00:00, 13.81it/s]--- Calculating mAP : 100%|█████████████████| 1000/1000 [01:14<00:00, 13.39it/s]0427 09:41:39 PM save in checkpoints/model_best_120427 09:41:40 PM GreedyHash epoch:2, bit:12, dataset:cifar10-1, MAP:0.614, Best MAP: 0.614, Acc: 77.0000427 09:41:51 PM GreedyHash[ 3/50][21:41:51] bit:12, lr:0.001000000, dataset:cifar10-1, train loss:1.3160427 09:42:00 PM GreedyHash[ 4/50][21:42:00] bit:12, lr:0.001000000, dataset:cifar10-1, train loss:1.120--- Calculating Acc : 100%|█████████████████████| 32/32 [00:02<00:00, 13.93it/s]--- Compressing(train) :  46%|██████▊        | 841/1844 [00:49<00:58, 17.28it/s]^CTraceback (most recent call last):  File "train.py", line 183, in     main()  File "train.py", line 180, in main    database_loader)  File "train.py", line 136, in train_val    mAP, acc = val(model, test_loader, database_loader)  File "train.py", line 81, in val    retrievalB, retrievalL, queryB, queryL = compress(database_loader, test_loader, model)  File "/home/aistudio/paddle_greedyhash/utils/tools.py", line 31, in compress    _,_, code = model(data)  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 917, in __call__    return self._dygraph_call_func(*inputs, **kwargs)  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 907, in _dygraph_call_func    outputs = self.forward(*inputs, **kwargs)  File "/home/aistudio/paddle_greedyhash/models/greedyhash.py", line 67, in forward    x = self.features(x)  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 917, in __call__    return self._dygraph_call_func(*inputs, **kwargs)  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 907, in _dygraph_call_func    outputs = self.forward(*inputs, **kwargs)  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/container.py", line 98, in forward    input = layer(input)  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 917, in __call__    return self._dygraph_call_func(*inputs, **kwargs)  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 907, in _dygraph_call_func    outputs = self.forward(*inputs, **kwargs)  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/conv.py", line 677, in forward    use_cudnn=self._use_cudnn)  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/functional/conv.py", line 123, in _conv_nd    pre_bias = getattr(_C_ops, op_type)(x, weight, *attrs)KeyboardInterrupt--- Compressing(train) :  46%|██████▊        | 841/1844 [00:49<00:58, 17.00it/s]

step5: 验证预测

【论文复现】基于 PaddlePaddle 实现 GreedyHash - 创想鸟

验证图片(类别:飞机 airplane, id: 0)

对于上面的图片,直接运行 predict.py 即可,这里拿 bit_48.pdparams 预测一下看看:In [5]

! python predict.py --bit 48 --pic_id 1949
W0427 21:43:31.814743  1416 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1W0427 21:43:31.819936  1416 device_context.cc:465] device: 0, cuDNN Version: 7.6.Loading AlexNet state from path: /home/aistudio/paddle_greedyhash/models/AlexNet_pretrained.pdparams----- Pretrained: Load model state from output/bit_48.pdparams----- Predicted Class_ID: 0, Prob: 0.9965014457702637, Real Label_ID: 0----- Predicted Class_NAME: 飞机 airplane, Real Class_NAME: 飞机 airplane

显然,预测结果正确。

七、代码结构与详细说明

|-- paddle_greedyhash    |-- output              # 日志及模型文件        |-- bit48_alone         # 偶然把bit48跑到了0.824,日志和权重存于此            |-- bit_48.pdparams     # bit48_alone的模型权重            |-- log_48.txt          # bit48_alone的训练日志        |-- bit_12.pdparams     # 12bits的模型权重        |-- bit_24.pdparams     # 24bits的模型权重        |-- bit_32.pdparams     # 32bits的模型权重        |-- bit_48.pdparams     # 48bits的模型权重        |-- log_eval.txt        # 用训练好的模型测试日志(包含bit48_alone)        |-- log_train.txt       # 依次训练 12/24/32/48 bits(不包含bit48_alone)    |-- models        |-- __init__.py        |-- alexnet.py      # AlexNet 定义,注意这里有略微有别于 paddle 集成的 AlexNet        |-- greedyhash.py   # GreedyHash 算法定义    |-- utils        |-- datasets.py         # dataset, dataloader, transforms        |-- lr_scheduler.py     # 学习率策略定义        |-- tools.py            # mAP, acc计算;随机数种子固定函数    |-- eval.py             # 单卡测试代码    |-- predict.py          # 预测演示代码    |-- train.py            # 单卡训练代码    |-- README.md    |-- pytorch_greedyhash        |-- datasets.py         # PyTorch 定义dataset, dataloader, transforms        |-- cal_map.py          # PyTorch mAP计算;        |-- main.py             # PyTorch 单卡训练代码        |-- output              # PyTorch 重跑日志

八、模型信息

关于模型的其他信息,可以参考下表:

信息 说明

发布者文洪涛Emailhatimwen@163.com时间2022.04框架版本Paddle 2.2.2应用场景图像检索支持硬件GPU、CPU下载链接预训练模型 提取码: tl1i在线运行AI StudioLicenseApache 2.0 license

九、参考及引用

@article{su2018greedy,  title={Greedy hash: Towards fast optimization for accurate hash coding in cnn},  author={Su, Shupeng and Zhang, Chao and Han, Kai and Tian, Yonghong},  year={2018},  journal={Advances in Neural Information Processing Systems},  volume={31},  year={2018}}

以上就是【论文复现】基于 PaddlePaddle 实现 GreedyHash的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/318541.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月5日 08:54:25
下一篇 2025年11月5日 08:56:32

相关推荐

  • C++机器人开发 ROS框架环境配置

    答案:配置ROS环境需选择匹配的ROS与Ubuntu版本,添加软件源和密钥,安装ros-desktop-full,初始化rosdep并配置环境变量,创建catkin工作空间,最后通过roscore测试;常见问题包括依赖、网络、环境变量和权限问题,可通过rosdep命令、网络代理、检查$ROS_PAC…

    好文分享 2025年12月18日
    000
  • C++函数返回指针 局部变量地址问题

    返回局部变量指针会导致未定义行为,因局部变量在函数结束时被销毁,指针指向已释放内存;正确做法包括返回堆内存指针(需手动释放)、静态变量地址或传入的有效指针,现代C++推荐使用智能指针或值返回避免内存问题。 在C++中,函数返回指针时,如果返回的是局部变量的地址,会引发严重的运行时错误或未定义行为。这…

    2025年12月18日
    000
  • C++金融回测环境 历史数据高速读取优化

    最优解是采用自定义二进制格式结合内存映射文件(mmap)和连续内存数据结构。首先,将历史数据以固定大小结构体(如包含时间戳、OHLCV的BarData)存储为二进制文件,避免文本解析开销;其次,使用mmap实现文件到虚拟地址空间的映射,利用操作系统预读和页缓存提升I/O效率;最后,在内存中通过std…

    2025年12月18日
    000
  • C++结构化绑定进阶 多返回值处理

    结构化绑定通过auto [var1, var2, …] = func();语法,直接解包pair、tuple或聚合类型,使多返回值处理更清晰;它提升代码可读性,简化错误处理与自定义类型协同,支持从标准库到私有封装类的灵活应用,显著优化函数调用表达力与维护性。 C++的结构化绑定(Stru…

    2025年12月18日
    000
  • C++计算机视觉 OpenCV库编译安装

    答案:编译安装OpenCV需先搭建环境,安装C++编译器、CMake及依赖库,Ubuntu下用apt-get安装必要组件;接着配置CMake生成Makefile,指定编译类型和安装路径;然后通过make -j4编译,sudo make install安装;之后配置环境变量,更新ldconfig并添加…

    2025年12月18日
    000
  • Linux Ubuntu系统下安装C++ build-essential工具包的命令是什么

    安装C++开发环境需先更新包列表并安装build-essential,该工具包包含gcc、g++、make等核心组件,用于编译和链接C++程序。通过编译Hello World程序可验证环境是否正常。若遇问题可更换软件源、修复依赖或重装;需特定GCC版本时可用apt安装指定版本并用update-alt…

    2025年12月18日
    000
  • C++程序的内存是如何分区的 比如栈、堆、全局区

    C++程序内存分为栈、堆、全局/静态区和代码区。栈用于函数调用和局部变量,由编译器自动管理,速度快但容量有限,过深递归或大局部数组易导致栈溢出。堆用于动态内存分配,通过new和delete手动管理,灵活性高但管理不当易引发内存泄漏或悬挂指针。全局/静态存储区存放全局变量和静态变量,程序启动时分配,结…

    2025年12月18日
    000
  • C++中如何使用指针实现多态和虚函数调用

    多态通过基类指针调用虚函数实现,需将基类函数声明为virtual,派生类重写该函数,运行时根据实际对象类型动态调用对应函数,实现多态;若使用纯虚函数则形成抽象基类,强制派生类实现该函数,且基类不可实例化;注意虚函数须通过指针或引用调用,析构函数应为虚以避免内存泄漏,且虚函数有轻微性能开销。 在C++…

    2025年12月18日
    000
  • C++异常安全拷贝 拷贝构造异常处理

    拷贝构造函数应提供强异常安全保证,确保操作全成功或全回滚;2. 使用“拷贝再交换”技术,将可能抛出的操作置于局部对象,成功后通过无抛出swap提交;3. 优先采用RAII容器如std::string,其默认拷贝构造已具强保证,减少资源管理风险。 在C++中,实现异常安全的拷贝构造函数是编写强异常安全…

    2025年12月18日
    000
  • C++ multiset容器 允许重复元素集合

    C++ multiset与set的核心区别在于multiset允许重复元素而set不允许,multiset适用于需自动排序且容纳重复值的场景,如统计频次或维护有序序列。 C++ std::multiset 容器是一个有序集合,它允许你存储重复的元素。它本质上是一个关联容器,所有元素都会根据其值自动排…

    2025年12月18日
    000
  • C++云开发 Docker容器环境配置

    配置C++云开发Docker容器需选择轻量基础镜像如Alpine或Ubuntu,安装g++、make等构建工具及云服务SDK(如AWS SDK for C++),通过多阶段构建优化镜像大小,使用.dockerignore减少冗余文件,合并RUN命令并清理缓存;为保障云服务凭证安全,应避免硬编码,推荐…

    2025年12月18日
    000
  • C++中new一个数组为什么要用delete[]来释放

    C++中new和new[]的核心区别在于:new用于单个对象的分配与构造,delete用于其释放;new[]用于对象数组的分配,会调用多个构造函数并存储元素数量,必须用delete[]释放以正确调用每个对象的析构函数并释放内存。若用delete释放new[]分配的数组,将导致未定义行为,可能引发内存…

    2025年12月18日
    000
  • C++关联容器性能 map和unordered_map对比

    map基于红黑树实现,元素有序,查找、插入、删除时间复杂度稳定为O(log n);unordered_map基于哈希表,元素无序,平均操作时间复杂度O(1),但最坏可达O(n)。unordered_map通常更快但内存开销大且性能受哈希影响,map更稳定且支持有序遍历,选择应根据是否需要顺序访问、性…

    2025年12月18日
    000
  • C++机器学习入门 线性回归实现示例

    首先实现线性回归模型,通过梯度下降最小化均方误差,代码包含数据准备、训练和预测,最终参数接近真实关系,适用于高性能场景。 想用C++实现线性回归,其实并不复杂。虽然Python在机器学习领域更常见,但C++凭借其高性能,在对效率要求高的场景中非常适用。下面是一个简单的线性回归实现示例,帮助你入门C+…

    2025年12月18日
    000
  • 如何在Docker容器中构建一个隔离的C++开发环境

    使用Docker构建C++开发环境可实现隔离、标准化和团队协作一致性。1. 选择基础镜像如ubuntu:latest并安装g++、cmake等工具链;2. 设置WORKDIR /app并复制源码;3. 构建项目并定义CMD运行可执行文件;4. 通过docker build和run创建容器;5. 利用…

    2025年12月18日
    000
  • 如何安全地使用C++指针来避免数组越界访问

    使用指针时应明确数组边界并检查索引,优先采用std::vector或std::array等标准库容器,利用其边界检查和大小管理特性避免越界访问,确保内存安全。 使用C++指针时,数组越界访问是常见且危险的问题,可能导致程序崩溃、数据损坏甚至安全漏洞。要安全地使用指针并避免越界,关键在于明确边界控制、…

    2025年12月18日
    000
  • C++内存管理原则 资源获取即初始化

    RAII通过对象生命周期管理资源,确保构造时获取、析构时释放,结合智能指针与自定义类,实现内存安全与异常安全,避免资源泄漏。 在C++中,内存管理是程序稳定性和性能的关键。一个核心原则是“资源获取即初始化”(Resource Acquisition Is Initialization,简称RAII)…

    2025年12月18日
    000
  • 如果C++程序忘记delete new出来的内存会发生什么

    内存泄漏指程序未释放不再使用的内存,导致内存占用持续增长,最终引发性能下降或崩溃。C++不自动回收内存是为了避免垃圾回收机制带来的性能开销,赋予程序员更高控制权。解决内存泄漏的核心是遵循RAII原则,优先使用智能指针(如std::unique_ptr、std::shared_ptr)管理资源,结合现…

    2025年12月18日
    000
  • C++自定义分配器 重载new运算符实例

    通过重载new和delete可实现自定义内存管理,如内存池。示例中MyClass重载类内new和delete,使用静态内存池分配对象,优先复用已释放空间,提升小对象频繁创建销毁时的性能,并通过静态数组管理内存使用状态。 在C++中,通过重载 new 和 delete 运算符,可以实现自定义内存管理策…

    2025年12月18日
    000
  • 不使用IDE如何用命令行编译和运行一个C++程序

    答案是使用命令行编译和运行C++程序需调用编译器(如g++)将源码编译为可执行文件并运行,例如g++ hello.cpp -o hello生成可执行文件,./hello运行程序;对于多文件项目需包含所有.cpp文件,使用-I指定头文件路径,-L和-l链接库;通过Makefile或CMake自动化管理…

    2025年12月18日
    000

发表回复

登录后才能评论
关注微信