facebookresearch/maskrcnn-benchmark

trian cityscapes use coco pretrain model problem ?

Open

#259 建立於 2018年12月10日

在 GitHub 查看
 (10 留言) (0 反應) (0 負責人)Python (9,161 star) (2,574 fork)batch import
help wantedquestion

描述

❓ Questions and Help

  • thanks the code for train new datasets cityscapes for instance segementation .
  • first i train the cityscapes from scratch and the loss is convergence;but i get box_AP and seg_AP is not high as follow , i read the mask_rcnn paper is is higher a lot , I don't know what details I overlooked.
2018-12-07 18:58:13,471 maskrcnn_benchmark.inference INFO: OrderedDict([('bbox', OrderedDict([('AP', 0.266143220179594), ('AP50', 0.4705279119903588), ('AP75', 0.2664711486678874), ('APs', 0.0742186384761436), ('APm', 0.26418817964465885), ('APl', 0.4618351991771723)])), ('segm', OrderedDict([('AP', 0.2169857479304357), ('AP50', 0.4159623962610022), ('AP75', 0.17807455425402843), ('APs', 0.029122872145021395), ('APm', 0.174442224182182), ('APl', 0.42977448859947454)]))])
  • experiment set on single GTX1080ti :
--config-file "../configs/cityscapes/e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml" SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.00125 SOLVER.MAX_ITER 200000 SOLVER.STEPS "(160000, 180000)" TEST.IMS_PER_BATCH 1
  • seconde quesition : using COCO pre-training to train cityscapes
  • when i load the pretrain coco model meet some problem ,the classnums 81->9 ,so the fc parameter should be ignored ,
  • but the code follow maskrcnn-benchmark/maskrcnn_benchmark/utils/model_serialization.py get problem becase model_state_dict[key] = loaded_state_dict[key_old] overwriting the original value :
def load_state_dict(model, loaded_state_dict):
    model_state_dict = model.state_dict()
    # if the state_dict comes from a model that was wrapped in a
    # DataParallel or DistributedDataParallel during serialization,
    # remove the "module" prefix before performing the matching
    loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
    align_and_update_state_dicts(model_state_dict, loaded_state_dict) ##model_state_dict[key] = loaded_state_dict[key_old] 

    # use strict loading
model.load_state_dict(model_state_dict)
  • i use follow code:
def load_state_dict(model, loaded_state_dict):
    model_state_dict = model.state_dict()
    # if the state_dict comes from a model that was wrapped in a
    # DataParallel or DistributedDataParallel during serialization,
    # remove the "module" prefix before performing the matching
    loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")

    # align_and_update_state_dicts(model_state_dict, loaded_state_dict)
    # # finetune
    loaded_state_dict = {k:v for k,v in loaded_state_dict.items() if k in model_state_dict and model_state_dict[k].size()==v.size()}
    model_state_dict.update(loaded_state_dict)
    # use strict loading
    model.load_state_dict(model_state_dict)
  • but then maskrcnn_benchmark/utils/checkpoint.py get error, i don't know why should load self.optimizer.load_state_dict and self.scheduler.load_state_dict , it has 'momentum_buffer' paremeter , i don't understand why load this parameter . can you explain ? and how can i use coco pretrain model to finetune cityscapes ? thanks !
 def load(self, f=None):
        if self.has_checkpoint():
            # override argument with existing checkpoint
            f = self.get_checkpoint_file()
        if not f:
            # no checkpoint could be found
            self.logger.info("No checkpoint found. Initializing model from scratch")
            return {}
        self.logger.info("Loading checkpoint from {}".format(f))
        checkpoint = self._load_file(f)
        self._load_model(checkpoint)
        if "optimizer" in checkpoint and self.optimizer:
            self.logger.info("Loading optimizer from {}".format(f))
            self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
        if "scheduler" in checkpoint and self.scheduler:
            self.logger.info("Loading scheduler from {}".format(f))
            self.scheduler.load_state_dict(checkpoint.pop("scheduler"))

        # return any further checkpoint data
        return checkpoint

貢獻者指南