掘金 人工智能 前天 14:18
使用 PyTorch 学习生成式人工智能——CycleGAN:将金发转换为黑发
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文介绍了CycleGAN,一种无需成对样本即可学习图像在不同领域间转换的生成对抗网络。CycleGAN通过循环一致性损失函数,确保转换后的图像能够还原回原始图像,从而保留关键特征。文章详细讲解了CycleGAN的架构,包括两个生成器和两个判别器,并以黑发和金发图像转换为例,说明了CycleGAN的工作原理和训练过程。此外,还介绍了如何使用名人脸部数据集训练CycleGAN模型,实现黑发与金发之间的相互转换。

🎨CycleGAN的核心理念在于无需成对样本即可学习领域之间的映射,实现图像在不同领域之间的转换,例如将马变成斑马,黑发变金发等,为图像处理提供了新的可能性。

🔄循环一致性损失是CycleGAN的关键创新,它确保了图像在经过两次转换后能够尽可能地恢复到原始状态,从而保证了图像内容和结构的一致性,避免了信息丢失。

👩‍🦰CycleGAN由两个生成器和两个判别器组成,生成器负责实现图像在两个领域间的转换,判别器则判断其所属领域图像的真实性,通过对抗训练不断提高生成器的生成能力和判别器的判别能力。

🧑‍💻文章提供了一个使用CycleGAN进行黑发和金发图像相互转换的实例,详细介绍了数据集的准备、模型的构建和训练过程,读者可以通过实践加深对CycleGAN的理解和应用。

本章内容包括:

在前面三章中,我们讨论的生成对抗网络(GAN)模型,都是试图生成与训练集中的图像难以区分的图像。

你可能会想:我们能否实现图像在不同领域之间的转换?比如将马变成斑马,黑发变金发或金发变黑发,图像中戴眼镜与否的转换,将照片转换成绘画风格,或者将冬季场景转换成夏季场景?事实证明,这些都是可以实现的,本章将通过 CycleGAN 教你如何做到这一点!

CycleGAN 于2017年提出,其关键创新在于无需成对样本就能学习领域之间的映射。CycleGAN 有许多有趣且实用的应用,例如模拟人脸的老化或返老还童过程以辅助数字身份验证,或是在设计服装时无需实际制作每种颜色或图案的样衣,就能通过视觉化展示不同款式,从而优化设计流程。

CycleGAN 使用循环一致性损失函数,确保能够从转换后的图像还原回原始图像,从而鼓励保留关键特征。循环一致性损失的理念非常巧妙,值得特别强调。本章中的 CycleGAN 包含两个生成器:分别称为“黑发生成器”和“金发生成器”。黑发生成器以一张金发图像作为输入(不同于之前看到的随机噪声向量),将其转换成黑发图像;而金发生成器以黑发图像为输入,转换成金发图像。

训练过程中,我们先将一张黑发真实图像输入金发生成器,生成一张假金发图像。然后,将该假金发图像输入黑发生成器,转换回黑发图像。如果两个生成器效果良好,原始黑发图像与经过一轮转换后生成的假黑发图像几乎无差异。训练 CycleGAN 时,我们调整模型参数以最小化对抗损失和循环一致性损失之和。与第3章和第4章类似,对抗损失用来衡量生成器欺骗判别器的能力以及判别器区分真假样本的能力。循环一致性损失是 CycleGAN 的独特概念,衡量原始图像与一轮转换后生成图像之间的差异。将循环一致性损失纳入总损失函数,是 CycleGAN 的关键创新。

本章示例中,我们将以黑发和金发图像作为两个领域来训练 CycleGAN,但该模型适用于任意两个图像领域。为了加深理解,我会请你用第5章中“带眼镜”和“不带眼镜”的人脸图像训练同一个 CycleGAN 模型。该解决方案已上传至本书的 GitHub 仓库(github.com/markhliu/DG…),你将看到训练好的模型确实能实现为人脸图像添加或去除眼镜的功能.

6.1 CycleGAN 与循环一致性损失

CycleGAN 扩展了基本的 GAN 架构,包含两个生成器和两个判别器。每对生成器和判别器负责学习两个不同领域之间的映射。它旨在实现图像从一个领域到另一个领域的转换(例如马与斑马、夏季与冬季场景等),同时保留原始图像的关键特征。CycleGAN 使用循环一致性损失,确保能够从转换后的图像重建出原始图像,从而鼓励关键特征的保留。

