- 保存、恢复模型参数
参考:pytorch学习笔记(五):保存和加载模型
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
- 中断时保存参数
try:
train_net(net=net, epochs=args.epochs, batch_size=args.batchsize,
lr=args.lr, gpu=args.gpu, img_scale=args.scale)
except KeyboardInterrupt: # 用户中断执行(通常是输入^C)
import time
save_time = time.strftime("%Y-%m-%d-%H-%M", time.localtime())
torch.save(net.state_dict(), '{}_INTERRUPTED.pth'.format(save_time))
print('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
将该代码添加至save_model合适的位置,可实现“Early Stopping”