
提升PyTorch DataLoader效率:避免重复实例化
在PyTorch深度学习训练中,高效的数据加载至关重要。 反复创建DataLoader实例会导致进程池的重复创建和销毁,严重影响训练速度。本文介绍如何复用DataLoader,避免这种低效的重复实例化操作。
问题:许多代码在每次迭代中都重新创建DataLoader:DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)。 这会造成性能瓶颈,因为DataLoader初始化需要创建进程池,频繁地创建和销毁进程池会消耗大量资源。
解决方案:将DataLoader的创建移至训练循环之外。 只需在训练开始前创建一次DataLoader实例,并在训练循环中重复使用它即可。 以下代码演示了改进后的方法:
import torchfrom torch.utils.data import DataLoader, Datasetfrom math import sqrtfrom typing import List, Tuple, Unionfrom numpy import ndarrayfrom PIL import Imagefrom torchvision import transformspreprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )])class PreprocessImageDataset(Dataset): def __init__(self, images: Union[List[ndarray], Tuple[ndarray]]): self.images = images def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] image = Image.fromarray(image) preprocessed_image: torch.Tensor = preprocess(image) unsqueezed_image = preprocessed_image return unsqueezed_imageif __name__=='__main__': data = list(range(10000000)) batch_size = 10 num_workers = 16 dataset = PreprocessImageDataset(data) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) for epoch in range(5): print(f"Epoch {epoch + 1}:") for batch_data in dataloader: batch_data print("Batch data:", batch_data) print("Batch data type :", type(batch_data)) print("Batch data shape:", batch_data.shape)
通过将DataLoader的实例化放在循环外,并在多个epoch中复用同一个实例,我们避免了重复创建进程池,显著提高了数据加载效率,减少了系统开销,从而提升了训练性能。
以上就是PyTorch DataLoader 如何避免重复实例化以提升训练效率?的详细内容,更多请关注创想鸟其它相关文章!
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/1359304.html
微信扫一扫
支付宝扫一扫