Pytorch数据加载 Data Loading
- PyTorch 中的数据处理概述
- 使用
torch.utils.data.Dataset
和torch.utils.data.DataLoader
- 利用
torchvision
提供的常用数据集 - 数据变换(Transforms)
- 创建自定义数据集
- 数据加载的实用示例
- 常见问题与注意事项
一、PyTorch 中的数据处理概述¶
在机器学习和深度学习中,数据处理是整个流程的关键步骤。PyTorch 提供了灵活而高效的工具来处理和加载数据,使得训练过程更加便捷和高效。主要组件包括:
- Dataset:表示一个数据集,负责数据的读取和预处理。
- DataLoader:用于将 Dataset 封装成可以迭代的批次,支持多线程加载、批次生成、打乱数据等功能。
二、使用 torch.utils.data.Dataset
和 torch.utils.data.DataLoader
¶
1. torch.utils.data.Dataset
¶
Dataset
是一个抽象类,所有自定义的数据集都应继承自该类,并实现以下两个方法:
__len__
:返回数据集中样本的数量。__getitem__
:根据索引返回数据集中的一个样本。
PyTorch 提供了一些现成的 Dataset
类,如 torchvision.datasets.MNIST
,但你也可以根据需要创建自定义的数据集。
示例:使用内置的 MNIST 数据集¶
import torch
from torchvision import datasets, transforms
# 定义数据转换(可选)
transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为张量
transforms.Normalize((0.5,), (0.5,)) # 标准化
])
# 下载并加载训练数据
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
print(f'训练数据集的大小: {len(train_dataset)}')
print(f'第一个样本的形状: {train_dataset[0][0].shape}')
print(f'第一个样本的标签: {train_dataset[0][1]}')
输出示例:¶
训练数据集的大小: 60000
第一个样本的形状: torch.Size([1, 28, 28])
第一个样本的标签: 5
2. torch.utils.data.DataLoader
¶
DataLoader
是一个迭代器,主要功能包括:
- 批量加载数据:将数据集划分为多个批次。
- 打乱数据:每个epoch打乱数据顺序,有助于提升模型性能。
- 并行加载:通过多线程加载数据,加快数据读取速度。
示例:创建 DataLoader¶
from torch.utils.data import DataLoader
# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2)
# 迭代 DataLoader
for images, labels in train_loader:
print(f'批次图像的形状: {images.shape}')
print(f'批次标签的形状: {labels.shape}')
break # 仅显示第一个批次
输出示例:¶
批次图像的形状: torch.Size([64, 1, 28, 28])
批次标签的形状: torch.Size([64])
3. 结合使用¶
结合 Dataset
和 DataLoader
,你可以便捷地管理和加载数据。
# 定义数据转换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST 数据集的均值和标准差
])
# 加载训练和测试数据
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False, num_workers=2)
三、利用 torchvision
提供的常用数据集¶
torchvision
是 PyTorch 的一个子库,专门用于计算机视觉任务。它提供了许多常用的图像数据集和预训练模型。
1. 常用数据集¶
- MNIST:手写数字识别数据集。
- CIFAR-10:10类彩色图像数据集。
- ImageNet:大规模图像分类数据集。
- FashionMNIST:类似 MNIST,但包含服装类别。
- COCO、VOC 等:用于目标检测和分割。
2. 示例:加载 CIFAR-10 数据集¶
from torchvision import datasets, transforms
# 定义数据转换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), # CIFAR-10 的均值
(0.2023, 0.1994, 0.2010)) # CIFAR-10 的标准差
])
# 加载训练和测试数据
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=4)
3. 自定义数据集¶
当内置的数据集不满足需求时,可以创建自定义的数据集,具体方法将在后续章节详细介绍。
四、数据变换(Transforms)¶
数据预处理和增强对于提高模型性能至关重要。torchvision.transforms
提供了丰富的数据变换功能。
常用的变换操作¶
- 基本变换:
transforms.ToTensor()
:将 PIL 图像或 NumPy 数组转换为张量。-
transforms.Normalize(mean, std)
:对张量进行标准化。 -
几何变换:
transforms.Resize(size)
:调整图像大小。transforms.CenterCrop(size)
:中心裁剪。transforms.RandomCrop(size)
:随机裁剪。-
transforms.RandomHorizontalFlip(p)
:随机水平翻转。 -
颜色变换:
-
transforms.ColorJitter(brightness, contrast, saturation, hue)
:随机调整亮度、对比度、饱和度和色调。 -
其他变换:
transforms.RandomRotation(degrees)
:随机旋转图像。transforms.RandomAffine(degrees, translate, scale, shear)
:随机仿射变换。transforms.Grayscale(num_output_channels)
:将图像转换为灰度图。
示例:应用数据增强¶
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
# 定义数据转换,包括数据增强
transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
# 加载 CIFAR-10 训练数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4)
五、创建自定义数据集¶
当内置的数据集无法满足特定需求时,可以创建自定义数据集。自定义数据集需要继承 torch.utils.data.Dataset
并实现 __len__
和 __getitem__
方法。
示例:创建一个自定义数据集¶
假设你有一组图像和对应的标签,存储在特定的文件夹中。
1. 项目结构¶
custom_dataset/
├── images/
│ ├── img1.jpg
│ ├── img2.jpg
│ └── ...
└── labels.csv
labels.csv
内容示例:
filename,label
img1.jpg,0
img2.jpg,1
...
2. 实现自定义 Dataset
¶
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
"""
Args:
annotations_file (string): 注释文件路径(如 CSV 文件)。
img_dir (string): 图像文件夹路径。
transform (callable, optional): 应用于图像的变换。
target_transform (callable, optional): 应用于标签的变换。
"""
self.annotations = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.img_dir, self.annotations.iloc[idx, 0])
image = Image.open(img_name).convert('RGB') # 确保图像是 RGB
label = self.annotations.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
# 定义数据转换
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 创建自定义数据集实例
dataset = CustomImageDataset(annotations_file='custom_dataset/labels.csv',
img_dir='custom_dataset/images',
transform=transform)
# 创建 DataLoader
data_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
# 迭代 DataLoader
for images, labels in data_loader:
print(f'批次图像的形状: {images.shape}')
print(f'批次标签的形状: {labels.shape}')
break
输出示例:¶
批次图像的形状: torch.Size([32, 3, 128, 128])
批次标签的形状: torch.Size([32])
注意事项¶
- 数据索引:确保在
__getitem__
方法中正确索引数据。 - 数据类型:图像应转换为
RGB
或合适的格式,避免加载灰度图引发的错误。 - 错误处理:处理缺失或损坏的图像文件,避免在训练时崩溃。
六、数据加载的实用示例¶
下面我们将结合前面的内容,展示一个完整的数据加载流程,包括使用内置数据集、应用数据变换以及自定义数据集。
示例 1:加载并迭代 MNIST 数据集¶
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义数据转换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载训练和测试数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False, num_workers=2)
# 迭代训练数据
for epoch in range(1):
for batch_idx, (data, target) in enumerate(train_loader):
print(f'批次 {batch_idx+1}')
print(f'数据形状: {data.shape}')
print(f'标签形状: {target.shape}')
if batch_idx == 0:
break # 仅显示第一个批次
输出示例:¶
批次 1
数据形状: torch.Size([64, 1, 28, 28])
标签形状: torch.Size([64])
示例 2:使用 CIFAR-10 和数据增强¶
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义数据转换,包括数据增强
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=4)
# 迭代训练数据
for epoch in range(1):
for batch_idx, (data, target) in enumerate(train_loader):
print(f'批次 {batch_idx+1}')
print(f'数据形状: {data.shape}')
print(f'标签形状: {target.shape}')
if batch_idx == 0:
break # 仅显示第一个批次
输出示例:¶
批次 1
数据形状: torch.Size([128, 3, 32, 32])
标签形状: torch.Size([128])
示例 3:使用自定义数据集¶
假设你有一个自定义的图像数据集,结构如下:
custom_dataset/
├── images/
│ ├── cat001.jpg
│ ├── cat002.jpg
│ ├── dog001.jpg
│ └── dog002.jpg
└── labels.csv
labels.csv
内容示例:
filename,label
cat001.jpg,0
cat002.jpg,0
dog001.jpg,1
dog002.jpg,1
实现并加载自定义数据集¶
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.annotations = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.img_dir, self.annotations.iloc[idx, 0])
image = Image.open(img_name).convert('RGB')
label = self.annotations.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
# 定义数据转换
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 创建自定义数据集实例
dataset = CustomImageDataset(annotations_file='custom_dataset/labels.csv',
img_dir='custom_dataset/images',
transform=transform)
# 创建 DataLoader
data_loader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, num_workers=1)
# 迭代数据
for images, labels in data_loader:
print(f'数据形状: {images.shape}')
print(f'标签: {labels}')
break # 仅显示第一个批次
输出示例:¶
数据形状: torch.Size([2, 3, 128, 128])
标签: tensor([1, 0])
七、常见问题与注意事项¶
1. num_workers
参数¶
- 含义:
num_workers
指定了用于数据加载的子进程数量。 - 选择策略:
- 小数据集或计算能力有限的系统可以设置较低的值(如 2)。
- 大数据集和多核 CPU 系统可以增大此值(如 4 或更高)。
- 如果遇到
too many open files
错误,可以减少num_workers
。
2. pin_memory
参数¶
对于使用 GPU 的情况,设置 pin_memory=True
可以加快数据传输速度。
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
3. 数据不平衡¶
如果数据集中某些类别的数据较少,可能导致模型偏向于数量多的类别。可通过以下方法缓解:
- 过采样:增加少数类的数据量。
- 欠采样:减少多数类的数据量。
- 使用权重:在损失函数中为不同类别设置不同的权重。
4. 数据增强¶
适当的数据增强可以提升模型的泛化能力,但过度的数据增强可能引入噪声,影响模型性能。
5. 自定义数据集的效率¶
确保自定义数据集的 __getitem__
方法高效,避免在数据加载过程中进行过多的计算或 I/O 操作,以免成为训练的瓶颈。
八、完整的数据加载实战示例¶
为了帮助你更好地理解,下面是一个结合内置数据集、数据变换以及 DataLoader 的完整示例:
目标¶
- 使用 CIFAR-10 数据集。
- 应用数据增强技术。
- 创建训练和测试的 DataLoader。
- 迭代数据并查看每个批次的形状。
实现步骤¶
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 定义训练数据的变换(包含数据增强)
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
# 定义测试数据的变换(不包含数据增强)
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
# 加载 CIFAR-10 训练和测试数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
# 查看一个训练批次
for batch_idx, (data, targets) in enumerate(train_loader):
print(f'批次 {batch_idx+1}')
print(f'数据形状: {data.shape}') # 应为 [128, 3, 32, 32]
print(f'标签形状: {targets.shape}') # 应为 [128]
break # 仅显示第一个批次
# 输出示例:
# 批次 1
# 数据形状: torch.Size([128, 3, 32, 32])
# 标签形状: torch.Size([128])
九、总结¶
在 PyTorch 中,数据加载和预处理是训练模型的重要部分。通过合理使用 Dataset
和 DataLoader
,结合数据变换和增强技术,可以高效地管理和处理各种类型的数据。此外,灵活地创建自定义数据集,使得 PyTorch 能够适应几乎所有的数据来源和格式。
关键点回顾¶
- Dataset:表示一个数据集,负责数据的读取和预处理。
- DataLoader:将
Dataset
封装成可迭代的批次,支持并行加载、打乱数据等。 - Transforms:用于数据预处理和增强,提升模型的泛化能力。
- 自定义 Dataset:根据需求创建符合特定要求的数据集。