Pytorch 实现自定义卷积:以 2.5 维卷积(2.5D Convolution)为例

        在用 Pytorch 实现各种卷积神经网络的时候,一般用到的卷积层都是系统自带的 2 维卷积 torch.nn.Conv2d,或者较少用到的 1 维卷积 torch.nn.Conv1d3 维卷积 torch.nn.Conv3d,这些卷积提供了足够的参数,使得实现带洞卷积(Atrous Convolution)深度可分离卷积(Depthwise Seperable Convolution)等特殊卷积都易如反掌。但有时候,为了某些特殊的需求,不能直接使用经典的卷积层,而是要自定义的实现某种新的卷积运算,比如可形变卷积(Deformable Convolution),因此学会从底层实现自定义卷积层是必要且必须的。

        本文试着提供一个自定义卷积层的简单教程,为了有针对性和实用性,以实现 2.5 维卷积RGB-D 图像语义分割论文:2.5D Convolution for RGB-D Semantic Segmentation)为例。本文是在参考了论文 Pixel-Adaptive Convolutional Neural Networks 的开源项目 pacnet 的基础上实现的,在此对作者表示感谢。

        本文的所有代码都见下文,也可以访问 [GitHub:稍后放出]。

一、2.5 维卷积原理

        对于卷积核为 k \times k2 维卷积,计它的感受野大小为 k\Delta p\times k\Delta p,以 (u_i, v_i) 为中心的标准二维卷积计算如下:

标准的二维卷积计算公式,来源:2.5D Convolution for RGB-D Semantic Segmentation(下同)

类似的,标准的 3 维卷积 计算如下(k \times k \times k):

标准的三维卷积计算公式

        二维卷积(三维卷积)处理图像(视频)数据已经非常成熟,应用十分广泛。对于带有深度信息的 RGB-D 图像的语义分割,如果把深度信息当成一个额外的通道,那么直接使用二维卷积来实现语义分割模型即可。然而,这样做会忽视深度信息中隐藏的几何结构特征,因此有必要设计一种新颖的卷积方式来充分使用深度信息中的几何特征,论文(2.5D Convolution for RGB-D Semantic Segmentation)作者们就设计了一种称为 2.5 维卷积的操作:

2.5 维卷积计算公式

其中 s_i=z(u_i, v_i)\Delta p^\prime/fz(u_i, v_i) 为深度信息,w_lk2 维卷积核的参数,x_l 的计算公式为:

掩模操作

        根据以上公式,如果输入的特征通道数为 m,输出通道数为 n,那么容易知道:

  • 2 维卷积核的参数量:m\times n\times k\times k
  • 3 维卷积核的参数量:m\times n\times k\times k\times k
  • 2.5 维卷积核的参数量:(m\times n\times k\times k)\times k

如果输入、输出的分辨率都是 r\times s (或者 r\times s\times t),那么(大约):

  • 2 维卷积的计算量:m\times n\times k\times k\times r\times s
  • 3 维卷积的计算量:m\times n\times k\times k\times k \times r\times s\times t
  • 2.5 维卷积的计算量:(m\times n\times k\times k\times r\times s)\times k

显然,虽然相比于 2 维卷积来说,2.5 维卷积的参数量和计算量都要大,但对比 3 维卷积来说,在参数量一致的情况下,2.5 维卷积的计算量却小得多。因此,从——性能上优于 2 维卷积,计算量上优于 3 维卷积——的角度看,2.5 维卷积是有意义的

二、2.5 维卷积实现

        严格按照公式 (4-7)来实现,2.5 维卷积的实现代码为(命名为:conv2_5d.py):

# -*- coding: utf-8 -*-
"""
Created on Wed Nov 20 18:58:19 2019

@author: lijingxiong

Implementation of 2.5D convolution:
    paper: 2.5D Convolution for RGB-D Semantic Segmentation.

Reference: https://github.com/NVlabs/pacnet/blob/master/pac.py
"""

