PVT:可用于密集任务backbone的金字塔视觉transformer!

基于detectron2实现的PVT开源了,欢迎star:https://github.com/xiaohu2015/pvt_detectron2

自从ViT之后,关于vision transformer的研究呈井喷式爆发,从思路上分主要沿着两大个方向,一是提升ViT在图像分类的效果;二就是将ViT应用在其它图像任务中,比如分割和检测任务上,这里介绍的PVT(Pyramid Vision Transformer) 就属于后者。PVT相比ViT引入了和CNN类似的金字塔结构,使得PVT像CNN那样作为backbone应用在dense prediction任务(分割和检测等)。

CNN结构常用的是一种金字塔架构,如上图所示,CNN网络一般可以划分为不同的stage,在每个stage开始时,特征图的长和宽均减半,而特征维度(channel)扩宽2倍。这主要有两个方面的考虑,一是采用stride=2的卷积或者池化层对特征降维可以增大感受野,另外也可以减少计算量,但同时空间上的损失用channel维度的增加来弥补。但是ViT本身就是全局感受野,所以ViT就比较简单直接了,直接将输入图像tokens化后就不断堆积相同的transformer encoders,这应用在图像分类上是没有太大的问题。但是如果应用在密集任务上,会遇到问题:一是分割和检测往往需要较大的分辨率输入,当输入图像增大时,ViT的计算量会急剧上升;二是ViT直接采用较大patchs进行token化,如采用16x16大小那么得到的粗粒度特征,对密集任务来说损失较大。这正是PVT想要解决的问题,PVT采用和CNN类似的架构,将网络分成不同的stages,每个stage相比之前的stage特征图的维度是减半的,这意味着tokens数量减少4倍,具体结构如下:

每个stage的输入都是一个维度的3-D特征图,对于第1个stage,输入就是RGB图像,对于其它stage可以将tokens重新reshape成3-D特征图。在每个stage开始,首先像ViT一样对输入图像进行token化,即进行patch embedding,patch大小均采用2x2大小(第1个stage的patch大小是4x4),这意味着该stage最终得到的特征图维度是减半的,tokens数量对应减少4倍。PVT共4个stage,这和ResNet类似,4个stage得到的特征图相比原图大小分别是1/4,1/8,1/16和1/32。由于不同的stage的tokens数量不一样,所以每个stage采用不同的position embeddings,在patch embed之后加上各自的position embedding,当输入图像大小变化时,position embeddings也可以通过插值来自适应。

不同的stage的tokens数量不同,越靠前的stage的patchs数量越多,我们知道self-attention的计算量与sequence的长度的平方成正比,如果PVT和ViT一样,所有的transformer encoders均采用相同的参数,那么计算量肯定是无法承受的。PVT为了减少计算量,不同的stages采用的网络参数是不同的。PVT不同系列的网络参数设置如下所示,这里为patch的size,为特征维度大小,为MHA(multi-head attention)的heads数量,为FFN的扩展系数,transformer中默认为4。

可以见到随着stage,特征的维度是逐渐增加的,比如stage1的特征维度只有64,而stage4的特征维度为512,这种设置和常规的CNN网络设置是类似的,所以前面stage的patchs数量虽然大,但是特征维度小,所以计算量也不是太大。不同体量的PVT其差异主要体现在各个stage的transformer encoder的数量差异。

PVT为了进一步减少计算量,将常规的multi-head attention (MHA)用spatial-reduction attention (SRA)来替换。SRA的核心是减少attention层的key和value对的数量,常规的MHA在attention层计算时key和value对的数量为sequence的长度,但是SRA将其降低为原来的。SRA的具体结构如下所示:


在实现上,首先将维度为的patch embeddings通过reshape变换到维度为的3-D特征图,然后均分大小为的patchs,每个patchs通过线性变换将得到维度为的patch embeddings(这里实现上其实和patch emb操作类似,等价于一个卷积操作),最后应用一个layer norm层,这样就可以大大降低K和V的数量。具体实现代码如下:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        # 实现上这里等价于一个卷积层
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) # 这里x_.shape = (B, N/R^2, C)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

从PVT的网络设置上,前面的stage的取较大的值,比如stage1的,说明这里直接将Q和V的数量直接减为原来的1/64,这个就大大降低计算量了。

