前段时间研究了一波关于视频分类的相关代码,这里准备写一系列博客介绍一下有关于视频分类的相关论文及代码。这篇博客主要介绍下Temporal Shift Module for Efficient Video Understanding这篇文章。
paper: TSM: Temporal Shift Module for Efficient Video Understanding
code:① temporal-shift-module (这里是官方的代码推荐
) ② mmaction2 (本文介绍代码地址
,因为有很多注册钩子类的python文件初学者理解起来有一定难度)
在TSM
作为视频分类之前之前主流视频分类还是是用的是3D CNN
, 但是文章也说了计算量很大,不太适用于做在线视频分类,所以针对该问题,作者发明了TSM
模型。该模型分为两个版本一个是在线视频分类,一个是离线视频分类,分别用来进行模型测试以及模型训练用的。
一、 模型原理
1.1 shift 机制
该论文的核心就是如下图所示的,基于在时间维度上的通道移位来增加该端特征时间维度的信息包含。
首先我们先理解下该图的表达,纵轴代表的是
Temporal T
即时间维度,这里我们选用的是抽帧的方式(一般选取8帧或16帧),横坐标代表的是Channel
通道的维度,深度坐标代表的是特征的空间信息即(宽x高
)通过上下移动时间维度的channel信息,可以将时间感受野扩展,图中空白的区域采用的是0填充。
那我们究竟选多少通道数进行移位呢?这边作者是这么说的。
也就是说有前的通道是上一帧的通道特征,后的通道信息被下一帧通道移位,中间通道信息保持不变。这里选用的原因作者在论文中也阐明了原因
主要还是为了保持空间信息和时间信息的平衡,所以才选了。即便shift
操作对于模型来说的虽然增加0 FLOP,但是在内存消耗上还是偏大的,同时shift过多的通道数,损失空间信息,不利于模型表现。下面实验说明shift全部通道效果, 严重影响推理效率:
1.2 Residual shift 机制
发现
Residual TSM
融合了时间信息,效果好于In-place TSM
,In-place
损失了空间特征学习的能力。
1.3 整体模型机理
通过上图就很容易理解模型在对视频分类的原理了。首先通过对每一帧进行上述的shift操作,在进行卷积块操作即可(后面代码会清晰梳理原理),这里需要注意的是最终输出我们采用的是全局平均池化,得到特征在经过
fc(fully connected)层
输出模型类别的概率矩阵。
二、代码理解
这里的代码主要通过mmaction2
代码框架进行行为分类。这里主要通过三部分来介绍TSM模型训练代码分别是:①数据处理 ②模型结构 ③损失计算
2.1 数据处理
先看下我们的数据的配置文件
_base_ = [
'../../_base_/models/tsm_r50.py', '../../_base_/schedules/sgd_tsm_50e.py',
'../../_base_/default_runtime.py'
]
dataset_type = 'VideoDataset'
data_root = ""
data_root_val = ""
ann_file_train = '/data/humaocheng/action_classification/ava_datasets/train_data.txt'
ann_file_val = '/data/humaocheng/action_classification/ava_datasets/test_data.txt'
ann_file_test = '/data/humaocheng/action_classification/ava_datasets/test_data.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=2,
test_dataloader=dict(videos_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'])
# runtime settings
checkpoint_config = dict(interval=5)
work_dir = './work_dirs/tsm_r50_1x1x8_100e_kinetics400_rgb/'
注意在训练自己的数据集的时候需要将下面的num_classes的类比数量进行修改
# rsm_r50.py
# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNetTSM',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False,
shift_div=8),
cls_head=dict(
type='TSMHead',
num_classes=7,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.5,
init_std=0.001,
is_shift=True),
# model training and testing settings
train_cfg=None,
test_cfg=dict(average_clips='prob'))
# optimizer
optimizer = dict(
type='SGD',
lr=0.01, # this lr is used for 8 gpus
momentum=0.9,
weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[20, 40])
total_epochs = 50
首先要我们看train_pipeline
是怎么处理视频数据集的
train_pipeline = [
dict(type='DecordInit'),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
1. DecordInit, DecordDecode
这里的DecordInit
, DecordDecode
主要是用来做视频读写,这里用的对视频进行处理涉及到pyi
代码(涉及C++
加速视频读等处理)
@PIPELINES.register_module()
class DecordInit:
"""Using decord to initialize the video_reader.
Decord: https://github.com/dmlc/decord
Required keys are "filename",
added or modified keys are "video_reader" and "total_frames".
Args:
io_backend (str): io backend where frames are store.
Default: 'disk'.
num_threads (int): Number of thread to decode the video. Default: 1.
kwargs (dict): Args for file client.
"""
def __init__(self, io_backend='disk', num_threads=1, **kwargs):
self.io_backend = io_backend
self.num_threads = num_threads
self.kwargs = kwargs
self.file_client = None
def __call__(self, results):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
try:
import decord
except ImportError:
raise ImportError(
'Please run "pip install decord" to install Decord first.')
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
file_obj = io.BytesIO(self.file_client.get(results['filename']))
container = decord.VideoReader(file_obj, num_threads=self.num_threads)
results['video_reader'] = container
results['total_frames'] = len(container)
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'io_backend={self.io_backend}, '
f'num_threads={self.num_threads})')
return repr_str
@PIPELINES.register_module()
class DecordDecode:
"""Using decord to decode the video.
Decord: https://github.com/dmlc/decord
Required keys are "video_reader", "filename" and "frame_inds",
added or modified keys are "imgs" and "original_shape".
Args:
mode (str): Decoding mode. Options are 'accurate' and 'efficient'.
If set to 'accurate', it will decode videos into accurate frames.
If set to 'efficient', it will adopt fast seeking but only return
key frames, which may be duplicated and inaccurate, and more
suitable for large scene-based video datasets. Default: 'accurate'.
"""
def __init__(self, mode='accurate'):
self.mode = mode
assert mode in ['accurate', 'efficient']
def __call__(self, results):
"""Perform the Decord decoding.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
container = results['video_reader']
if results['frame_inds'].ndim != 1:
results['frame_inds'] = np.squeeze(results['frame_inds'])
frame_inds = results['frame_inds']
if self.mode == 'accurate':
imgs = container.get_batch(frame_inds).asnumpy()
imgs = list(imgs)
elif self.mode == 'efficient':
# This mode is faster, however it always returns I-FRAME
container.seek(0)
imgs = list()
for idx in frame_inds:
container.seek(idx)
frame = container.next()
imgs.append(frame.asnumpy())
results['video_reader'] = None
del container
results['imgs'] = imgs
results['original_shape'] = imgs[0].shape[:2]
results['img_shape'] = imgs[0].shape[:2]
return results
def __repr__(self):
repr_str = f'{self.__class__.__name__}(mode={self.mode})'
return repr_str
2. SampleFrames
这里我们使用SampleFrames,参数为dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
这里的意思相当于把一段视频等分8份,在每一份的视频中随机抽取帧,这样总共就能抽取8帧视频。
@PIPELINES.register_module()
class SampleFrames:
"""Sample frames from the video.
Required keys are "total_frames", "start_index" , added or modified keys
are "frame_inds", "frame_interval" and "num_clips".
Args:
clip_len (int): Frames of each sampled output clip.
frame_interval (int): Temporal interval of adjacent sampled frames.
Default: 1.
num_clips (int): Number of clips to be sampled. Default: 1.
temporal_jitter (bool): Whether to apply temporal jittering.
Default: False.
twice_sample (bool): Whether to use twice sample when testing.
If set to True, it will sample frames with and without fixed shift,
which is commonly used for testing in TSM model. Default: False.
out_of_bound_opt (str): The way to deal with out of bounds frame
indexes. Available options are 'loop', 'repeat_last'.
Default: 'loop'.
test_mode (bool): Store True when building test or validation dataset.
Default: False.
start_index (None): This argument is deprecated and moved to dataset
class (``BaseDataset``, ``VideoDatset``, ``RawframeDataset``, etc),
see this: https://github.com/open-mmlab/mmaction2/pull/89.
keep_tail_frames (bool): Whether to keep tail frames when sampling.
Default: False.
"""
def __init__(self,
clip_len,
frame_interval=1,
num_clips=1,
temporal_jitter=False,
twice_sample=False,
out_of_bound_opt='loop',
test_mode=False,
start_index=None,
keep_tail_frames=False):
self.clip_len = clip_len
self.frame_interval = frame_interval
self.num_clips = num_clips
self.temporal_jitter = temporal_jitter
self.twice_sample = twice_sample
self.out_of_bound_opt = out_of_bound_opt
self.test_mode = test_mode
self.keep_tail_frames = keep_tail_frames
assert self.out_of_bound_opt in ['loop', 'repeat_last']
if start_index is not None:
warnings.warn('No longer support "start_index" in "SampleFrames", '
'it should be set in dataset class, see this pr: '
'https://github.com/open-mmlab/mmaction2/pull/89')
def _get_train_clips(self, num_frames):
"""Get clip offsets in train mode.
It will calculate the average interval for selected frames,
and randomly shift them within offsets between [0, avg_interval].
If the total number of frames is smaller than clips num or origin
frames length, it will return all zero indices.
Args:
num_frames (int): Total number of frame in the video.
Returns:
np.ndarray: Sampled frame indices in train mode.
"""
ori_clip_len = self.clip_len * self.frame_interval
if self.keep_tail_frames:
avg_interval = (num_frames - ori_clip_len + 1) / float(
self.num_clips)
if num_frames > ori_clip_len - 1:
base_offsets = np.arange(self.num_clips) * avg_interval
clip_offsets = (base_offsets + np.random.uniform(
0, avg_interval, self.num_clips)).astype(np.int)
else:
clip_offsets = np.zeros((self.num_clips, ), dtype=np.int)
else:
avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips
if avg_interval > 0:
base_offsets = np.arange(self.num_clips) * avg_interval
clip_offsets = base_offsets + np.random.randint(
avg_interval, size=self.num_clips)
elif num_frames > max(self.num_clips, ori_clip_len):
clip_offsets = np.sort(
np.random.randint(
num_frames - ori_clip_len + 1, size=self.num_clips))
elif avg_interval == 0:
ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips
clip_offsets = np.around(np.arange(self.num_clips) * ratio)
else:
clip_offsets = np.zeros((self.num_clips, ), dtype=np.int)
return clip_offsets
def _get_test_clips(self, num_frames):
"""Get clip offsets in test mode.
Calculate the average interval for selected frames, and shift them
fixedly by avg_interval/2. If set twice_sample True, it will sample
frames together without fixed shift. If the total number of frames is
not enough, it will return all zero indices.
Args:
num_frames (int): Total number of frame in the video.
Returns:
np.ndarray: Sampled frame indices in test mode.
"""
ori_clip_len = self.clip_len * self.frame_interval
avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips)
if num_frames > ori_clip_len - 1:
base_offsets = np.arange(self.num_clips) * avg_interval
clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int)
if self.twice_sample:
clip_offsets = np.concatenate([clip_offsets, base_offsets])
else:
clip_offsets = np.zeros((self.num_clips, ), dtype=np.int)
return clip_offsets
def _sample_clips(self, num_frames):
"""Choose clip offsets for the video in a given mode.
Args:
num_frames (int): Total number of frame in the video.
Returns:
np.ndarray: Sampled frame indices.
"""
if self.test_mode:
clip_offsets = self._get_test_clips(num_frames)
else:
clip_offsets = self._get_train_clips(num_frames)
return clip_offsets
def __call__(self, results):
"""Perform the SampleFrames loading.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
total_frames = results['total_frames']
clip_offsets = self._sample_clips(total_frames)
frame_inds = clip_offsets[:, None] + np.arange(
self.clip_len)[None, :] * self.frame_interval
frame_inds = np.concatenate(frame_inds)
if self.temporal_jitter:
perframe_offsets = np.random.randint(
self.frame_interval, size=len(frame_inds))
frame_inds += perframe_offsets
frame_inds = frame_inds.reshape((-1, self.clip_len))
if self.out_of_bound_opt == 'loop':
frame_inds = np.mod(frame_inds, total_frames)
elif self.out_of_bound_opt == 'repeat_last':
safe_inds = frame_inds < total_frames
unsafe_inds = 1 - safe_inds
last_ind = np.max(safe_inds * frame_inds, axis=1)
new_inds = (safe_inds * frame_inds + (unsafe_inds.T * last_ind).T)
frame_inds = new_inds
else:
raise ValueError('Illegal out_of_bound option.')
start_index = results['start_index']
frame_inds = np.concatenate(frame_inds) + start_index
results['frame_inds'] = frame_inds.astype(np.int)
results['clip_len'] = self.clip_len
results['frame_interval'] = self.frame_interval
results['num_clips'] = self.num_clips
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'clip_len={self.clip_len}, '
f'frame_interval={self.frame_interval}, '
f'num_clips={self.num_clips}, '
f'temporal_jitter={self.temporal_jitter}, '
f'twice_sample={self.twice_sample}, '
f'out_of_bound_opt={self.out_of_bound_opt}, '
f'test_mode={self.test_mode})')
return repr_str
3. Resize, MultiScaleCrop, Flip, Normalize,FormatShape,Collect,ToTensor
这些比较浅显易懂的意思这里就不做过多的介绍了,需要强调的是Collect
主要是将视频和标签进行数据集整理。最终我们会得到的结果为:
imgs
: shape [8, 8, 3, 224, 224]
, 这里的数字分别代表Batch Size
, 帧长度
,通道数
, 帧高
, 帧宽
labels
: shape[8,1]
,这里的数字分别代表Batch Size
, 类别
。
2.2 模型结构
2.2.1 Pipeline
主要的pipeline
代码如下:
@RECOGNIZERS.register_module()
class Recognizer2D(BaseRecognizer):
"""2D recognizer model framework."""
def forward_train(self, imgs, labels, **kwargs):
"""Defines the computation performed at every call when training."""
assert self.with_cls_head
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
losses = dict()
x = self.extract_feat(imgs)
if self.backbone_from in ['torchvision', 'timm']:
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, loss_aux = self.neck(x, labels.squeeze())
x = x.squeeze(2)
num_segs = 1
losses.update(loss_aux)
cls_score = self.cls_head(x, num_segs) # x shape [64, 2048, 7, 7]
gt_labels = labels.squeeze()
loss_cls = self.cls_head.loss(cls_score, gt_labels, **kwargs)
losses.update(loss_cls)
return losses
① batches
: shape 为 8
② imgs = imgs.reshape((-1, ) + imgs.shape[2:])
,这里的img
一开始为(8, 8, 3, 224, 224)
, 之后再将batch
维度与frame
维度进行合并最终img
维度为(64, 3, 224, 224)
③ num_segs
代表的是帧数
④ extract_feat
代表的是进行特征提取
2.2.2 特征提取
接着会进入到特征提取模块
def forward(self, x):
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The feature of the input samples extracted
by the backbone.
"""
# x shape [64, 3, 224, 224]
x = self.conv1(x)
# x shape [64, 64, 56, 56]
x = self.maxpool(x)
# x shape [64, 64, 56, 56]
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
return tuple(outs)
在一开始会进入conv
以及maxpool
模块中去,之后会进入到四个res_layer
模块中。关于这四个res_layer
是如何组建的看如下代码
def make_temporal_shift(self):
"""Make temporal shift for some layers."""
if self.temporal_pool:
num_segment_list = [
self.num_segments, self.num_segments // 2,
self.num_segments // 2, self.num_segments // 2
]
else:
num_segment_list = [self.num_segments] * 4
if num_segment_list[-1] <= 0:
raise ValueError('num_segment_list[-1] must be positive')
if self.shift_place == 'block':
def make_block_temporal(stage, num_segments):
"""Make temporal shift on some blocks.
Args:
stage (nn.Module): Model layers to be shifted.
num_segments (int): Number of frame segments.
Returns:
nn.Module: The shifted blocks.
"""
blocks = list(stage.children())
for i, b in enumerate(blocks):
blocks[i] = TemporalShift(
b, num_segments=num_segments, shift_div=self.shift_div)
return nn.Sequential(*blocks)
self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])
elif 'blockres' in self.shift_place:
n_round = 1
if len(list(self.layer3.children())) >= 23:
n_round = 2
def make_block_temporal(stage, num_segments):
"""Make temporal shift on some blocks.
Args:
stage (nn.Module): Model layers to be shifted.
num_segments (int): Number of frame segments.
Returns:
nn.Module: The shifted blocks.
"""
blocks = list(stage.children())
for i, b in enumerate(blocks):
if i % n_round == 0:
blocks[i].conv1.conv = TemporalShift(
b.conv1.conv,
num_segments=num_segments,
shift_div=self.shift_div)
return nn.Sequential(*blocks)
self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])
else:
raise NotImplementedError
通过make_temporal_shift
函数我们可以得到num_segment_list
为[8,8,8,8]
我们使用的是blokres
- 构建self.layer1, self.layer2, self.layer3, self.layer4
首先根据make_block_temporal
函数构建的,我们首先看下在输入之前self.layer1, self.layer2, self.layer3, self.layer4为残差卷积块构成。下面展示下其卷积内部结构。之后加入了TemporalShift
机制再看下self.layer1, self.layer2, self.layer3, self.layer4
layer1
:
Sequential(
(0): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
(downsample): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
layer2
:
Sequential(
(0): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
(downsample): ConvModule(
(conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
layer3
:
Sequential(
(0): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
(downsample): ConvModule(
(conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(4): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(5): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
layer4
:
Sequential(
(0): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
(downsample): ConvModule(
(conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): ConvModule(
(conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
- TemporalShift
这里的TemporalShift
的逻辑代码如下所示:
class TemporalShift(nn.Module):
"""Temporal shift module.
This module is proposed in
`TSM: Temporal Shift Module for Efficient Video Understanding
<https://arxiv.org/abs/1811.08383>`_
Args:
net (nn.module): Module to make temporal shift.
num_segments (int): Number of frame segments. Default: 3.
shift_div (int): Number of divisions for shift. Default: 8.
"""
def __init__(self, net, num_segments=3, shift_div=8):
super().__init__()
self.net = net
self.num_segments = num_segments
self.shift_div = shift_div
def forward(self, x):
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
# self.num_segments = 8
x = self.shift(x, self.num_segments, shift_div=self.shift_div)
return self.net(x)
@staticmethod
def shift(x, num_segments, shift_div=3):
"""Perform temporal shift operation on the feature.
Args:
x (torch.Tensor): The input feature to be shifted.
num_segments (int): Number of frame segments.
shift_div (int): Number of divisions for shift. Default: 3.
Returns:
torch.Tensor: The shifted feature.
"""
# [N, C, H, W]
n, c, h, w = x.size()
# n=64, h=56, w=56, c=64
# [N // num_segments, num_segments, C, H*W]
# can't use 5 dimensional array on PPL2D backend for caffe
x = x.view(-1, num_segments, c, h * w) # x shape [8, 8, 64, 3136]
# get shift fold
fold = c // shift_div
# split c channel into three parts:
# left_split, mid_split, right_split
left_split = x[:, :, :fold, :] # shape[8, 8, 8, 3136]
mid_split = x[:, :, fold:2 * fold, :] # shape[8, 8, 8, 3136]
right_split = x[:, :, 2 * fold:, :] # shape [8, 8, 48, 3136]
# can't use torch.zeros(*A.shape) or torch.zeros_like(A)
# because array on caffe inference must be got by computing
# SHAPE [BATCH_SIZE, CHANNEL, TIME, HEIGHT*WIDTH]
# shift left on num_segments channel in `left_split`
zeros = left_split - left_split
blank = zeros[:, :1, :, :] # shape [8, 1, 8, 3136]
left_split = left_split[:, 1:, :, :] # shape [8, 7, 8, 3136]
left_split = torch.cat((left_split, blank), 1)
# shift right on num_segments channel in `mid_split`
zeros = mid_split - mid_split
blank = zeros[:, :1, :, :]
mid_split = mid_split[:, :-1, :, :]
mid_split = torch.cat((blank, mid_split), 1)
# right_split: no shift
# concatenate
out = torch.cat((left_split, mid_split, right_split), 2)
# [N, C, H, W]
# restore the original dimension
return out.view(n, c, h, w)
上述代码则是该模型的核心。
x = x.view(-1, num_segments, c, h * w)
, 则是将我们的长宽进行合并,这样就得到了其shape为[8, 8, 64, 3136]
。
fold = c // shift_div
这里的c
表示的, shift_div
为8
, 得到fold为8。这里就是我们之前说的,将64个通道数分为8份,每份通道数为8,分后移部分1份(left_split
), 前移部分1份(mid_split
), 不变的6份(right_split)最后将其都concate。
left_split = x[:, :, :fold, :] # shape[8, 8, 8, 3136]
mid_split = x[:, :, fold:2 * fold, :] # shape[8, 8, 8, 3136]
right_split = x[:, :, 2 * fold:, :] # shape [8, 8, 48, 3136]
# shift left on num_segments channel in `left_split`
zeros = left_split - left_split
blank = zeros[:, :1, :, :] # shape [8, 1, 8, 3136]
left_split = left_split[:, 1:, :, :] # shape [8, 7, 8, 3136]
left_split = torch.cat((left_split, blank), 1)
# shift right on num_segments channel in `mid_split`
zeros = mid_split - mid_split
blank = zeros[:, :1, :, :]
mid_split = mid_split[:, :-1, :, :]
mid_split = torch.cat((blank, mid_split), 1)
# right_split: no shift
# concatenate
out = torch.cat((left_split, mid_split, right_split), 2)
- TemporalShift+resConv我们来看看模型layer1, layer2, layer3, layer4的样子
layer1
:
Sequential(
(0): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
(downsample): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
layer2
:
Sequential(
(0): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
(downsample): ConvModule(
(conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
layer3
:
Sequential(
(0): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
(downsample): ConvModule(
(conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(4): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(5): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
layer4
:
Sequential(
(0): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
(downsample): ConvModule(
(conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): ConvModule(
(conv): TemporalShift(
(net): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv2): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(conv3): ConvModule(
(conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(relu): ReLU(inplace=True)
)
)
- 输出
最终输出矩阵特征shape大小为[64, 2048, 7, 7]
2.2.3 head
这里的head
很简单就是类似于fc
层,将2048
维度直接降到对应的类别数目
, 为了后期做损失函数的计算, 但是这里需要注意的是我们在输如到head
的特征shape为[N * num_segs, in_channels, 7, 7]
(N为8, num_segs为8),
因此我们在这里是对7x7
的空间特征做了空间的卷据平均池化使其shape为[N * num_segs, in_channels, 1, 1]
, 并通过flatten使其shape为[64, 2048]
, 接着输入到fc
得到shape为[64, classes]
,然后分类batch所以最终shape为[N, classes ]
具体代码如下所示。
@HEADS.register_module()
class TSMHead(BaseHead):
"""Class head for TSM.
Args:
num_classes (int): Number of classes to be classified.
in_channels (int): Number of channels in input feature.
num_segments (int): Number of frame segments. Default: 8.
loss_cls (dict): Config for building loss.
Default: dict(type='CrossEntropyLoss')
spatial_type (str): Pooling type in spatial dimension. Default: 'avg'.
consensus (dict): Consensus config dict.
dropout_ratio (float): Probability of dropout layer. Default: 0.4.
init_std (float): Std value for Initiation. Default: 0.01.
is_shift (bool): Indicating whether the feature is shifted.
Default: True.
temporal_pool (bool): Indicating whether feature is temporal pooled.
Default: False.
kwargs (dict, optional): Any keyword argument to be used to initialize
the head.
"""
def __init__(self,
num_classes,
in_channels,
num_segments=8,
loss_cls=dict(type='CrossEntropyLoss'),
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.8,
init_std=0.001,
is_shift=True,
temporal_pool=False,
**kwargs):
super().__init__(num_classes, in_channels, loss_cls, **kwargs)
self.spatial_type = spatial_type
self.dropout_ratio = dropout_ratio
self.num_segments = num_segments
self.init_std = init_std
self.is_shift = is_shift
self.temporal_pool = temporal_pool
consensus_ = consensus.copy()
consensus_type = consensus_.pop('type')
if consensus_type == 'AvgConsensus':
self.consensus = AvgConsensus(**consensus_)
else:
self.consensus = None
if self.dropout_ratio != 0:
self.dropout = nn.Dropout(p=self.dropout_ratio)
else:
self.dropout = None
self.fc_cls = nn.Linear(self.in_channels, self.num_classes)
if self.spatial_type == 'avg':
# use `nn.AdaptiveAvgPool2d` to adaptively match the in_channels.
self.avg_pool = nn.AdaptiveAvgPool2d(1)
else:
self.avg_pool = None
def init_weights(self):
"""Initiate the parameters from scratch."""
normal_init(self.fc_cls, std=self.init_std)
def forward(self, x, num_segs):
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
num_segs (int): Useless in TSMHead. By default, `num_segs`
is equal to `clip_len * num_clips * num_crops`, which is
automatically generated in Recognizer forward phase and
useless in TSM models. The `self.num_segments` we need is a
hyper parameter to build TSM models.
Returns:
torch.Tensor: The classification scores for input samples.
"""
# [N * num_segs, in_channels, 7, 7]
if self.avg_pool is not None: # x shape [64, 2048, 7, 7]
x = self.avg_pool(x) # 全局平均池化 空间为1x1, 通道不变
# x shape [64, 2048, 1, 1]
# [N * num_segs, in_channels, 1, 1]
x = torch.flatten(x, 1) # x shape [64, 2048]
# [N * num_segs, in_channels]
if self.dropout is not None:
x = self.dropout(x)
# [N * num_segs, num_classes]
cls_score = self.fc_cls(x) # cls shape [64, 400]
if self.is_shift and self.temporal_pool:
# [2 * N, num_segs // 2, num_classes]
cls_score = cls_score.view((-1, self.num_segments // 2) +
cls_score.size()[1:])
else:
# [N, num_segs, num_classes]
cls_score = cls_score.view((-1, self.num_segments) +
cls_score.size()[1:]) # batch 分离
# cls shape [8, 8, 400]
# [N, 1, num_classes]
cls_score = self.consensus(cls_score)
# [N, num_classes]
return cls_score.squeeze(1)
2.3 损失函数
这里的损失函数文章使用的是CrossEntropy损失函数。
感悟:
这里的模型可以看出越深层的卷积包含更多前后帧的信息(动作语义信息),越浅层的包含越少帧的信息(空间信息), 想想模型涉及的有那么点意思~