上一篇讲了如何载入模型,这里写一下如何使用载入的模型初始化新网络的部分层:
我的理解在于,在pytorch中,模型的参数是按照字典的形式存储的,key为该层的名称,相应的value是这层的参数,理解了之后,其实更新一个新的网络的参数,也就是用一个已经存在的字典(也就是预训练的模型的参数)来更新新的字典(新的模型的参数):
- 网络结构的定义
# DenseNet这个类就是网络denesnet的结构的定义,这里参考了pytorch 里面models的源码
class DenseNet(nn.Module):
... ...(此处网络结构的定义省略
class fw_DenseNet(nn.Module):
... ...(这个是我修改后网络结构)
- 获取网络参数:
'''预训练的模型'''
net = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24))
net.load_state_dict(torch.load('/home/wei.fan/.torch/models/densenet161-17b70270.pth'))
net_dict = net.state_dict() #获取预训练模型的参数
''' 自定义的网络模型'''
net1 = fw_DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24))
net1_dict = net1.state_dict() #获取参数,但其实没有参数(因为没有训练)
- 使用预训练的模型的参数,更新自定义模型的参数
net1_dict = {k: v for k, v in net_dict.items() if k in net1_dict} #把两个模型中名称不同的层去掉
net1_dict.update(net1_dict) #使用预训练模型更新新模型的参数
net1.load_state_dict(net1_dict) #更新模型
参考了这里:How to load part of pre trained model?
其实核心代码只有下面三句,但是因为pytorch的内部机制不清楚,所以搞了蛮久才弄懂,这里贴出来让后来者少踩几个坑吧。
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(pretrained_dict)