本节将首先介绍 CycleGAN 的架构,重点强调 CycleGAN 的关键创新——循环一致性损失。

6.1.1 什么是 CycleGAN?

CycleGAN 由两个生成器和两个判别器组成。生成器负责实现图像在两个领域间的转换,判别器则判断其所属领域图像的真实性。这些网络能够将照片转换成模仿著名画家风格或特定艺术流派的艺术作品,从而架起艺术与技术之间的桥梁。它们也可应用于医疗领域,例如将 MRI 图像转换成 CT 扫描图像,或反之。这对于某种成像技术不可用或成本过高的场景非常有帮助。

本章的项目中,我们将实现黑发与金发图像之间的相互转换,因此以此作为 CycleGAN 工作原理的示例。图 6.1 展示了 CycleGAN 的架构示意图。

图 6.1 CycleGAN 架构示意图,用于将黑发图像转换为金发图像,以及将金发图像转换为黑发图像。图中还概述了最小化对抗损失的训练步骤。模型如何最小化循环一致性损失将在图 6.2 中进行说明。

训练 CycleGAN 时,我们使用来自两个不同领域的无配对数据集。这里使用了 48,472 张黑发名人脸部图像和 29,980 张金发图像。我们调整模型参数,以最小化对抗损失和循环一致性损失的总和。为了便于讲解,图 6.1 仅解释了对抗损失,循环一致性损失将在下一小节中详细说明。

在每次训练迭代中,我们将真实的黑发图像(图 6.1 左上角)输入金发生成器,生成假的金发图像。然后将生成的假金发图像与真实金发图像一起输入金发判别器(图中上方中间位置),金发判别器会输出每张图像是真实金发图像的概率。我们将判别器的预测与真实标签(图像是否是真实金发)进行比较,并计算对判别器的损失(Loss_D_Blond)和对生成器的损失(Loss_G_Blond)。

与此同时,在每次训练迭代中,我们将真实的金发图像(图中左中位置)输入黑发生成器(图中左下位置),生成假的黑发图像。我们将生成的假黑发图像与真实黑发图像一起输入黑发判别器(图中底部中间位置),获取其真实概率的预测。我们将黑发判别器的预测与真实标签比较,计算判别器损失(Loss_D_Black)和生成器损失(Loss_G_Black)。生成器和判别器同时训练。训练两个判别器时,我们调整模型参数,最小化判别器损失,该损失是 Loss_D_Black 和 Loss_D_Blond 的总和。

6.1.2 循环一致性损失

训练两个生成器时,我们调整模型参数,以最小化对抗损失和循环一致性损失的总和。对抗损失是上一小节中讨论的 Loss_G_Black 和 Loss_G_Blond 之和。关于循环一致性损失的解释,请参见图 6.2。

图 6.2 CycleGAN 如何最小化原始黑发图像与经过循环转换后的假黑发图像之间的循环一致性损失,以及原始金发图像与经过循环转换后的假金发图像之间的循环一致性损失。

CycleGAN 中生成器的损失函数由两部分组成。第一部分是对抗损失,它确保生成的图像在目标域中与真实图像难以区分。例如,上节定义的 Loss_G_Blond 确保金发生成器产生的假金发图像与训练集中真实的金发图像相似。第二部分是循环一致性损失,它确保从一个域转换到另一个域的图像可以被转换回原始域。

循环一致性损失是 CycleGAN 的关键组成部分,保证经过一次循环转换后可以恢复原始输入图像。具体来说,如果你将真实的黑发图像(图 6.2 左上)转换成假金发图像,再将其转换回假黑发图像(右上),那么最终得到的图像应该与原始黑发图像非常接近。黑发图像的循环一致性损失是在像素级别上,假图像与真实原图之间的平均绝对误差,我们称之为 Loss_Cycle_Black。金发图像在从金发转换为黑发再回到金发的循环中也同理,我们称之为 Loss_Cycle_Blond。总的循环一致性损失是 Loss_Cycle_Black 和 Loss_Cycle_Blond 之和。

6.2 名人脸部数据集

我们将使用带有黑发和金发的名人脸部图像作为两个域。你将在本节中先下载数据,然后处理这些图像,为本章后续的训练做准备。

在本章中,你将使用两个新的Python库:pandas和albumentations。要安装这些库,请在你的Jupyter Notebook中新建一个代码单元,执行以下代码:

!pip install pandas albumentations

然后按照屏幕上的提示完成安装。

6.2.1 下载名人脸部数据集

