Skip to content

Pytorch数据加载 Data Loading

GitHub stars

  1. PyTorch 中的数据处理概述
  2. 使用 torch.utils.data.Datasettorch.utils.data.DataLoader
  3. 利用 torchvision 提供的常用数据集
  4. 数据变换(Transforms)
  5. 创建自定义数据集
  6. 数据加载的实用示例
  7. 常见问题与注意事项

一、PyTorch 中的数据处理概述

在机器学习和深度学习中,数据处理是整个流程的关键步骤。PyTorch 提供了灵活而高效的工具来处理和加载数据,使得训练过程更加便捷和高效。主要组件包括:

  • Dataset:表示一个数据集,负责数据的读取和预处理。
  • DataLoader:用于将 Dataset 封装成可以迭代的批次,支持多线程加载、批次生成、打乱数据等功能。

二、使用 torch.utils.data.Datasettorch.utils.data.DataLoader

1. torch.utils.data.Dataset

Dataset 是一个抽象类,所有自定义的数据集都应继承自该类,并实现以下两个方法:

  • __len__:返回数据集中样本的数量。
  • __getitem__:根据索引返回数据集中的一个样本。

PyTorch 提供了一些现成的 Dataset 类,如 torchvision.datasets.MNIST,但你也可以根据需要创建自定义的数据集。

示例:使用内置的 MNIST 数据集

import torch
from torchvision import datasets, transforms

# 定义数据转换(可选)
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图片转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化
])

# 下载并加载训练数据
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

print(f'训练数据集的大小: {len(train_dataset)}')
print(f'第一个样本的形状: {train_dataset[0][0].shape}')
print(f'第一个样本的标签: {train_dataset[0][1]}')

输出示例:

训练数据集的大小: 60000
第一个样本的形状: torch.Size([1, 28, 28])
第一个样本的标签: 5

2. torch.utils.data.DataLoader

DataLoader 是一个迭代器,主要功能包括:

  • 批量加载数据:将数据集划分为多个批次。
  • 打乱数据:每个epoch打乱数据顺序,有助于提升模型性能。
  • 并行加载:通过多线程加载数据,加快数据读取速度。

示例:创建 DataLoader

from torch.utils.data import DataLoader

# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2)

# 迭代 DataLoader
for images, labels in train_loader:
    print(f'批次图像的形状: {images.shape}')
    print(f'批次标签的形状: {labels.shape}')
    break  # 仅显示第一个批次

输出示例:

批次图像的形状: torch.Size([64, 1, 28, 28])
批次标签的形状: torch.Size([64])

3. 结合使用

结合 DatasetDataLoader,你可以便捷地管理和加载数据。

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST 数据集的均值和标准差
])

# 加载训练和测试数据
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False, num_workers=2)

三、利用 torchvision 提供的常用数据集

torchvision 是 PyTorch 的一个子库,专门用于计算机视觉任务。它提供了许多常用的图像数据集和预训练模型。

1. 常用数据集

  • MNIST:手写数字识别数据集。
  • CIFAR-10:10类彩色图像数据集。
  • ImageNet:大规模图像分类数据集。
  • FashionMNIST:类似 MNIST,但包含服装类别。
  • COCOVOC 等:用于目标检测和分割。

2. 示例:加载 CIFAR-10 数据集

from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),  # CIFAR-10 的均值
                         (0.2023, 0.1994, 0.2010))  # CIFAR-10 的标准差
])

# 加载训练和测试数据
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=4)

3. 自定义数据集

当内置的数据集不满足需求时,可以创建自定义的数据集,具体方法将在后续章节详细介绍。


四、数据变换(Transforms)

数据预处理和增强对于提高模型性能至关重要。torchvision.transforms 提供了丰富的数据变换功能。

