一、DataLoader是什么
DataLoader
是 PyTorch 中的一个核心组件,用于管理和加载数据集,以便在训练和评估模型时以批次进行高效的数据输入。它提供了处理各类数据输入的便利方式,支持数据的随机化和并行载入,从而提高训练效率。
二、DataLoader
的主要功能
批处理数据(Batching)
DataLoader
能够将数据集中的数据按批次加载,为训练过程提供一个稳定的数据流。通过设置 batch_size
参数,可以定义每次载入的数据样本数量,这对于 GPU 加速训练非常重要。随机化数据(Shuffling)
- 可以通过设置
shuffle=True
来随机化数据顺序,有助于打乱数据集以减少训练过程中的过拟合。并行数据加载(Parallel Data Loading)
- 通过
num_workers
参数,可以指定子进程的数量以并行执行数据加载。并行加载减少了训练过程中 CPU 数据管理的瓶颈。数据迭代器(Iterator)
DataLoader
提供数据迭代器接口,使得数据可以像 Python 的迭代器一样被顺序读取和处理。三、如何使用 DataLoader
下面是一个使用 DataLoader
的基本示例,展示了如何载入一个图像数据集:
import torchfrom torch.utils.data import DataLoaderfrom torchvision import datasets, transforms# 定义数据转换:图片转换为张量并归一化transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 加载数据集,以 MNIST 数据集为例train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)# 初始化 DataLoadertrain_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=4)# 数据迭代示例for batch_idx, (data, targets) in enumerate(train_loader): # 在这里可以对每个批次的数据进行处理 print(data.size(), targets.size())
四、重要参数
dataset:
- 实例化的 Dataset 对象,从中加载数据。
batch_size:
- 每个批次加载的数据样本数量。
shuffle:
- 是否对数据进行随机化。
num_workers:
- 用于数据加载的子进程数量。注意:在 Windows 系统上,设
num_workers
为 0 或者使用 multiprocessing
方法来处理数据并行问题。drop_last:
- 如果数据集样本数量不能被
batch_size
整除,设置为 True 则丢弃最后那个不完整的批数据。五、高级功能
动态采样:
- 可以通过自定义的
Sampler
来动态调整数据采样方式,例如根据类别平衡数据等。数据集分割:
- 可以使用
SubsetRandomSampler
将数据集分割为训练集、验证集等。六、服务器实操
import torchvisionfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWritertest_data = torchvision.datasets.CIFAR10("pytorchstu/dataset",train=False,transform=torchvision.transforms.ToTensor())test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)writer = SummaryWriter("dataloader")step = 0for data in test_loader: imgs , targets = data print(imgs.shape) print(targets) # 使用add_images来处理整个批次,或者循环处理每个图像 writer.add_images("test_data", imgs, step) step += 1writer.close()
截图:
ps:shuttle为True代表每个epoch训练图片的顺序是不同的,一个epoch代表遍历完所有图片。