要下载名人脸部数据集,请登录Kaggle并访问链接 mng.bz/Ompo 。下载后解压,将所有图像文件放入你电脑上的文件夹 /files/img_align_celeba/img_align_celeba/ 中(注意,该文件夹内还有一个同名子文件夹)。该文件夹内约有20万张图片。同时,从Kaggle下载 list_attr_celeba.csv 文件并放入你电脑的 /files/ 文件夹中。该CSV文件列出了每张图像的各种属性。

名人脸部数据集中包含了多种发色的图像:棕色、灰色、黑色、金色等。我们将选取黑发和金发的图像作为训练集,因为这两种发色在数据集中最为丰富。运行下面的代码来筛选所有黑发或金发的图像。

代码示例 6.1 选取黑发或金发图像

import pandas as pdimport os, shutildf = pd.read_csv("files/list_attr_celeba.csv")            # ①加载包含图像属性的CSV文件os.makedirs("files/black", exist_ok=True)                 # ②创建存放黑发图像的文件夹os.makedirs("files/blond", exist_ok=True)                 # ②创建存放金发图像的文件夹folder = "files/img_align_celeba/img_align_celeba"for i in range(len(df)):    dfi = df.iloc[i]    if dfi['Black_Hair'] == 1:                            # ③若属性Black_Hair为1,则移动图像到黑发文件夹        try:            oldpath = f"{folder}/{dfi['image_id']}"            newpath = f"files/black/{dfi['image_id']}"            shutil.move(oldpath, newpath)        except:            pass    elif dfi['Blond_Hair'] == 1:                          # ④若属性Blond_Hair为1,则移动图像到金发文件夹        try:            oldpath = f"{folder}/{dfi['image_id']}"            newpath = f"files/blond/{dfi['image_id']}"            shutil.move(oldpath, newpath)        except:            pass

我们先使用pandas库加载 list_attr_celeba.csv 文件,以判断每张图像是否为黑发或金发。然后在本地分别创建 /files/black//files/blond/ 两个文件夹,分别存放黑发和金发的图像。接着,代码会遍历数据集中的所有图像,如果某张图像的属性 Black_Hair 为1,则将其移动到 /files/black/ 文件夹;如果属性 Blond_Hair 为1,则移动到 /files/blond/ 文件夹。你将会看到黑发图像有48,472张,金发图像有29,980张。图 6.3 展示了部分图像示例。

图 6.3 顶部一排的图像为黑发,底部一排的图像为金发。此外,图像质量很高:所有人脸均正对镜头且位于中心位置,发色清晰可辨。训练数据的数量和质量将有助于CycleGAN模型的训练。

6.2.2 处理黑发和金发图像数据

我们将使CycleGAN模型具有通用性,以便它可以在任意包含两个图像域的数据集上进行训练。同时,我们会定义一个LoadData()类来处理CycleGAN模型的训练数据集。该函数可应用于任意包含两个域的数据集,无论是不同发色的人脸图像、戴眼镜与否的图像,还是夏季与冬季场景的图像。

为此,我们创建了一个本地模块ch06util。请从本书GitHub仓库(github.com/markhliu/DG…)下载文件ch06util.py和__init__.py,并将它们放置于你电脑上的/utils/ 文件夹中。在该本地模块中,我们定义了如下LoadData()类。

代码清单 6.2 CycleGAN训练数据处理的LoadData()类

class LoadData(Dataset):    def __init__(self, root_A, root_B, transform=None):    # ①        super().__init__()        self.root_A = root_A        self.root_B = root_B        self.transform = transform        self.A_images = []        for r in root_A:            files = os.listdir(r)            self.A_images += [r + i for i in files]        self.B_images = []        for r in root_B:                                   # ②            files = os.listdir(r)            self.B_images += [r + i for i in files]        self.len_data = max(len(self.A_images),                            len(self.B_images))        self.A_len = len(self.A_images)        self.B_len = len(self.B_images)    def __len__(self):                                     # ③        return self.len_data    def __getitem__(self, index):                          # ④        A_img = self.A_images[index % self.A_len]        B_img = self.B_images[index % self.B_len]        A_img = np.array(Image.open(A_img).convert("RGB"))        B_img = np.array(Image.open(B_img).convert("RGB"))        if self.transform:            augmentations = self.transform(image=B_img,                                           image0=A_img)            B_img = augmentations["image"]            A_img = augmentations["image0"]        return A_img, B_img

