掘金 人工智能 2024年07月08日
Swin Transformer:深度解析其架构与代码实现
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

Swin Transformer是强大的视觉Transformer模型,探讨其架构并尝试复现网络结构,在多任务中达先进性能水平。

💥Swin Transformer的诞生背景与改进:它由微软亚洲研究院推出,为克服Transformer在CV领域的难题,进行了两项创新改进,构建层次化Transformer并引入locality概念,能灵活应用于多种任务。

🎯Swin Transformer的主要特点:采用层次化结构,逐步降低特征空间维度并增加深度;具有移位窗口自注意力,减少计算量并捕获丰富上下文信息;能够学习多尺度特征表示。

🛠Swin Transformer的具体实现:以图片为起点,经过Patch Partition层、Linear Embedding层、Swin Transformer Block、Patch Merging层、AdaptiveAvgPool1d层和全连接层等处理,最终完成图像分类。

📄Swin Transformer的代码分析:包括ShiftWindowAttentionBlock类实现带有窗口移位的自注意力机制,SwinTransformer类构建完整模型,以及模型组件和前向传播的详细介绍。


theme: smartblue

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

Swin Transformer是一种强大的视觉Transformer模型,它通过引入层次化结构和基于窗口偏移的自注意力机制,有效提升了特征提取的能力。在多个计算机视觉任务中,Swin Transformer已经达到了最先进的性能水平。本文将深入探讨Swin Transformer的架构,并尝试将其网络结构进行复现。

一、Swin Transformer 概述

Swin Transformer通过扩展原始Transformer模型的能力,引入了层次化结构和基于窗口偏移的自注意力机制,使其能够有效处理图像数据,并可应用于图像分类、目标检测和分割等任务。

1.背景介绍

Swin Transformer,由微软亚洲研究院孕育的新星,今年在学术界大放异彩,以其独特的魅力在图像分类、图像分割和目标检测等众多领域中斩获了无数荣誉。

然而,Swin Transformer的诞生并非一帆风顺。在它之前,Transformer在自然语言处理(NLP)领域已经取得了辉煌的成就,但在计算机视觉(CV)的舞台上却未能同样耀眼。Swin Transformer的创造者们深入分析了这一现象,发现主要有两个难题:首先,NLP中的token大小是固定的,而CV中的特征尺度变化莫测,如同变幻莫测的风;其次,CV对于分辨率的要求更高,而使用Transformer的计算复杂度与图像尺寸的平方成正比,这无疑给计算带来了巨大的压力。

为了克服这些挑战,Swin Transformer进行了两项创新性的改进:首先,它借鉴了CNN中常用的层次化构建方式,构建了层次化的Transformer;其次,它引入了locality的概念,对没有重叠的window区域进行self-attention计算,能够精准地聚焦于每一个角落。

Swin Transformer不仅仅是一个技术革新,它更是一个多才多艺的艺术家,能够灵活地应用于图像分类、目标检测和语义分割等任务,成为这些任务的通用骨干网络。有人说,Swin Transformer可能是CNN的完美替代者,但我认为,它更像是一位能够与CNN并肩作战的伙伴,共同推动计算机视觉技术的发展。

2.主要特点

3.对比

下图为Swin Transformer与ViT在处理图片方式上的对比,可以看出,Swin Transformer有着ResNet一样的残差结构和CNN具有的多尺度图片结构。

二、具体实现

首先Swin-Transformer 以一张图片作为起点,这是它的画布,准备在上面绘制出精彩的图案。

1.Patch Partition 层

在 Patch Partition 层,这张图片被巧妙地拆分成众多小块,就像是将一幅大画卷分解为易于管理的小片段。Patch Partition是模型对输入图像进行预处理的一种重要操作。该操作的主要目的是将原始的连续像素图像分割成一系列固定大小的图像块(patches),以便进一步转化为Transformer可以处理的序列数据。

2.Swin Transfomer

随后,Linear Embedding 层赋予了这些小块以特征的维度,让它们不再是静止的图像,而是活跃的数据点,为之后的表演做好准备。这些特征化的小块进入 Swin Transformer Block,这是第一阶段,它们在这里学会了如何与周围的伙伴协作,共同构建起初步的图像理解。

3.Patch Merging层

接下来的第二至第四阶段,每个阶段开始前,小块们会经历 Patch Merging 的过程,这就像是将多个小故事合并为一个更加宏大的叙事,每一次合并都让图像的表示更加深入和丰富。Patch Merging层主要是进行下采样,产生分层表示。 Patch Merging 是一种减少序列长度并增加每个补丁表示中通道数的操作。

4.AdaptiveAvgPool1d 层和全连接层

在第四阶段的末尾,所有的数据汇集到输出模块,这里有一个 LayerNorm 层,它确保了数据的平衡和稳定,就像是在演出中保持舞者的稳定和优雅。最后,AdaptiveAvgPool1d 层和全连接层相继登场,它们共同作用于数据,最终完成图像的分类,为这场演出画上完美的句点。

三、代码分析

1.ShiftWindowAttentionBlock 类