常用的变换操作

  • 基本变换
  • transforms.ToTensor():将 PIL 图像或 NumPy 数组转换为张量。
  • transforms.Normalize(mean, std):对张量进行标准化。

  • 几何变换

  • transforms.Resize(size):调整图像大小。
  • transforms.CenterCrop(size):中心裁剪。
  • transforms.RandomCrop(size):随机裁剪。
  • transforms.RandomHorizontalFlip(p):随机水平翻转。

  • 颜色变换

  • transforms.ColorJitter(brightness, contrast, saturation, hue):随机调整亮度、对比度、饱和度和色调。

  • 其他变换

  • transforms.RandomRotation(degrees):随机旋转图像。
  • transforms.RandomAffine(degrees, translate, scale, shear):随机仿射变换。
  • transforms.Grayscale(num_output_channels):将图像转换为灰度图。

示例:应用数据增强

from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# 定义数据转换,包括数据增强
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomCrop(32, padding=4),  # 随机裁剪
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.2023, 0.1994, 0.2010))
])

# 加载 CIFAR-10 训练数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4)

五、创建自定义数据集

当内置的数据集无法满足特定需求时,可以创建自定义数据集。自定义数据集需要继承 torch.utils.data.Dataset 并实现 __len____getitem__ 方法。

示例:创建一个自定义数据集

假设你有一组图像和对应的标签,存储在特定的文件夹中。

1. 项目结构

custom_dataset/
├── images/
│   ├── img1.jpg
│   ├── img2.jpg
│   └── ...
└── labels.csv

labels.csv 内容示例:

filename,label
img1.jpg,0
img2.jpg,1
...

2. 实现自定义 Dataset

import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        """
        Args:
            annotations_file (string): 注释文件路径(如 CSV 文件)。
            img_dir (string): 图像文件夹路径。
            transform (callable, optional): 应用于图像的变换。
            target_transform (callable, optional): 应用于标签的变换。
        """
        self.annotations = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.img_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')  # 确保图像是 RGB
        label = self.annotations.iloc[idx, 1]

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 创建自定义数据集实例
dataset = CustomImageDataset(annotations_file='custom_dataset/labels.csv',
                             img_dir='custom_dataset/images',
                             transform=transform)

# 创建 DataLoader
data_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)

# 迭代 DataLoader
for images, labels in data_loader:
    print(f'批次图像的形状: {images.shape}')
    print(f'批次标签的形状: {labels.shape}')
    break

输出示例:

批次图像的形状: torch.Size([32, 3, 128, 128])
批次标签的形状: torch.Size([32])

注意事项

  • 数据索引:确保在 __getitem__ 方法中正确索引数据。
  • 数据类型:图像应转换为 RGB 或合适的格式,避免加载灰度图引发的错误。
  • 错误处理:处理缺失或损坏的图像文件,避免在训练时崩溃。

六、数据加载的实用示例

下面我们将结合前面的内容,展示一个完整的数据加载流程,包括使用内置数据集、应用数据变换以及自定义数据集。

示例 1:加载并迭代 MNIST 数据集

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载训练和测试数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False, num_workers=2)

# 迭代训练数据
for epoch in range(1):
    for batch_idx, (data, target) in enumerate(train_loader):
        print(f'批次 {batch_idx+1}')
        print(f'数据形状: {data.shape}')
        print(f'标签形状: {target.shape}')
        if batch_idx == 0:
            break  # 仅显示第一个批次

输出示例:

批次 1
数据形状: torch.Size([64, 1, 28, 28])
标签形状: torch.Size([64])

示例 2:使用 CIFAR-10 和数据增强

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义数据转换,包括数据增强
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.2023, 0.1994, 0.2010))
])

# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=4)

# 迭代训练数据
for epoch in range(1):
    for batch_idx, (data, target) in enumerate(train_loader):
        print(f'批次 {batch_idx+1}')
        print(f'数据形状: {data.shape}')
        print(f'标签形状: {target.shape}')
        if batch_idx == 0:
            break  # 仅显示第一个批次

输出示例:

批次 1
数据形状: torch.Size([128, 3, 32, 32])
标签形状: torch.Size([128])

