def set_bn_eval(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
use model.apply()
to freeze bn
def train(model,data_loader,criterion,epoch):
model.train() # switch to train mode
model.apply(set_bn_eval) # this will freeze the bn in training process
###
# training code
###
wrap up, commonly used
def main():
# ...
for epoch in epochs:
train(model,train_loader,criterion,epoch)
test(model,eval_loader,epoch)
# ...