class ShiftWindowAttentionBlock(nn.Module):    def init(self, ...):        ...    def forward(self, x):    # patch_num补成能够被window_size整除    if x.size(-2) % self.window_size:        x = nn.ZeroPad2d((0, 0, 0, self.window_size - x.size(-2) % self.window_size))(x)    batch, modal_leng, patch_num, input_dim = x.size()    short_cut = x # resdual    # 窗口偏移    if self.shift_size:        x = torch.roll(x, shifts=-self.shift_size, dims=2) # 只在 patch_num 上 roll   [batch, modal_leng, patch_num, input_dim]    # 窗口化     window_num = patch_num // self.window_size    window_x = x.reshape(batch, modal_leng, window_num, self.window_size, input_dim) # [batch, modal_leng, window_num, window_size, input_dim]    # 基于窗口的多头自注意力    q = self.query(window_x).reshape(batch, modal_leng, window_num, self.window_size, self.head_num, self.att_size).permute(0, 1, 2, 4, 3, 5)     ....

ShiftWindowAttentionBlock 类实现了带有窗口移位的自注意力机制。它接收一个输入张量 x,对其进行自注意力操作,并根据是否启用移位来调整窗口的覆盖范围。

2.SwinTransformer 类

train_shape: 总体训练样本的shape

category: 类别数

embedding_dim: embedding 维度

patch_size: 一个patch长度

head_num: 多头自注意力

att_size: QKV矩阵维度

window_size: 一个窗口包含多少patchs

对于传感窗口数据来讲,在每个单独的模态轴上对时序轴进行patch切分,例如 uci-har 数据集窗口尺寸为 [128, 9],一个patch包含4个数据,那么每个模态轴上的patch_num为32, 总patch数为 32 9:

class SwinTransformer(nn.Module):    def init(self, train_shape, category, embedding_dim=256, patch_size=4, head_num=4, att_size=64, window_size=8):        super().init()        self.series_leng = train_shape[-2]        self.modal_leng = train_shape[-1]        self.patch_num = self.series_leng // patch_size                self.patch_conv = nn.Conv2d(            in_channels=1,            out_channels=embedding_dim,            kernel_size=(patch_size, 1),            stride=(patch_size, 1),            padding=0        )        # 位置信息        self.position_embedding = nn.Parameter(torch.zeros(1, self.modal_leng, self.patch_num, embedding_dim))        # patch_num维度降采样一次后的计算方式        swin_transformer_block1_input_patch_num = math.ceil(self.patch_num / window_size)  window_size        swin_transformer_block2_input_patch_num = math.ceil(math.ceil(swin_transformer_block1_input_patch_num / 2) / window_size)  window_size        swin_transformer_block3_input_patch_num = math.ceil(math.ceil(swin_transformer_block2_input_patch_num / 2) / window_size)  window_size        # Shift_Window_Attention_Layer        # 共3个swin_transformer_block,每个swin_transformer_block对时序维降采样1/2,共降采样1/8        self.swa = nn.Sequential(            # swin_transformer_block 1            nn.Sequential(                 ShiftWindowAttentionBlock(patch_num=swin_transformer_block1_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=False),                ShiftWindowAttentionBlock(patch_num=swin_transformer_block1_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=True)            ),            # swin_transformer_block 2            nn.Sequential(                ShiftWindowAttentionBlock(patch_num=swin_transformer_block2_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=False),                ShiftWindowAttentionBlock(patch_num=swin_transformer_block2_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=True)            ),            # swin_transformer_block 3            nn.Sequential(                ShiftWindowAttentionBlock(patch_num=swin_transformer_block3_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=False),                ShiftWindowAttentionBlock(patch_num=swin_transformer_block3_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=True)            )        )        # classification tower        self.dense_tower = nn.Sequential(            nn.Linear(self.modal_leng  math.ceil(swin_transformer_block3_input_patch_num / 2)  embedding_dim, 1024),            nn.LayerNorm(1024),            nn.ReLU(),            nn.Linear(1024, category)        )

SwinTransformer 类构建了 Swin Transformer 的完整模型。它接收输入数据的形状 train_shape 和类别数 category,以及其他配置参数。

3.模型组件

4.前向传播

def forward(self, x):    x = self.patch_conv(x) # [batch, embedding_dim, patch_num, modal_leng]    x = self.position_embedding + x.permute(0, 3, 2, 1) # [batch, modal_leng, patch_num, embedding_dim]    x = self.swa(x)    x = nn.Flatten()(x)    x = self.dense_tower(x)    return x

forward 方法定义了模型的前向传播过程:

    块卷积:输入数据通过卷积操作转换成嵌入维度。位置嵌入:将位置信息添加到块特征中。Swin Transformer 块:通过多个 Swin Transformer 块进行特征提取。分类塔:在模型顶部,将特征展平并通过线性层进行分类。

Swin Transformer 是一种创新的模型,它将 Transformer 架构的优势引入到计算机视觉领域。通过层次化处理和高效的自注意力机制,Swin Transformer 在多个视觉任务上展现出卓越的性能。提供的代码实现了 Swin Transformer 的核心功能,为进一步的研究和应用提供了基础。

四、致谢

本文和代码实现基于 Swin Transformer 的原始论文和相关研究工作。感谢所有为深度学习和计算机视觉领域做出贡献的研究人员和开发者。

注意:具体的代码实现和模型细节可以联系作者获取,以便进一步的研究和应用。本文首发于稀土掘金,未经允许禁止转发和二次创作,侵权必究。

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Swin Transformer 视觉模型 代码分析 层次化结构
相关文章