PyTorch Transforms
一、什么是 transforms
?
transforms
是 PyTorch 中用来处理和预处理图像数据的一组工具。它帮助我们在训练机器学习模型之前,对数据进行规范化和增强处理。能将数据从原始格式转换为更适合模型训练的格式。常见的图像数据变换操作包括:
ToTensor:
- 将图像从普通格式(例如像素值在 0-255)转换为 PyTorch 能处理的
Tensor
格式,并归一化到 0 到 1 之间。例如,彩色图像被转换为一个 [3, H, W] 的张量,其中 H 和 W 是图像的高和宽,3 表示 RGB 三通道。Normalize:
- 调整张量图像中的像素值,使其均值为零,标准差为一,更易于模型处理。通常需要提供均值和标准差参数,对每个图像通道分别应用。
RandomCrop:
- 随机裁剪图像到指定大小,帮助模型学习不同视图,防止过拟合。
RandomHorizontalFlip:
- 随机水平翻转图像,提高模型对图像左右变换后的识别能力。
Resize:
- 将图像调整为指定尺寸,确保所有输入图像具有统一大小。
Compose:
- 用于组合多个变换操作,可以连贯地对每张图像进行一系列处理。
举例说明
假设有一个猫和狗的图片数据集,准备训练一个图像识别模型。原始图片大小不一,就可以使用 transforms
来规范化这些图片:
import torchvision.transforms as transforms# 定义一个变换序列transform = transforms.Compose([ transforms.Resize((128, 128)), # 将图像调整到 128x128 像素 transforms.ToTensor(), # 转换成 Tensor 格式并归一化 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 对像素值进行归一化])# 应用于你的数据集载入器
在这个例子里,transform 先调整图片大小,接着转换成 Tensor 格式,最后进行标准化处理。
总结
• transforms 是处理和增强图像数据的强大工具,尤其在模型训练阶段。
• 通过数据增强(如裁剪、翻转),可以提高模型的泛化能力和鲁棒性。
• 组合变换简化数据预处理过程,确保模型接收到统一标准化的数据。
二、简单的使用
import torchvision.transforms as transformsfrom PIL import Image#使用ToTensor()将PIL图像转换为张量#为什么要使用张量类型?img_path = "hymenoptera_data/train/ants/0013035.jpg"img = Image.open(img_path)print(img)tensor_tran = transforms.ToTensor()tentor_img = tensor_tran(img)
截图:
解释 transforms.ToTensor()
的实例及 __call__
方法
tensor_tran是transforms.ToTensor()类的一个实例在定义 tensor_tran = transforms.ToTensor() 后,它具有(继承) call 方法,使其可以像函数一样调用
更详细的解释
类
transforms.ToTensor
是一个类,负责将图像转换为张量。类通常定义了一组属性和方法,包括特殊的 __call__
方法。实例
tensor_tran = transforms.ToTensor()
创建了 transforms.ToTensor
类的一个实例。实例 tensor_tran
继承了类所定义的所有功能,包括 __call__
方法。调用行为
tensor_tran(img)
涉及调用 tensor_tran
的 __call__
方法。这表示 tensor_tran
可以像函数一样使用,用来处理输入数据 img
。Ps:在 Python 中,call 方法是一个特殊的方法,它允许对象像函数一样被调用。在定义了一个类并为这个类实现 call 方法之后,创建的对象实例就可以直接使用括号 () 来调用,就像调用函数一样。这是 Python 中一种让对象更具灵活性和可操作性的技术。
class MyCallable: def __call__(self, x): return x * xobj = MyCallable()result = obj(5) # instance is called like a functionprint(result) # outputs: 25
• 在这个例子中,MyCallable 类定义了一个 call 方法。通过实例化 MyCallable 后,obj(5) 实际上调用的是 call 方法。
读取Tensor图片
from PIL import Imagefrom torch.utils.tensorboard import SummaryWriter img_path = "hymenoptera_data/train/ants/0013035.jpg"# 打开图像img = Image.open(img_path)# 创建一个 SummaryWriter 实例writer = SummaryWriter("logs")# 创建 ToTensor 变换tensor_tran = transforms.ToTensor()# 将图像转换为张量tensor_img = tensor_tran(img)# 将转换后的张量写入 TensorBoardwriter.add_image("Tensor_img", tensor_img) # 关闭写入器writer.close()