① root_A和root_B分别为存放两个图像域图像的文件夹路径
② 加载每个域中的所有图像
③ 定义获取数据集长度的方法
④ 定义访问每个域单个元素的方法

LoadData()类继承自PyTorch的Dataset类。root_A和root_B列表分别包含域A和域B中的图像文件夹路径。该类加载两个域中的图像,并输出一对图像(一张来自域A,一张来自域B),以便后续用这对图像训练CycleGAN模型。

与前几章类似,我们创建批量数据迭代器,以提高训练过程中的计算效率、内存利用和优化效果。

代码清单 6.3 处理黑发和金发图像以便训练

transforms = albumentations.Compose(    [albumentations.Resize(width=256, height=256),        # ①     albumentations.HorizontalFlip(p=0.5),     albumentations.Normalize(mean=[0.5, 0.5, 0.5],                              std=[0.5, 0.5, 0.5], max_pixel_value=255),  # ②     ToTensorV2()],    additional_targets={"image0": "image"})dataset = LoadData(root_A=["files/black/"],                   root_B=["files/blond/"],                   transform=transforms)                                 # ③loader = DataLoader(dataset, batch_size=1,                    shuffle=True, pin_memory=True)                        # ④

① 将图像调整为256×256像素
② 将图像像素值归一化到[-1, 1]范围
③ 对图像应用LoadData()类进行处理
④ 创建训练用的数据迭代器

我们首先创建了albumentations库中Compose()类的一个实例,命名为transforms。该类对图像进行一系列变换:调整大小到256×256像素,归一化像素值至[-1, 1]范围。代码中的HorizontalFlip()参数以50%的概率对训练图像进行水平翻转。水平翻转是一种简单但有效的增强手段,可增加训练数据多样性,提升模型泛化能力和鲁棒性。图像增强及尺寸调整均有助于提高CycleGAN模型的性能,使生成的图像更具真实感。

接着,我们使用LoadData()类加载黑发和金发图像。由于图像文件较大,我们将批量大小设置为1,并在每次迭代中用一对图像训练模型。若设置批量大小大于1,可能会导致计算机内存不足。

6.3 构建 CycleGAN 模型

本节中,我们将从头开始构建一个 CycleGAN 模型。我们会特别注意使该模型具备通用性,以便它能在任意包含两个图像域的数据集上训练。因此,我们用 A 和 B 来表示两个域,而不是例如用黑发和金发图像作为示例。作为练习,你将用第 5 章中使用的眼镜数据集训练同样的 CycleGAN 模型。这将帮助你将本章所学技能应用到其他实际场景中,使用不同的数据集。

6.3.1 创建两个判别器

虽然 CycleGAN 有两个判别器,但它们在设计上是完全相同的。因此,我们只需创建一个 Discriminator() 类,然后实例化两次:一个实例为判别器 A,另一个为判别器 B。CycleGAN 中的两个域是对称的,哪个叫域 A 并不重要——可以是黑发图像,也可以是金发图像。

打开你刚下载的 ch06util.py 文件,里面定义了 Discriminator() 类。

代码清单 6.4 CycleGAN 中 Discriminator() 类的定义

class Discriminator(nn.Module):    def __init__(self, in_channels=3, features=[64,128,256,512]):        super().__init__()        self.initial = nn.Sequential(            nn.Conv2d(in_channels, features[0],                  # ①                      kernel_size=4, stride=2, padding=1,                      padding_mode="reflect"),            nn.LeakyReLU(0.2, inplace=True))        layers = []        in_channels = features[0]        for feature in features[1:]:                            # ②            layers.append(Block(in_channels, feature,                                 stride=1 if feature == features[-1] else 2))            in_channels = feature        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4,  # ③                                stride=1, padding=1, padding_mode="reflect"))        self.model = nn.Sequential(*layers)    def forward(self, x):        out = self.model(self.initial(x))        return torch.sigmoid(out)                               # ④

① 第一层 Conv2d 有 3 个输入通道和 64 个输出通道
② 接下来三层 Conv2d 的输出通道分别为 128、256 和 512
③ 最后一层 Conv2d 有 512 个输入通道和 1 个输出通道
④ 对输出应用 sigmoid 激活函数,使输出可被解释为概率值

上述代码定义了判别器网络。其架构类似于第 4 章中的判别器网络和第 5 章中的判别器(critic)网络。主要由五个 Conv2d 层组成。最后一层使用 sigmoid 激活函数,是因为判别器执行的是二分类任务。判别器以三通道彩色图像作为输入,输出 0 到 1 之间的单一数值,代表该图像为真实图像的概率。

代码中 padding_mode="reflect" 参数表示对输入张量进行反射填充。反射填充有助于保留边缘信息,避免在边界引入人为的零值,令输入张量边界处过渡更加平滑,这有利于模型区分不同域的图像。

接着,我们创建该类的两个实例,分别命名为 disc_A 和 disc_B:

from utils.ch06util import Discriminator, weights_init    # ①import torchdevice = "cuda" if torch.cuda.is_available() else "cpu"disc_A = Discriminator().to(device)disc_B = Discriminator().to(device)                       # ②weights_init(disc_A)weights_init(disc_B)                                      # ③

① 从本地模块导入 Discriminator 类和 weights_init 函数
② 创建两个 Discriminator 类的实例
③ 初始化权重

在本地模块 ch06util 中,还定义了一个 weights_init() 函数用于初始化模型权重,定义方式与第 5 章类似。我们对新建的两个判别器 disc_A 和 disc_B 执行权重初始化。

现在,我们已经拥有两个判别器,接下来将创建两个生成器。

6.3.2 创建两个生成器

同样地,我们在本地模块中定义一个 Generator() 类,并实例化两次:一个实例为生成器 A,另一个为生成器 B。在你刚下载的 ch06util.py 文件中,已经定义了 Generator() 类。

代码清单 6.5 CycleGAN 中的 Generator() 类

class Generator(nn.Module):    def __init__(self, img_channels, num_features=64,                 num_residuals=9):        super().__init__()             self.initial = nn.Sequential(            nn.Conv2d(img_channels, num_features, kernel_size=7,                      stride=1, padding=3, padding_mode="reflect"),            nn.InstanceNorm2d(num_features),            nn.ReLU(inplace=True))        self.down_blocks = nn.ModuleList(            [ConvBlock(num_features, num_features*2, kernel_size=3,                       stride=2, padding=1),             ConvBlock(num_features*2, num_features*4, kernel_size=3, ①                       stride=2, padding=1)])        self.res_blocks = nn.Sequential(                           ②            *[ResidualBlock(num_features * 4)               for _ in range(num_residuals)])        self.up_blocks = nn.ModuleList(            [ConvBlock(num_features * 4, num_features * 2,                       down=False, kernel_size=3, stride=2,                       padding=1, output_padding=1),             ConvBlock(num_features * 2, num_features * 1,      ③                       down=False, kernel_size=3, stride=2,                       padding=1, output_padding=1)])        self.last = nn.Conv2d(num_features * 1, img_channels,                              kernel_size=7, stride=1,                              padding=3, padding_mode="reflect")            def forward(self, x):        x = self.initial(x)        for layer in self.down_blocks:            x = layer(x)        x = self.res_blocks(x)        for layer in self.up_blocks:            x = layer(x)        return torch.tanh(self.last(x))                            ④

① 三个 Conv2d 层
② 九个残差块
③ 两个上采样块
④ 输出层使用 tanh 激活函数

生成器网络包含多个 Conv2d 层,后接九个残差块(稍后详细讲解)。然后,网络有两个上采样块,每个包含一个 ConvTranspose2d 层、一个 InstanceNorm2d 层和一个 ReLU 激活层。和之前章节一样,输出层使用 tanh 激活函数,使输出像素值处于 -1 到 1 范围内,和训练集图像一致。

生成器中的残差块在本地模块中定义如下:

class ConvBlock(nn.Module):    def __init__(self, in_channels, out_channels,                  down=True, use_act=True, **kwargs):        super().__init__()        self.conv = nn.Sequential(            nn.Conv2d(in_channels, out_channels,                       padding_mode="reflect", **kwargs)            if down            else nn.ConvTranspose2d(in_channels,                                    out_channels, **kwargs),            nn.InstanceNorm2d(out_channels),            nn.ReLU(inplace=True) if use_act else nn.Identity())    def forward(self, x):        return self.conv(x)  class ResidualBlock(nn.Module):    def __init__(self, channels):        super().__init__()        self.block = nn.Sequential(            ConvBlock(channels, channels, kernel_size=3, padding=1),            ConvBlock(channels, channels,                      use_act=False, kernel_size=3, padding=1))    def forward(self, x):        return x + self.block(x)

残差连接是深度学习中特别是在设计深层神经网络时常用的概念,你将在本书后续章节中频繁看到。它用于解决梯度消失问题,该问题在非常深的网络中较为常见。在残差块中(残差连接网络的基本单元),输入先经过一系列变换(如卷积、激活和批归一化或实例归一化),然后将该变换的输出与输入相加。图 6.4 展示了上述残差块架构的示意图。

图6.4 残差块的架构示意图。输入 x 会经过一系列变换(两组 Conv2d 层和 InstanceNorm2d 层,中间夹着一个 ReLU 激活层),然后输入 x 会与这些变换的输出 f(x) 相加,因此残差块的输出是 x + f(x)。

每个残差块中的变换不同。以本例为例,输入 x 先经过两组 Conv2d 层和 InstanceNorm2d 层,以及一个 ReLU 激活层,随后将输入 x 与变换输出 f(x) 相加,形成最终输出 x+f(x),这也就是“残差连接”的由来。

接下来,我们创建 Generator() 类的两个实例,分别命名为 gen_A 和 gen_B:

from utils.ch06util import Generator  gen_A = Generator(img_channels=3, num_residuals=9).to(device)gen_B = Generator(img_channels=3, num_residuals=9).to(device)weights_init(gen_A)weights_init(gen_B)

在训练模型时,我们将使用平均绝对误差(L1 损失)来衡量循环一致性损失,使用均方误差(L2 损失)来衡量对抗损失。L1 损失通常用于数据有噪声且存在许多异常值的情况,因为它对极端值的惩罚比 L2 损失小。
因此,我们导入以下损失函数:

import torch.nn as nn  l1 = nn.L1Loss()mse = nn.MSELoss()g_scaler = torch.cuda.amp.GradScaler()d_scaler = torch.cuda.amp.GradScaler()

L1 和 L2 损失都是在像素级别计算的。原始图像和生成的假图像都是形状为 (3, 256, 256)。计算损失时,我们先计算两张图像对应位置像素值的差异(L1 是绝对值,L2 是平方值),然后对 3×256×256=196608 个位置的结果求平均。

我们将使用 PyTorch 的自动混合精度工具 torch.cuda.amp 来加速训练。PyTorch 张量默认的数据类型是 float32(32 位浮点数),其内存占用是 16 位浮点数(float16)的两倍,且操作速度较慢。float16 计算速度更快,但精度较低。混合精度训练通过根据操作需要动态选择 float32 或 float16,实现精度和速度的平衡。

如同第 4 章所做的,我们对判别器和生成器均使用 Adam 优化器:

lr = 0.00001opt_disc = torch.optim.Adam(list(disc_A.parameters()) +   list(disc_B.parameters()), lr=lr, betas=(0.5, 0.999))opt_gen = torch.optim.Adam(list(gen_A.parameters()) +   list(gen_B.parameters()), lr=lr, betas=(0.5, 0.999))

接下来,我们将使用黑发和金发图像来训练 CycleGAN 模型。

6.4 使用 CycleGAN 在黑发和金发之间进行转换

现在我们已经准备好了训练数据和 CycleGAN 模型,将使用黑发和金发的图像来训练模型。与所有 GAN 模型一样,训练完成后我们会丢弃判别器,使用两个训练好的生成器分别将黑发图像转换为金发图像,以及将金发图像转换为黑发图像。

6.4.1 训练 CycleGAN 进行黑发与金发转换

如第4章所述,我们通过视觉检查来决定训练何时停止。为此,我们定义了一个函数用来测试真实图像和对应生成图像的样子,便于我们通过视觉比较模型效果。在本地模块 ch06util 中定义了 test() 函数:

def test(i, A, B, fake_A, fake_B):    save_image(A*0.5+0.5, f"files/A{i}.png")    save_image(B*0.5+0.5, f"files/B{i}.png")               ①    save_image(fake_A*0.5+0.5, f"files/fakeA{i}.png")    save_image(fake_B*0.5+0.5, f"files/fakeB{i}.png")      ②

① 将域 A 和域 B 中的真实图像保存到本地文件夹
② 将生成器在批次 i 中生成的对应假图像保存到本地

每训练 100 个批次,我们就保存四张图像:两个域的真实图像和对应的生成假图。这样可以定期检查生成图像,并与真实图像比较,评估训练进度。该函数设计通用,可适用于任意两个域的图像。

此外,在 ch06util 本地模块中定义了 train_epoch() 函数,用于训练判别器和生成器一个 epoch。下面代码段突出显示了训练两个判别器的部分:

def train_epoch(disc_A, disc_B, gen_A, gen_B, loader, opt_disc,        opt_gen, l1, mse, d_scaler, g_scaler, device):    loop = tqdm(loader, leave=True)    for i, (A, B) in enumerate(loop):                       ①        A = A.to(device)        B = B.to(device)        with torch.cuda.amp.autocast():                    ②            fake_A = gen_A(B)            D_A_real = disc_A(A)            D_A_fake = disc_A(fake_A.detach())            D_A_real_loss = mse(D_A_real, torch.ones_like(D_A_real))            D_A_fake_loss = mse(D_A_fake, torch.zeros_like(D_A_fake))            D_A_loss = D_A_real_loss + D_A_fake_loss            fake_B = gen_B(A)            D_B_real = disc_B(B)            D_B_fake = disc_B(fake_B.detach())            D_B_real_loss = mse(D_B_real, torch.ones_like(D_B_real))            D_B_fake_loss = mse(D_B_fake, torch.zeros_like(D_B_fake))            D_B_loss = D_B_real_loss + D_B_fake_loss            D_loss = (D_A_loss + D_B_loss) / 2             ③        opt_disc.zero_grad()        d_scaler.scale(D_loss).backward()        d_scaler.step(opt_disc)        d_scaler.update()        …

① 遍历两个域中所有的图像对
② 使用 PyTorch 自动混合精度加速训练
③ 两个判别器的总损失是各自对抗损失的简单平均

这里使用 detach() 方法切断 fake_A 和 fake_B 的梯度,减少内存消耗并加快计算速度。训练两个判别器与第4章类似,但有两点不同:

    这里有两个判别器,分别对应域 A 和域 B 的图像;使用自动混合精度训练,训练时间缩短超过 50%。

两个生成器在同一迭代中同时训练,代码片段如下:

def train_epoch(disc_A, disc_B, gen_A, gen_B, loader, opt_disc,        opt_gen, l1, mse, d_scaler, g_scaler, device):        …        with torch.cuda.amp.autocast():            D_A_fake = disc_A(fake_A)            D_B_fake = disc_B(fake_B)            loss_G_A = mse(D_A_fake, torch.ones_like(D_A_fake))            loss_G_B = mse(D_B_fake, torch.ones_like(D_B_fake))      ①            cycle_B = gen_B(fake_A)            cycle_A = gen_A(fake_B)            cycle_B_loss = l1(B, cycle_B)            cycle_A_loss = l1(A, cycle_A)                            ②            G_loss = loss_G_A + loss_G_B + cycle_A_loss * 10 + cycle_B_loss * 10 ③        opt_gen.zero_grad()        g_scaler.scale(G_loss).backward()        g_scaler.step(opt_gen)        g_scaler.update()        if i % 100 == 0:            test(i, A, B, fake_A, fake_B)                                ④        loop.set_postfix(D_loss=D_loss.item(), G_loss=G_loss.item())

① 两个生成器的对抗损失
② 两个生成器的循环一致性损失
③ 两个生成器的总损失是对抗损失和循环一致性损失的加权和,循环一致性损失权重为10倍
④ 每训练 100 个批次生成一次图像供视觉检查

训练两个生成器与第4章相比有两点不同:

    这里同时训练两个生成器;总损失是对抗损失和循环一致性损失的加权和,循环一致性损失权重大约是对抗损失的10倍,但改变权重(如9或12)结果相似。

循环一致性损失是原始图像与经过两次转换后重构图像的平均绝对误差。

一切准备就绪,开始训练循环:

from utils.ch06util import train_epoch  for epoch in range(1):    train_epoch(disc_A, disc_B, gen_A, gen_B, loader, opt_disc,    opt_gen, l1, mse, d_scaler, g_scaler, device)                   ①torch.save(gen_A.state_dict(), "files/gen_black.pth")torch.save(gen_B.state_dict(), "files/gen_blond.pth")               ②

① 使用黑发和金发图像训练 CycleGAN 一个 epoch
② 保存训练好的模型权重

如果使用 GPU 训练,训练过程需数小时,CPU 训练则可能需一天。如果没有训练资源,可以从我的网站下载预训练生成器:gattonweb.uky.edu/faculty/liu…。解压后将gen_black.pth 和 gen_blond.pth 文件放入本地 /files/ 文件夹。这样你就能在下一节实现黑发与金发图像的互相转换。

练习 6.1
在训练 CycleGAN 时,我们默认域 A 是黑发图像,域 B 是金发图像。请修改清单6.2中的代码,使域 A 包含金发图像,域 B 包含黑发图像。

6.4.2 黑发图像和金发图像的往返转换

由于训练数据集的高质量和丰富数量,我们成功地训练了 CycleGAN。我们不仅实现了黑发图像与金发图像之间的转换,还进行了往返转换。比如,我们先将黑发图像转换成金发图像,再将其转换回黑发图像。这样,我们可以将同一域中往返转换后的生成图像与原始图像进行对比,观察差异。

以下代码实现了两个域之间的图像转换,以及每个域中图像的往返转换。
清单 6.8 黑发与金发图像的往返转换

gen_A.load_state_dict(torch.load("files/gen_black.pth",    map_location=device))gen_B.load_state_dict(torch.load("files/gen_blond.pth",    map_location=device))i=1for black, blond in loader:    fake_blond = gen_B(black.to(device))    save_image(black*0.5+0.5, f"files/black{i}.png")             ①    save_image(fake_blond*0.5+0.5, f"files/fakeblond{i}.png")     fake2black = gen_A(fake_blond)    save_image(fake2black*0.5+0.5,        f"files/fake2black{i}.png")                             ②    fake_black = gen_A(blond.to(device))    save_image(blond*0.5+0.5, f"files/blond{i}.png")             ③    save_image(fake_black*0.5+0.5, f"files/fakeblack{i}.png")    fake2blond = gen_B(fake_black)    save_image(fake2blond*0.5+0.5,        f"files/fake2blond{i}.png")                             ④    i = i + 1    if i > 10:        break

① 原始黑发图像
② 往返转换后的假黑发图像
③ 原始金发图像
④ 往返转换后的假金发图像

我们在本地文件夹 /files/ 中保存了六组图像。第一组是原始的黑发图像。第二组是由训练好的金发生成器产生的假金发图像,文件名为 fakeblond0.png、fakeblond1.png 等。第三组是往返转换后生成的假黑发图像:我们将刚生成的假金发图像输入训练好的黑发生成器,得到假黑发图像,保存为 fake2black0.png、fake2black1.png 等。图 6.5 显示了这三组图像。

图 6.5 黑发图像的往返转换。图中共有三排图像:
顶排是训练集中原始的黑发图像;
中排是经过训练好的金发生成器生成的对应假金发图像;
底排是经过往返转换后的假黑发图像:我们将中排的假金发图像输入训练好的黑发生成器,生成假黑发图像。

图 6.5 中的三排图像展示了训练效果。顶排展示了训练集中的原始黑发图像,中排是由训练好的金发生成器生成的假金发图像,底排是经过往返转换后生成的假黑发图像——它们几乎与顶排的原始黑发图像一模一样!这说明我们的 CycleGAN 模型效果非常好。

本地文件夹 /files/ 中的第四组图像是原始的金发图像,第五组是由训练好的黑发生成器生成的假黑发图像,最后一组是往返转换后的假金发图像。图 6.6 展示了这三组图像的对比效果。

图 6.6 金发图像的往返转换。图中共有三排图像:
顶排是训练集中原始的金发图像;
中排是由训练好的黑发生成器生成的对应假黑发图像,这些假黑发图像对应顶排中的同一张人脸;
底排是经过往返转换后的假金发图像:我们将中排的假黑发图像输入训练好的金发生成器,生成假金发图像。

图 6.6 中,中排展示了训练好的黑发生成器生成的假黑发图像,它们与顶排中对应的人脸相同。底排是经过往返转换后生成的假金发图像,与顶排的原始金发图像几乎一模一样。

练习 6.2
CycleGAN 模型是通用的,可以应用于任何包含两个图像域的训练数据集。请使用你在第 5 章下载的眼镜图像训练 CycleGAN 模型,将带眼镜的图像作为域 A,无眼镜的图像作为域 B。然后使用训练好的 CycleGAN 实现眼镜的添加和去除(即实现两个域之间的图像转换)。示例实现和结果见书籍的 GitHub 仓库。

至此,我们主要聚焦于一种生成模型——GAN。在下一章,你将学习另一种生成模型——变分自编码器(VAE),用以生成高分辨率图像。你将了解 VAE 相较于 GAN 的优缺点,更重要的是,学习 VAE 中的编码器-解码器架构。这种架构在包括后续章节学习的 Transformer 等生成模型中被广泛使用。

总结

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

CycleGAN 图像转换 循环一致性损失 生成对抗网络
相关文章