PaddlePaddle/PaddleDetection

打印模型的FLOPS

Open

#3 464 ouverte le 22 juin 2021

Voir sur GitHub
 (5 commentaires) (0 réactions) (0 assignés)Python (2 731 forks)batch import
help wanted

Métriques du dépôt

Stars
 (11 414 stars)
Métriques de merge PR
 (Merge moyen 2j 3h) (1 PR mergée en 30 j)

Description

我想打印出faster rcnn的网络结构和参数信息,使用了paddle.flops()函数,在infer.py中修改如下: ` def run(FLAGS, cfg): # build trainer trainer = Trainer(cfg, mode='test')

# load weights
trainer.load_weights(cfg.weights)

# get inference images
images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)

# inference
trainer.predict(
    images,
    draw_threshold=FLAGS.draw_threshold,
    output_dir=FLAGS.output_dir,
    save_txt=FLAGS.save_txt)
#打印模型的基础结构和参数信息
FLOPs = paddle.flops(trainer.model, [1, 3, 608, 608], custom_ops= {nn.LeakyReLU: trainer.predict}, print_detail=True)
print(FLOPs)

def main(): FLAGS = parse_args() cfg = load_config(FLAGS.config) cfg['use_vdl'] = FLAGS.use_vdl cfg['vdl_log_dir'] = FLAGS.vdl_log_dir merge_config(FLAGS.opt)

place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')

if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
    cfg['norm_type'] = 'bn'

if FLAGS.slim_config:
    cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')

check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()

run(FLAGS, cfg)

if name == 'main': main() ` 结果报错如下 1 2

Guide contributeur