在 Pytorch 框架下,如何手动修改训练的模型的参数?
我们以两个模型参数加权平均为例,步骤如下:
Step 1: Load Model
model0 = torch.load('*.pth')
model1 = torch.load('*.pth')
Step 2: Parse Parameters
param0 = model0.state_dict()
param1 = model1.state_dict()
Step 3: Modified Parameters
param_new = {}
for key in param0.keys():
paranew[key] = (param0[key] + param1[key])/2
Step 4: Generate New Model
model2 = model1
model2.load_state_dict(param_new)
param2 = model2.net.state_dict()