PaddlePaddle/PaddleDetection

打印模型的FLOPS

Open

#3.464 aberto em 22 de jun. de 2021

Ver no GitHub
 (5 comments) (0 reactions) (0 assignees)Python (2.731 forks)batch import
help wanted

Métricas do repositório

Stars
 (11.414 stars)
Métricas de merge de PR
 (Mesclagem média 2d 3h) (1 fundiu PR em 30d)

Description

我想打印出faster rcnn的网络结构和参数信息,使用了paddle.flops()函数,在infer.py中修改如下: ` def run(FLAGS, cfg): # build trainer trainer = Trainer(cfg, mode='test')

# load weights
trainer.load_weights(cfg.weights)

# get inference images
images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)

# inference
trainer.predict(
    images,
    draw_threshold=FLAGS.draw_threshold,
    output_dir=FLAGS.output_dir,
    save_txt=FLAGS.save_txt)
#打印模型的基础结构和参数信息
FLOPs = paddle.flops(trainer.model, [1, 3, 608, 608], custom_ops= {nn.LeakyReLU: trainer.predict}, print_detail=True)
print(FLOPs)

def main(): FLAGS = parse_args() cfg = load_config(FLAGS.config) cfg['use_vdl'] = FLAGS.use_vdl cfg['vdl_log_dir'] = FLAGS.vdl_log_dir merge_config(FLAGS.opt)

place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')

if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
    cfg['norm_type'] = 'bn'

if FLAGS.slim_config:
    cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')

check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()

run(FLAGS, cfg)

if name == 'main': main() ` 结果报错如下 1 2

Guia do colaborador