import math
import torch
        
        
class RepeatKernelConvFn(torch.autograd.function.Function):
    """2.5D convolution with kernel.
    """
        
    @staticmethod
    def forward(ctx, inputs, kernel, weight, bias=None, stride=1, padding=0, 
                dilation=1):
        """Forward computation.
        
        Args:
            inputs: A tensor with shape [batch, channels, height, width] 
                representing a batch of images.
            kernel: A tensor with shape [k, batch, channels, N, N, k, k],
                where k = kernel_size and N = number of slide windows.
            weight: A tensor with shape [k, out_channels, in_channels, 
                kernel_size, kernel_size].
            bias: None or a tensor with shape [out_channels].
            
        Returns:
            outputs: A tensor with shape [batch, out_channels, height, width].
        """
        (batch_size, channels), input_size = inputs.shape[:2], inputs.shape[2:]
        ctx.in_channels = channels
        ctx.input_size = input_size
        ctx.kernel_size = tuple(weight.shape[-2:])
        ctx.dilation = torch.nn.modules.utils._pair(dilation)
        ctx.padding = torch.nn.modules.utils._pair(padding)
        ctx.stride = torch.nn.modules.utils._pair(stride)
        
        needs_input_grad = ctx.needs_input_grad
        ctx.save_for_backward(
            inputs if (needs_input_grad[1] or needs_input_grad[2]) else None,
            kernel if (needs_input_grad[0] or needs_input_grad[2]) else None,
            weight if (needs_input_grad[0] or needs_input_grad[1]) else None)
        ctx._backend = torch._thnn.type2backend[inputs.type()]
        
        # Slide windows, [batch, channels x kernel_size x kernel_size, N x N],
        # where N is the number of slide windows.
        inputs_wins = torch.nn.functional.unfold(inputs, ctx.kernel_size, 
                                                 ctx.dilation, ctx.padding,
                                                 ctx.stride)

        inputs_wins = inputs_wins.view(
            1, batch_size, channels, *kernel.shape[3:])
        inputs_mul_kernel = inputs_wins * kernel
                
        # Matrix multiplication
        outputs = torch.einsum(
            'hijklmn,hojmn->iokl', (inputs_mul_kernel, weight))
        
        if bias is not None:
            outputs += bias.view(1, -1, 1, 1)
        return outputs
        
    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_outputs):
        grad_inputs = grad_kernel = grad_weight = grad_bias = None
        batch_size, out_channels = grad_outputs.shape[:2]
        output_size = grad_outputs.shape[2:]
        in_channels = ctx.in_channels
        
        # Compute gradients
        inputs, kernel, weight = ctx.saved_tensors
        if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
            grad_inputs_mul_kernel = torch.einsum('iokl,hojmn->hijklmn',
                                                  (grad_outputs, weight))
        if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
            inputs_wins = torch.nn.functional.unfold(inputs, ctx.kernel_size, 
                                                     ctx.dilation, ctx.padding,
                                                     ctx.stride)
            inputs_wins = inputs_wins.view(1, batch_size, in_channels,
                                           output_size[0], output_size[1],
                                           ctx.kernel_size[0], 
                                           ctx.kernel_size[1])
        if ctx.needs_input_grad[0]:
            grad_inputs = grad_outputs.new()
            grad_inputs_wins = grad_inputs_mul_kernel * kernel
            grad_inputs_wins = grad_inputs_wins.view(
                ctx.kernel_size[0], batch_size, -1, output_size[0], output_size[1])
            ctx._backend.Im2Col_updateGradInput(ctx._backend.library_state,
                                                grad_inputs_wins,
                                                grad_inputs,
                                                ctx.input_size[0],
                                                ctx.input_size[1],
                                                ctx.kernel_size[0],
                                                ctx.kernel_size[1],
                                                ctx.dilation[0], 
                                                ctx.dilation[1],
                                                ctx.padding[0], 
                                                ctx.padding[1],
                                                ctx.stride[0],
                                                ctx.stride[1])
        if ctx.needs_input_grad[1]:
            grad_kernel = inputs_wins * grad_inputs_mul_kernel
            grad_kernel = grad_kernel.sum(dim=1, keepdim=True)
        if ctx.needs_input_grad[2]:
            inputs_mul_kernel = inputs_wins * kernel
            grad_weight = torch.einsum('iokl,hijklmn->hojmn',
                                       (grad_outputs, inputs_mul_kernel))
        if ctx.needs_input_grad[3]:
            grad_bias = torch.einsum('iokl->o', (grad_outputs,))
        return (grad_inputs, grad_kernel, grad_weight, grad_bias, None, None,
                None)
        
        
