一 写在前面
未经允许,不得转载,谢谢~~~
之前自己写的code都太乱了,发表了工作都不好意思开源,模型训练和测试都是习惯手动改源码的方式来改不同的参数。
这次想着学习一下,修改一下以后的coding习惯。
ps: 不过这也只是我个人觉得还不错的,不一定恰好符合你的审美,coding习惯本身没有好坏之分,有自己满意的一套习惯就很🉑️~
二 主要构成
以参考GraphCMR为主,之前跑他的repo,觉得整体的代码可读性还是非常高的。
--- sources/
--- config.py
--- utils.py (utils/)
--- dataloader.py (dataloaders/)
--- model.py (models/)
--- train.py
--- eval.py
--- demo.py
--- logs/
- sources/
- 用于存储一些需要用的文件或者数据
- 例如数据集的list
- 个人不建议将数据集也放在home目录下
- config.py
- 用于存放全局路径和全局变量
- 最好都用全大写命名
- utils.py 或者utils/
- 用于写一些经常需要被调用的函数或者是处理特定任务的函数;
- 比如对于输入数据image或者video的处理;
- 如果工程量很大的话可以分开写多个utils文件;
- dataloder.py 或dataloaders/
- 用于加载数据;
- 需要处理多个数据集且不同数据集处理方式相差较大的情况下可以写多个文件;
- models.py 或者 models/
- 写主要模型结构;
- 看情况写一个文件或者一个文件夹;
- train.py
- 模型训练
- eval.py
- 模型测试
- demo.py
- 模型demo
- logs/
- 可以用于保存结果。
- 子目录名表示某次特定的exp名,例如exp4
- exp4中可以包含:
-
config.json
(此次实验的setting记录,对应下文save_dump()
函数) -
checkpoints/
(可以放model ckp) -
tensorboard/
(可以存tensorboard结果)
-
三 关于argparse
用于接受输入,可以避免运行程序的时候每次都需要去修改脚本。
示例用法:
import argparse
# Define command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', default=None, help='Path to network checkpoint')
parser.add_argument('--dataset', default='h36m-p1', choices=['h36m-p1', 'h36m-p2', 'up-3d', 'lsp'], help='Choose evaluation dataset')
parser.add_argument('--config', default=None, help='Path to config file containing model architecture etc.')
parser.add_argument('--log_freq', default=50, type=int, help='Frequency of printing intermediate results')
parser.add_argument('--batch_size', default=32, help='Batch size for testing')
parser.add_argument('--shuffle', default=False, action='store_true', help='Shuffle data')
parser.add_argument('--num_workers', default=8, type=int, help='Number of processes for data loading')
if __name__ == '__main__':
args = parser.parse_args()
# Run evaluation
run_evaluation(model, args.dataset, dataset, mesh,
batch_size=args.batch_size,
shuffle=args.shuffle,
log_freq=args.log_freq)
对于train:
- 模型train的时候需要的setting会比较多,比如模型层面,数据层面等,而且为了便于复现实验,一般需要保存training的setting,所以可以写一个专门用于处理training过程的option的类,主要用于定义和存储模型training中的各种setting。
示例用法:
class TrainOptions(object):
"""Object that handles command line options."""
def __init__(self):
self.parser = argparse.ArgumentParser()
req = self.parser.add_argument_group('Required')
req.add_argument('--name', required=True, help='Name of the experiment')
gen = self.parser.add_argument_group('General')
gen.add_argument('--time_to_run', type=int, default=np.inf, help='Total time to run in seconds. Used for training in environments with timing constraints')
gen.add_argument('--resume', dest='resume', default=False, action='store_true', help='Resume from checkpoint (Use latest checkpoint by default')
gen.add_argument('--num_workers', type=int, default=8, help='Number of processes used for data loading')
return
def parse_args(self):
"""Parse input arguments."""
self.args = self.parser.parse_args()
# If config file is passed, override all arguments with the values from the config file
if self.args.from_json is not None:
path_to_json = os.path.abspath(self.args.from_json)
with open(path_to_json, "r") as f:
json_args = json.load(f)
json_args = namedtuple("json_args", json_args.keys())(**json_args)
return json_args
else:
self.args.log_dir = os.path.join(os.path.abspath(self.args.log_dir), self.args.name)
self.args.summary_dir = os.path.join(self.args.log_dir, 'tensorboard')
if not os.path.exists(self.args.log_dir):
os.makedirs(self.args.log_dir)
self.args.checkpoint_dir = os.path.join(self.args.log_dir, 'checkpoints')
if not os.path.exists(self.args.checkpoint_dir):
os.makedirs(self.args.checkpoint_dir)
self.save_dump()
return self.args
def save_dump(self):
"""Store all argument values to a json file.
The default location is logs/expname/config.json.
"""
if not os.path.exists(self.args.log_dir):
os.makedirs(self.args.log_dir)
with open(os.path.join(self.args.log_dir, "config.json"), "w") as f:
json.dump(vars(self.args), f, indent=4)
return
对于eval和demo:
一般不会有很多需要临时接受的参数,就可以像开头给的通用示例用法一样去做
另外如果模型定义也需要参数,可以把对应的training时候的配置文件传回去。
with open(args.config, 'r') as f:
options = json.load(f)
options = namedtuple('options', options.keys())(**options)