引言
Xception是google在inception之后提出的对inceptionV3的另一种改进,主要采用depthwise separable convolution来替换原来的inception v3中的卷积操作。
思考
要解决什么问题?怎么解决的?
- 探寻Inception的基本思路
- 从Inception发展历程的角度,理解其基本思想,并引入与Inception类似的Depthwise Separable Convolution结构。
- 将Inception V3结构中的Inception改用Depthwise Separable Convolution。
效果怎么样?
- 在与Inception V3参数数量相差无几的情况下,在ImageNet上性能有略微上升,JFT上有明显提高。
还存在什么问题?
- Depthwise Separable Convolution不一定就是最优结构,还有尚未探索、验证的相似结构。
相关知识
Xception是inception系列的成员之一。
inception与普通的卷积操作相比,具有更强的表达能力。
inception系列的复习
inception结构如图1所示,用三种conv计算合并之后代替原来的conv。
选用卷积核
- 为了块对齐。卷积核大小设置为1,3,5,卷积步长设为1时,pad要分别设为0,1,2,则卷积后得到的特征图维度相同。
- 将相关性高的神经元连接在一起。
以上存在一个问题: 的卷积核计算量较大,因此采用的卷积核进行降维,如图2所示。
两个3 3的卷积核可以替代5 5的卷积核,因此结构变为图3。
以上模块主要在inceptionV3中,inceptionV3的基本结构为:
input
conv2d(32, 3, 3, s=2) #conv2d_1a
conv2d(32, 3, 3) #conv2d_2a
conv2d(64, 3, 3, 'SAME') #conv2d_2b
max_pool2d(3, 3, ,s=2) #maxpool_3a
conv2d(1, 1, 80) #conv2d_3b
conv2d(3, 3, 192) #conv2d_4a
max_pool2d(3, 3, s=2) #maxpool_5a
conv2d(1, 1, 64) conv2d(1, 1, 48) conv2d(1, 1, 64) avgpool(3, 3)
conv2d(5, 5, 64) conv2d(3, 3, 96) conv2d(1, 1, 32)
conv2d(3, 3, 96)
concat
*9
conv2d(1, 1, num_class)
在以上模块中,对于一个conv层来说,需要学习的是一个3D的卷积核,其中包括两个空间维度和一个通道维度,即w,h,c。这个卷积核与输入在3个维度上进行卷积操作,得到最终的结果,伪代码如下:
// 对于第i个filter
// 计算输入中心点(x, y)对应的卷积结果
sum = 0
for c in 1:C
for h in 1:K
for w in 1:K
sum += input[c, y-K/2+h, x-K/2+w] * filter_i[c, h, w]
out[i, y, x] = sum
可以看出在3D的卷积中,通道这个维度与空间的两个维度是一样的。
先用一个统一的的卷积核卷积,然后连接三个的卷积核,如图4所示。这3个卷积操作只将前面的卷积结果中的一部分作为自己的输入。图中是将1/3通道作为每个卷积核的输入。
再将卷积核的个数延伸到与卷积核输出通道的个数一样,即每个的卷积核和1个输入通道做卷积,如图5所示。
Xception
Xception主要使用depthwise separable convolution,即将传统的卷积操作分成两步:
-
depthwise convolution
M个的卷积核一对一卷积输入的M个特征图,不求和,生成M个结果。
pointwise convolution
用N个的卷积核正常卷积前面生成的M个结果。
depthwise separable convolution和以上结构的不同之处:
- 操作的顺序不同。depthwise separable conv的实现是先使用channelwise的filter只在spatial dimension上做卷积,再使用1×1的卷积核做跨channel的融合。而Inception中先使用1×1的卷积核。
- 非线性变换的缺席。在Inception中,每个conv操作后面都有ReLU的非线性变换,而depthwise separable conv没有。
Xception结构是将ResNet的相关卷积变成了depthwise separable conv,如下图所示。其中SeparableConv是depthwise separable conv模块。另外,原来的concat变成了residual connection。
参考文献
[1] Xception: Deep Learning with Depthwise Separable Convolutions
代码分析
### Xception.py
from keras.preprocessing import image
from keras.models import Model
from keras import layers
from keras.layers import Dense
from keras.layers import Input
from keras.layers import BatchNormalization
from keras.layers import Activation
from keras.layers import Conv2D
from keras.layers import SeparableConv2D
from keras.layers import MaxPooling2D
from keras.layers import GlobalAveragePooling2D
import tensorflow as tf
input_tensor = tf.ones([1, 224, 224, 3])
input_shape = [224, 224, 3]
img_input = Input(tensor=input_tensor, shape=input_shape)
x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(img_input) #(1, 112, 112, 32)
x = BatchNormalization(name='block1_conv1_bn')(x)
x = Activation('relu', name='block1_conv1_act')(x)
x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x) #(1, 109, 109, 64)
x = BatchNormalization(name='block1_conv2_bn')(x)
x = Activation('relu', name='block1_conv2_act')(x)
residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual) #(1, 55, 55, 128)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
x = BatchNormalization(name='block2_sepconv1_bn')(x)
x = Activation('relu', name='block2_sepconv2_act')(x)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
x = BatchNormalization(name='block2_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x)
x = layers.add([x, residual])
residual = Conv2D(256, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block3_sepconv1_act')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
x = BatchNormalization(name='block3_sepconv1_bn')(x)
x = Activation('relu', name='block3_sepconv2_act')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
x = BatchNormalization(name='block3_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x)
x = layers.add([x, residual])
residual = Conv2D(728, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block4_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
x = BatchNormalization(name='block4_sepconv1_bn')(x)
x = Activation('relu', name='block4_sepconv2_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
x = BatchNormalization(name='block4_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x)
x = layers.add([x, residual])
for i in range(8):
residual = x
prefix = 'block' + str(i + 5)
x = Activation('relu', name=prefix + '_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x)
x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
x = Activation('relu', name=prefix + '_sepconv2_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x)
x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
x = Activation('relu', name=prefix + '_sepconv3_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x)
x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)
x = layers.add([x, residual])
residual = Conv2D(1024, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block13_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
x = BatchNormalization(name='block13_sepconv1_bn')(x)
x = Activation('relu', name='block13_sepconv2_act')(x)
x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
x = BatchNormalization(name='block13_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x)
x = layers.add([x, residual])
x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
x = BatchNormalization(name='block14_sepconv1_bn')(x)
x = Activation('relu', name='block14_sepconv1_act')(x)
x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
x = BatchNormalization(name='block14_sepconv2_bn')(x)
x = Activation('relu', name='block14_sepconv2_act')(x)
if include_top:
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Dense(classes, activation='softmax', name='predictions')(x)
else:
if pooling == 'avg':
x = GlobalAveragePooling2D()(x)
elif pooling == 'max':
x = GlobalMaxPooling2D()(x)
if input_tensor is not None:
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
model = Model(inputs, x, name='xception')
if weights == 'imagenet':
if include_top:
weights_path = get_file('xception_weights_tf_dim_ordering_tf_kernels.h5',
TF_WEIGHTS_PATH,
cache_subdir='models')
else:
weights_path = get_file('xception_weights_tf_dim_ordering_tf_kernels_notop.h5',
TF_WEIGHTS_PATH_NO_TOP,
cache_subdir='models')
model.load_weights(weights_path)
if old_data_format:
K.set_image_data_format(old_data_format)
return model
[1] 代码参考