class DepthKernelFn(torch.autograd.function.Function):
    """Compute mask in paper: 
        2.5D convolution for rgb-d semantic segmentation.
    """
    
    @staticmethod
    def forward(ctx, depth, f, kernel_size, stride, padding, dilation):
        """Forward computation.
        
        Args:
            depth: A tensor with shape [batch, 1, height, width] representing
                a batch of depth maps.
            f: A constant.
            
        Returns:
            A tensor with shape [k, batch, 1, N, N, k, k], where 
            k = kernel_size and N = number of slide windows.
        """
        ctx.kernel_size = torch.nn.modules.utils._pair(kernel_size)
        ctx.stride = torch.nn.modules.utils._pair(stride)
        ctx.padding = torch.nn.modules.utils._pair(padding)
        ctx.dilation = torch.nn.modules.utils._pair(dilation)
        
        batch_size, channels, in_height, in_width = depth.shape
        out_height = (in_height + 2 * ctx.padding[0] - 
                      ctx.dilation[0] * (ctx.kernel_size[0] - 1)
                      -1) // ctx.stride[0] + 1
        out_width = (in_width + 2 * ctx.padding[1] - 
                     ctx.dilation[1] * (ctx.kernel_size[1] - 1)
                     -1) // ctx.stride[1] + 1
        
        depth_wins = torch.nn.functional.unfold(depth, ctx.kernel_size,
                                                ctx.dilation, ctx.padding,
                                                ctx.stride)
        depth_wins = depth_wins.view(batch_size, channels, out_height, 
                                     out_width, ctx.kernel_size[0],
                                     ctx.kernel_size[1])
        s_wins = depth_wins / f
        
        kernels = []
        center_y, center_x = ctx.kernel_size[0] // 2, ctx.kernel_size[1] // 2
        for l in range(ctx.kernel_size[0]):
            z_l = depth_wins + (l - (ctx.kernel_size[0] - 1) / 2) * s_wins
            z_l_0 = z_l.contiguous()[:, :, :, :, center_y:center_y + 1,
                                     center_x:center_x + 1]
            s_0 = s_wins.contiguous()[:, :, :, :, center_y:center_y + 1,
                                      center_x:center_x + 1]
            mask_l_ge = torch.where(depth_wins >= z_l_0 - s_0 / 2,
                                    torch.full_like(depth_wins, 1),
                                    torch.full_like(depth_wins, 0))
            mask_l_lt = torch.where(depth_wins < z_l_0 + s_0 / 2,
                                    torch.full_like(depth_wins, 1),
                                    torch.full_like(depth_wins, 0))
            mask_l = torch.where(mask_l_ge == mask_l_lt,
                                 mask_l_ge,
                                 torch.full_like(depth_wins, 0))
            kernels.append(mask_l.unsqueeze(dim=0))
        return torch.cat(kernels, dim=0)
    
    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_outputs):
        return 0, None, None, None, None, None
    
    