PVT具体到图像分类任务上,和ViT一样也通过引入一个class token来实现最后的分类,不过PVT是在最后的一个stage才引入:

    def forward_features(self, x):
        B = x.shape[0]

        # stage 1
        x, (H, W) = self.patch_embed1(x)
        x = x + self.pos_embed1
        x = self.pos_drop1(x)
        for blk in self.block1:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        # stage 2
        x, (H, W) = self.patch_embed2(x)
        x = x + self.pos_embed2
        x = self.pos_drop2(x)
        for blk in self.block2:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        # stage 3
        x, (H, W) = self.patch_embed3(x)
        x = x + self.pos_embed3
        x = self.pos_drop3(x)
        for blk in self.block3:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        # stage 4
        x, (H, W) = self.patch_embed4(x)
        cls_tokens = self.cls_token.expand(B, -1, -1) # 引入class token
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed4
        x = self.pos_drop4(x)
        for blk in self.block4:
            x = blk(x, H, W)

        x = self.norm(x)

        return x[:, 0]

具体到分类任务上,PVT在ImageNet上的Top-1 Acc其实是和ViT差不多的。其实PVT最重要的应用是作为dense任务如分割和检测的backbone,一方面PVT通过一些巧妙的设计使得对于分辨率较大的输入图像,其模型计算量不像ViT那么大,论文中比较了ViT-Small/16 ,ViT-Small,PVT-Small和ResNet50四种网络在不同的输入scale下的GFLOPs,可以看到PVT相比ViT要好不少,当输入scale=640时,PVT-Small和ResNet50的计算量是类似的,但是如果到更大的scale,PVT的增长速度就远超过ResNet50了。

PVT的另外一个相比ViT的优势就是其可以输出不同scale的特征图,这对于分割和检测都是非常重要的。因为目前大部分的分割和检测模型都是采用FPN结构,而PVT这个特性可以使其作为替代CNN的backbone而无缝对接分割和检测的heads。论文中做了大量的关于检测,语义分割以及实例分割的实验,可以看到PVT在dense任务的优势。比如,在更少的推理时间内,基于PVT-Small的RetinaNet比基于R50的RetinaNet在COCO上的AP值更高(38.7 vs. 36.3),虽然继续增加scale可以提升效果,但是就需要额外的推理时间:


所以虽然PVT可以解决一部分问题,但是如果输入图像分辨率特别大,可能基于CNN的方案还是最优的。另外旷视最新的一篇论文YOLOF指出其实ResNet一个C5特征加上一些增大感受野的模块就可以在检测上实现类似的效果,这不得不让人思考多尺度特征是不是必须的,而且transformer encoder本身就是全局感受野的。近期Intel提出的DPT直接在ViT模型的基础上通过Reassembles operation来得到不同scale的特征图以用于dense任务,并在ADE20K语义分割数据集上达到新的SOTA(mIoU 49.02)。而在近日,微软提出的Swin Transformer和PVT的网络架构和很类似,但其性能在各个检测和分割数据集上效果达到SOTA(在ADE20K语义分割数据集mIoU 53.5),其核心提出了一种shifted window方法来减少self-attention的计算量。

相信未来会有更好的work!期待!

参考

  1. Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions
  2. whai362/PVT
  3. 大白话Pyramid Vision Transformer
  4. You Only Look One-level Feature
  5. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  6. Vision Transformers for Dense Prediction

推荐阅读

谷歌提出Meta Pseudo Labels,刷新ImageNet上的SOTA!

"未来"的经典之作ViT:transformer is all you need!

大道至简!深度解读CVPR2021论文RepVGG!

PyTorch 源码解读之 torch.autograd

涨点神器FixRes:两次超越ImageNet数据集上的SOTA

Transformer为何能闯入CV界秒杀CNN?

SWA:让你的目标检测模型无痛涨点1% AP

CondInst:性能和速度均超越Mask RCNN的实例分割模型

centerX: 用新的视角的方式打开CenterNet

mmdetection最小复刻版(十一):概率Anchor分配机制PAA深入分析

MMDetection新版本V2.7发布,支持DETR,还有YOLOV4在路上!

CNN:我不是你想的那样

TF Object Detection 终于支持TF2了!

无需tricks,知识蒸馏提升ResNet50在ImageNet上准确度至80%+

不妨试试MoCo,来替换ImageNet上pretrain模型!

重磅!一文深入深度学习模型压缩和加速

从源码学习Transformer!

mmdetection最小复刻版(七):anchor-base和anchor-free差异分析

mmdetection最小复刻版(四):独家yolo转化内幕

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 205,565评论 6 479
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,021评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,003评论 0 341
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,015评论 1 278
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,020评论 5 370
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,856评论 1 283
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,178评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,824评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,264评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,788评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,913评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,535评论 4 322
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,130评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,102评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,334评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,298评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,622评论 2 343

推荐阅读更多精彩内容