20、PPM模块
论文《Pyramid Scene Parsing Network》
1、作用
金字塔场景解析网络(PSPNet)利用金字塔池化模块聚合不同区域的全局上下文信息,有效提高了场景解析的准确性。PSPNet能够处理开放词汇表中的复杂场景,并对每个像素进行精确的分类预测。
2、机制
PSPNet通过在深度卷积网络(如ResNet)的最后一层卷积特征图上应用金字塔池化模块来捕获不同尺度的全局上下文信息。金字塔池化模块包含不同层级的池化操作,每个层级针对特征图的不同子区域进行池化,以获取从局部到全局的上下文信息。通过这种方式,模型可以整合丰富的场景信息,改善场景解析的性能。
3、独特优势
1、全局上下文整合:
通过金字塔池化模块整合全局和局部上下文信息,有效捕获场景的多尺度特征,提高了对复杂场景的理解能力。
2、高准确性:
PSPNet在多个场景解析基准测试中取得了优异的性能,包括ImageNet场景解析挑战、PASCAL VOC 2012基准和Cityscapes基准,显示出其高准确性和泛化能力。
3、有效的优化策略:
引入深度监督损失(deeply supervised loss),提供了一个有效的优化策略,使得基于ResNet的FCN网络能够实现更好的训练效果。
4、代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class PPM(nn.Module): # pspnet中的金字塔池化模块
def __init__(self, down_dim):
super(PPM, self).__init__()
self.down_conv = nn.Sequential(
nn.Conv2d(2048, down_dim, 3, padding=1),
nn.BatchNorm2d(down_dim),
nn.PReLU()
)
# 使用不同尺度的自适应平均池化,并通过1x1卷积来减少特征维度
self.conv1 = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Conv2d(down_dim, down_dim, kernel_size=1),
nn.BatchNorm2d(down_dim),
nn.PReLU()
)
self.conv2 = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(2, 2)),
nn.Conv2d(down_dim, down_dim, kernel_size=1),
nn.BatchNorm2d(down_dim),
nn.PReLU()
)
self.conv3 = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(3, 3)),
nn.Conv2d(down_dim, down_dim, kernel_size=1),
nn.BatchNorm2d(down_dim),
nn.PReLU()
)
self.conv4 = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(6, 6)),
nn.Conv2d(down_dim, down_dim, kernel_size=1),
nn.BatchNorm2d(down_dim),
nn.PReLU()
)
# 融合不同尺度的特征
self.fuse = nn.Sequential(
nn.Conv2d(4 * down_dim, down_dim, kernel_size=1),
nn.BatchNorm2d(down_dim),
nn.PReLU()
)
def forward(self, x):
x = self.down_conv(x) # 降维
conv1 = self.conv1(x) # 1x1尺度
conv2 = self.conv2(x) # 2x2尺度
conv3 = self.conv3(x) # 3x3尺度
conv4 = self.conv4(x) # 6x6尺度
# 将池化后的特征上采样到输入特征相同的尺寸,并进行融合
conv1_up = F.upsample(conv1, size=x.size()[2:], mode='bilinear', align_corners=True)
conv2_up = F.upsample(conv2, size=x.size()[2:], mode='bilinear', align_corners=True)
conv3_up = F.upsample(conv3, size=x.size()[2:], mode='bilinear', align_corners=True)
conv4_up = F.upsample(conv4, size=x.size()[2:], mode='bilinear', align_corners=True)
return self.fuse(torch.cat((conv1_up, conv2_up, conv3_up, conv4_up), 1)) # 在通道维度上进行拼接并通过1x1卷积融合
# 测试用例
if __name__ == "__main__":
# 假设输入特征维度是2048,我们想要降维到512
ppm = PPM(down_dim=512)
input_tensor = torch.randn(64, 2048, 32, 32) # 模拟输入
output_tensor = ppm(input_tensor)
print("Output shape:", output_tensor.shape) # 打印输出维度