class Conv2_5d(torch.nn.Module):
    """Implementation of 2.5D convolution."""
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, bias=True):
        """Constructor."""
        super(Conv2_5d, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = torch.nn.modules.utils._pair(kernel_size)
        self.stride = torch.nn.modules.utils._pair(stride)
        self.padding = torch.nn.modules.utils._pair(padding)
        self.dilation = torch.nn.modules.utils._pair(dilation)
        
        # Parameters: weight, bias
        self.weight = torch.nn.parameter.Parameter(
            torch.Tensor(kernel_size, out_channels, in_channels, kernel_size,
                         kernel_size))
        if bias:
            self.bias = torch.nn.parameter.Parameter(
                torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        # Initialization
        self.reset_parameters()
        
    def forward(self, inputs, depth, f=1):
        """Forward computation.
        
        Args:
            inputs: A tensor with shape [batch, in_channels, height, width] 
                representing a batch of images.
            depth: A tensor with shape [batch, 1, height, width] representing
                    a batch of depth maps.
            f: A constant.
            
        Returns:
            outputs: A tensor with shape [batch, out_channels, height, width].
        """
        kernel = DepthKernelFn.apply(depth, f, self.kernel_size, self.stride,
                                     self.padding, self.dilation)
        
        outputs = RepeatKernelConvFn.apply(inputs, kernel, self.weight,
                                           self.bias, self.stride,
                                           self.padding, self.dilation)
        return outputs
    
    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        return s.format(**self.__dict__)
    
    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias, -bound, bound)

        实现自定义卷积层的要点是:

  • 底层计算需要继承 torch.autograd.function.Function

        定义该类的子类时,需要重载它的两个函数:forwardbackward 函数,分别用于前向传播和反向传播的计算。另外,forward 函数返回值的个数等于 backward 参数的个数(不计算 ctx),而 backward 返回值的个数则等于 forward 函数的参数个数(不计算 ctx),即两者的输入输出是一一对应的。顾名思义,backward 函数是利用链式法则forward 函数的所有输入求梯度,如果某个输入不需要求梯度,那么直接给该参数的梯度赋值为 None 即可。

  • 使用 torch.nn.functional.unfold 函数将数据按照滑动窗口分块:

        对于批量 b、通道数m、分辨率 R\times S 的输入 x,形状为: b\times m\times r\times s,如果卷积核大小(kernel size)、填充大小(padding)、步幅(stride)、空洞率(dilation)分别为 k,p,s,d,那么该函数的输出大小为:b\times(m\times k\times k)\times (R^\prime\times S^\prime),是一个 3 维张量,其中:
R^\prime=\lfloor[R+2p-d(k-1)-1]/s\rfloor+1,\\ S^\prime=\lfloor[S+2p-d(k-1)-1]/s\rfloor+1

  • 使用 torch.einsum 函数对张量按照卷积运算求和

        根据爱因斯坦和式约定,上下标一致的数据可以省略求和号,如:
s=\sum^N_{i=1} a^ib_i=a^ib_i\\
把这一约定用符合表示并计算出来就是 einsum 函数:

torch.einsum('i,i->', (a, b))

比如:

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
s=torch.einsum('i,i->', (a, b))
# s = tensor(14)

        结合以上两个函数,对于形状为 [b, 1, R, S] 的深度信息 z,对于固定的 l,根据公式 (7),用 torch.nn.functional.unfold 函数得到形状为 [b, 1, R^\prime, S^\prime, k, k]m_l(已通过 view 函数改变形状),对所有的 l 拼接 m_l,得到形状为 [k, b, 1, R^\prime, S^\prime, k, k] 的张量。同理,对于 x,经过滑动窗口操作之后形状为 [1, b, m, R^\prime, S^\prime, k, k](额外加第 1 维)。这两个 7 维张量经过元素级的乘法得到新的 7 维张量,形状为 [k, b, m, R^\prime, S^\prime, k, k],然后按照公式 (4)用 torch.einsum 函数对这个 7 维张量和形状为 [k, m, n, k, k] 的权重参数张量求和:torch.einsum('hijklmn,hojmn->iokl', (·, ·)),得到形状为 [b, n, R^\prime, S^\prime] 的输出,其中 n 为输出通道数。

        以上过程就是继承了 torch.autograd.function.Function 类的两个类: DepthKernelFnRepeatKernelConvFnforward 函数的内容。调用这些类时直接使用 .apply() 函数即可。而 backward 函数就是要对 forward 函数的计算利用链式法则求梯度,因此无需赘言。

        当前向传播和反向传播的计算都定义清楚了之后,还需要将它们封装成一个自定义卷积层,这就是类 Conv2_5d,即最终用来调用的 2.5 维卷积层。它的定义跟平时卷积网络的定义类似,都是直接继承 torch.nn.Module 类和重载 forward 函数,不同的是需要定义权重参数和偏置参数(如果需要的话):

self.weight = torch.nn.parameter.Parameter(torch.Tensor(shape))
self.bias = torch.nn.parameter.Parameter(torch.Tensor(shape))

并适当的初始化它们(见 reset_parameters 函数)。

三、2.5 维卷积实现代码的验证

        为了验证以上实现的代码在反向传播时不会报错,定义一个两层的简单网络来验证如下(命名为:conv2_5d_test.py):

# -*- coding: utf-8 -*-
"""
Created on Wed Nov 27 13:41:23 2019

@author: lijingxiong
"""

import torch

import conv2_5d


class ConvTest(torch.nn.Module):
    """A mini networt to test Conv2_5d in forward and backword computation."""
    
    def __init__(self, num_classes=2):
        super(ConvTest, self).__init__()
        
        self._head_conv = conv2_5d.Conv2_5d(in_channels=3, 
                                            out_channels=32, 
                                            kernel_size=5, 
                                            padding=2, 
                                            bias=False)
        self._pred_conv = torch.nn.Conv2d(in_channels=32,
                                          out_channels=num_classes,
                                          kernel_size=3,
                                          padding=1,
                                          bias=False)
        self._batch_norm = torch.nn.BatchNorm2d(num_features=num_classes,
                                                momentum=0.995)
        
    def forward(self, x, z, f=1):
        x = self._head_conv(x, z, f)
        x = self._pred_conv(x)
        x = self._batch_norm(x)
        return x
    
    
if __name__ == '__main__':
    # Device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    model = ConvTest().to(device)
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    
    num_steps = 100
    for i in range(num_steps):
        images = torch.rand((2, 3, 64, 64)).to(device)
        depth = torch.rand((2, 1, 64, 64)).to(device)
        labels = torch.LongTensor(
            torch.full((2, 64, 64), 0, dtype=torch.int64)).to(device)
        
        # Forward pass
        outputs = model(images, depth)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print('Step: {}/{}, Loss: {:.4f}'.format(i+1, num_steps, loss.item()))

        直接执行 :

python3 conv2_5d_test.py

代码正常结束,且损失逐渐减小,(暂时)认为代码是正确的。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,542评论 6 504
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,822评论 3 394
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 163,912评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,449评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,500评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,370评论 1 302
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,193评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,074评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,505评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,722评论 3 335
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,841评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,569评论 5 345
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,168评论 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,783评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,918评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,962评论 2 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,781评论 2 354