在pytorch中,常会load已有模型甚至pretrained的模型,用其中几层作为特征提取(feature extraction)。比如用pytorch内置的pretrained ResNet作为特征提取器,需要把fully connected layer去掉。可以用children()方法提出需要的层
import torch.nn as nn
from torchvision import models
model = models.resnet50(pretrained=True)
truncated_model = nn.Sequential(*list(model.children())[:8])
print(truncated_model)
truncted_model可作为feature extractor,需要注意输入输出大小即可。
PS: *list
可以达到以下效果
l = ["./foo", "bar", "quux"]
funcXXX(*l)
# 等价于
funcXXX("./foo", "bar", "quux")
也即是,iterate 提取list中的内容,并以逗号分隔。满足nn.Sequential()的输入条件