trainval_net.py

入口

if __name__ == '__main__': #主文件入口
  args = parse_args()#解析参数

  print('Called with args:')
  print(args)

  if args.cfg_file is not None:
    cfg_from_file(args.cfg_file)#加载配置文件并合入到默认项
  if args.set_cfgs is not None:
    cfg_from_list(args.set_cfgs)#加载配置列表并合入到默认项

  print('Using config:')
  pprint.pprint(cfg)

  np.random.seed(cfg.RNG_SEED)#生成随机种子,预测随机值

  # train set
  imdb, roidb = combined_roidb(args.imdb_name)
  print('{:d} roidb entries'.format(len(roidb)))
 # output directory where the models are saved
  output_dir = get_output_dir(imdb, args.tag)
  print('Output will be saved to `{:s}`'.format(output_dir))

  # tensorboard directory where the summaries are saved during training
  tb_dir = get_output_tb_dir(imdb, args.tag)
  print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))

  # also add the validation set, but with no flipping images
  orgflip = cfg.TRAIN.USE_FLIPPED
  cfg.TRAIN.USE_FLIPPED = False
  _, valroidb = combined_roidb(args.imdbval_name)
  print('{:d} validation roidb entries'.format(len(valroidb)))
  cfg.TRAIN.USE_FLIPPED = orgflip

  # load network
  if args.net == 'vgg16':
    net = vgg16()
  elif args.net == 'res50':
    net = resnetv1(num_layers=50)
  elif args.net == 'res101':
    net = resnetv1(num_layers=101)
  elif args.net == 'res152':
    net = resnetv1(num_layers=152)
  elif args.net == 'mobile':
    net = mobilenetv1()
  else:
    raise NotImplementedError
    
  train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,
            pretrained_model=args.weight,
            max_iters=args.max_iters)

combined_roidb(imdb_names)

def combined_roidb(imdb_names):
  """
  Combine multiple roidbs
  """
  #内部函数
  def get_roidb(imdb_name):
    imdb = get_imdb(imdb_name)
    print('Loaded dataset `{:s}` for training'.format(imdb.name))
    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
    print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD))
    roidb = get_training_roidb(imdb)
    return roidb

  roidbs = [get_roidb(s) for s in imdb_names.split('+')]
  roidb = roidbs[0]
  if len(roidbs) > 1:
    for r in roidbs[1:]:
      roidb.extend(r)
    tmp = get_imdb(imdb_names.split('+')[1])
    imdb = datasets.imdb.imdb(imdb_names, tmp.classes)
  else:
    imdb = get_imdb(imdb_names)
  return imdb, roidb

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 文章摘自swungover.wordpress.com This is an index of the basic...
    TravisShout阅读 1,146评论 0 3
  • NAME dnsmasq - A lightweight DHCP and caching DNS server....
    ximitc阅读 2,936评论 0 0
  • 雪落秦安, 寒锁故园。 一场春雪,风情万般。 辜负了冬的望穿秋水, 你却在破茧的春光里, ...
    魏文晶阅读 337评论 0 2
  • 一:今天睡之前,我跑到果果身边和他说晚安,已经躺下了的小家伙爬起来,亲了亲我说:“妈妈,晚安!我是世界上最爱你的人...
    果果菠萝蜜妈妈姗妮谢阅读 464评论 2 0
  • 连续223 【箴29:1】人屡次受责罚,仍然硬着颈项,他必顷刻败坏,无法可治。【箴29:26】求王恩的人多,定人事...
    报佳音阅读 1,425评论 0 0