PyTorch如何打印模型详细信息

我们以resnet18为例,介绍几种获取模型摘要的方法。

import torchvistion
model = torchvision.models.resnet18()

1.直接使用PrettyTable

from prettytable import PrettyTable

table = PrettyTable(['Modules', 'Parameters']) 
total_params = 0 
for name, parameter in model.named_parameters():
    if not parameter.requires_grad: continue
    params = parameter.numel()
    table.add_row([name, params])
    total_params+=params
print(table) 
print(f'Total Trainable Params: {total_params}') 

效果如下:


PrettyTable

比较简单,也没有模型的输入输出情况。

2. TorchSummary

from torchsummary import summary
summary(model, input_size = (3, 64, 64), batch_size = -1)
TorchSummary

整体看美观了很多,也有了输出的维度。但是如果能打印出模型的层次结构就更好了。

3. torchinfo

import torchinfo 
torchinfo.summary(model, (3, 224, 224), batch_dim = 0, col_names = ('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose = 0)
torchinfo

这种方式更加美观,且内容详细,灰常棒。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容