我们以resnet18为例,介绍几种获取模型摘要的方法。
import torchvistion
model = torchvision.models.resnet18()
1.直接使用PrettyTable
from prettytable import PrettyTable
table = PrettyTable(['Modules', 'Parameters'])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
table.add_row([name, params])
total_params+=params
print(table)
print(f'Total Trainable Params: {total_params}')
效果如下:
比较简单,也没有模型的输入输出情况。
2. TorchSummary
from torchsummary import summary
summary(model, input_size = (3, 64, 64), batch_size = -1)
整体看美观了很多,也有了输出的维度。但是如果能打印出模型的层次结构就更好了。
3. torchinfo
import torchinfo
torchinfo.summary(model, (3, 224, 224), batch_dim = 0, col_names = ('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose = 0)
这种方式更加美观,且内容详细,灰常棒。