PaddlePaddle/PaddleDetection

打印模型的FLOPS

Open

#3,464 创建于 2021年6月22日

在 GitHub 查看
 (5 评论) (0 反应) (0 负责人)Python (11,414 star) (2,731 fork)batch import
help wanted

描述

我想打印出faster rcnn的网络结构和参数信息,使用了paddle.flops()函数,在infer.py中修改如下: ` def run(FLAGS, cfg): # build trainer trainer = Trainer(cfg, mode='test')

# load weights
trainer.load_weights(cfg.weights)

# get inference images
images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)

# inference
trainer.predict(
    images,
    draw_threshold=FLAGS.draw_threshold,
    output_dir=FLAGS.output_dir,
    save_txt=FLAGS.save_txt)
#打印模型的基础结构和参数信息
FLOPs = paddle.flops(trainer.model, [1, 3, 608, 608], custom_ops= {nn.LeakyReLU: trainer.predict}, print_detail=True)
print(FLOPs)

def main(): FLAGS = parse_args() cfg = load_config(FLAGS.config) cfg['use_vdl'] = FLAGS.use_vdl cfg['vdl_log_dir'] = FLAGS.vdl_log_dir merge_config(FLAGS.opt)

place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')

if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
    cfg['norm_type'] = 'bn'

if FLAGS.slim_config:
    cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')

check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()

run(FLAGS, cfg)

if name == 'main': main() ` 结果报错如下 1 2

贡献者指南