示例 3:使用自定义数据集

假设你有一个自定义的图像数据集,结构如下:

custom_dataset/
├── images/
│   ├── cat001.jpg
│   ├── cat002.jpg
│   ├── dog001.jpg
│   └── dog002.jpg
└── labels.csv

labels.csv 内容示例:

filename,label
cat001.jpg,0
cat002.jpg,0
dog001.jpg,1
dog002.jpg,1

实现并加载自定义数据集

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.annotations = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.img_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        label = self.annotations.iloc[idx, 1]

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 创建自定义数据集实例
dataset = CustomImageDataset(annotations_file='custom_dataset/labels.csv',
                             img_dir='custom_dataset/images',
                             transform=transform)

# 创建 DataLoader
data_loader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, num_workers=1)

# 迭代数据
for images, labels in data_loader:
    print(f'数据形状: {images.shape}')
    print(f'标签: {labels}')
    break  # 仅显示第一个批次

输出示例:

数据形状: torch.Size([2, 3, 128, 128])
标签: tensor([1, 0])

七、常见问题与注意事项

1. num_workers 参数

  • 含义num_workers 指定了用于数据加载的子进程数量。
  • 选择策略
  • 小数据集或计算能力有限的系统可以设置较低的值(如 2)。
  • 大数据集和多核 CPU 系统可以增大此值(如 4 或更高)。
  • 如果遇到 too many open files 错误,可以减少 num_workers

2. pin_memory 参数

对于使用 GPU 的情况,设置 pin_memory=True 可以加快数据传输速度。

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)

3. 数据不平衡

如果数据集中某些类别的数据较少,可能导致模型偏向于数量多的类别。可通过以下方法缓解:

  • 过采样:增加少数类的数据量。
  • 欠采样:减少多数类的数据量。
  • 使用权重:在损失函数中为不同类别设置不同的权重。

4. 数据增强

适当的数据增强可以提升模型的泛化能力,但过度的数据增强可能引入噪声,影响模型性能。

5. 自定义数据集的效率

确保自定义数据集的 __getitem__ 方法高效,避免在数据加载过程中进行过多的计算或 I/O 操作,以免成为训练的瓶颈。


八、完整的数据加载实战示例

为了帮助你更好地理解,下面是一个结合内置数据集、数据变换以及 DataLoader 的完整示例:

目标

  • 使用 CIFAR-10 数据集。
  • 应用数据增强技术。
  • 创建训练和测试的 DataLoader。
  • 迭代数据并查看每个批次的形状。

实现步骤

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 定义训练数据的变换(包含数据增强)
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomCrop(32, padding=4),  # 随机裁剪
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.2023, 0.1994, 0.2010))
])

# 定义测试数据的变换(不包含数据增强)
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.2023, 0.1994, 0.2010))
])

# 加载 CIFAR-10 训练和测试数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

# 查看一个训练批次
for batch_idx, (data, targets) in enumerate(train_loader):
    print(f'批次 {batch_idx+1}')
    print(f'数据形状: {data.shape}')  # 应为 [128, 3, 32, 32]
    print(f'标签形状: {targets.shape}')  # 应为 [128]
    break  # 仅显示第一个批次

# 输出示例:
# 批次 1
# 数据形状: torch.Size([128, 3, 32, 32])
# 标签形状: torch.Size([128])

九、总结

在 PyTorch 中,数据加载和预处理是训练模型的重要部分。通过合理使用 DatasetDataLoader,结合数据变换和增强技术,可以高效地管理和处理各种类型的数据。此外,灵活地创建自定义数据集,使得 PyTorch 能够适应几乎所有的数据来源和格式。

关键点回顾

  • Dataset:表示一个数据集,负责数据的读取和预处理。
  • DataLoader:将 Dataset 封装成可迭代的批次,支持并行加载、打乱数据等。
  • Transforms:用于数据预处理和增强,提升模型的泛化能力。
  • 自定义 Dataset:根据需求创建符合特定要求的数据集。