deformable cnn的流程无外乎每一个点生成常规cnn kernel点的偏移量以及对应的权重,然后通过该偏移量去选择合适的位置,然后对这些位置的特征进行kernel的卷积。鉴于kernel_size一般很小,所以这部分可以直接由循环实现,于是很容易的利用纯pytorch实现deformable cnn的过程。
import torch
import torch.nn as nn
class MyDCN(nn.Module):
def __init__(self, in_channel, out_channel, ks=3, dilation=1, offset_ks=3, offset_stride=1, offset_dilation=1, bias=True):
super(MyDCN, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.dilation_patch = offset_dilation
self.offset_conv = nn.Conv2d(in_channel, 3*ks**2, kernel_size=offset_ks, stride=offset_stride, padding=offset_ks//2, dilation=offset_dilation)
self.conv = nn.Conv2d(in_channel*ks**2, out_channel, kernel_size=1, padding=0, stride=1)
self.d = int(offset_dilation*(ks-1)/2)
self.init_parameters()
def init_parameters(self):
for c in [self.conv, self.offset_conv]:
nn.init.xavier_normal_(c.weight)
if c.bias is not None:
nn.init.constant_(c.bias, 0)
def make_grid(self, h, w):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
return torch.stack((x,y), dim=0).unsqueeze(0).float() # 坐标需要是 xy
def forward(self, x):
offset_preds = self.offset_conv(x)
h, w = x.size(2), x.size(3)
self.grid = self.make_grid(h, w).to(x.device)
feat_selections = []
k=0
for i in range(-self.d, self.d+1, self.dilation_patch): #
for j in range(-self.d, self.d+1, self.dilation_patch):
tmp = self.grid + offset_preds[:, k:k+2] # 若 offset_preds为0的话,退化为local attention
tmp = tmp.permute(0, 2, 3, 1)
tmp[..., 0] = (tmp[...,0] + i)/(w-1.0)
tmp[..., 1] = (tmp[...,1] + j)/(h-1.0)
tmp = tmp * 2.0 - 1.0 # [-1, 1]
feat_selections.append(F.grid_sample(x, tmp, mode="bilinear", padding_mode="border", align_corners=True)*torch.sigmoid(offset_preds[:, k+2:k+3]))
k += 3
# feat_selections = torch.stack(feat_selections, dim=1).permute(0, 2, 1, 3, 4).view(x.size(0), -1, h, w).contiguous()
feat_selections = torch.cat(feat_selections, dim=1)
out = self.conv(feat_selections)
return out
if __name__=="__main__":
net = MyDCN(8, 4).cuda()
x = torch.rand(2, 8, 20, 14).cuda()
y = net(x)
print(y.size())
z = y.sum()
z.backward()
这里没有添加group参数,该参数的添加可以直接在conv这一层添加,此时需要修改feat_selections中value的排列方式,见注释掉部分。