PaddlePaddle/PaddleX

`DeepLabV3P_CBAM` is not registered on BaseTrainer.

Open

#4.739 aberto em 18 de nov. de 2025

Ver no GitHub
 (1 comment) (1 reaction) (1 assignee)Python (894 forks)batch import
help wanted

Métricas do repositório

Stars
 (4.520 stars)
Métricas de merge de PR
 (Mesclagem média 2d 15h) (15 fundiu PRs em 30d)

Description

Checklist:

描述问题

我在用paddlex的语义分割模型做遥感的道路提取,因为默认的DeepLabV3_plus-Res v101 里面不包含注意力机制,所以我自己扩展了一下,现在我希望加入这个模型去做训练,可是会出现paddlex.utils.errors.others.ClassNotFoundException: DeepLabV3P_CBAM is not registered on BaseTrainer. 的问题

复现

1.我扩展了两个文件。 deeplabv3p_attention.py,用来创建注意力模块

import paddle import paddle.nn as nn import paddle.nn.functional as F

class ChannelAttention(nn.Layer): """通道注意力机制""" def init(self, in_channels, ratio=16): super(ChannelAttention, self).init() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1)

    self.fc = nn.Sequential(
        nn.Conv2D(in_channels, in_channels // ratio, 1, bias_attr=False),
        nn.ReLU(),
        nn.Conv2D(in_channels // ratio, in_channels, 1, bias_attr=False)
    )
    self.sigmoid = nn.Sigmoid()

def forward(self, x):
    avg_out = self.fc(self.avg_pool(x))
    max_out = self.fc(self.max_pool(x))
    out = avg_out + max_out
    return x * self.sigmoid(out)

class SpatialAttention(nn.Layer): """空间注意力机制""" def init(self, kernel_size=7): super(SpatialAttention, self).init() self.conv = nn.Conv2D(2, 1, kernel_size, padding=kernel_size//2, bias_attr=False) self.sigmoid = nn.Sigmoid()

def forward(self, x):
    avg_out = paddle.mean(x, axis=1, keepdim=True)
    max_out = paddle.max(x, axis=1, keepdim=True)
    x_combined = paddle.concat([avg_out, max_out], axis=1)
    attention_map = self.conv(x_combined)
    return x * self.sigmoid(attention_map)

class CBAM(nn.Layer): """CBAM注意力模块(通道+空间)""" def init(self, in_channels, reduction=16, kernel_size=7): super(CBAM, self).init() self.channel_attention = ChannelAttention(in_channels, reduction) self.spatial_attention = SpatialAttention(kernel_size)

def forward(self, x):
    x = self.channel_attention(x)
    x = self.spatial_attention(x)
    return x

class RoadConnectivityAttention(nn.Layer): """专门针对道路连通性的注意力机制""" def init(self, in_channels, kernel_size=5): super(RoadConnectivityAttention, self).init() # 使用大卷积核捕捉长距离道路依赖 self.road_conv = nn.Conv2D(in_channels, in_channels, kernel_size, padding=kernel_size//2, groups=in_channels, bias_attr=False) self.road_bn = nn.BatchNorm2D(in_channels) self.road_act = nn.ReLU()

    # 通道重校准
    self.gap = nn.AdaptiveAvgPool2d(1)
    self.channel_fc = nn.Sequential(
        nn.Linear(in_channels, in_channels // 4),
        nn.ReLU(),
        nn.Linear(in_channels // 4, in_channels),
        nn.Sigmoid()
    )
    
def forward(self, x):
    b, c, h, w = x.shape
    
    # 道路结构增强
    road_feat = self.road_conv(x)
    road_feat = self.road_bn(road_feat)
    road_feat = self.road_act(road_feat)
    
    # 通道注意力
    channel_weights = self.gap(x).reshape([b, c])
    channel_weights = self.channel_fc(channel_weights).reshape([b, c, 1, 1])
    
    return x + road_feat * channel_weights

deeplabv3p_cbam.py,带有注意力模块的DPV3模型

import paddle import paddle.nn as nn from paddleseg.cvlibs import manager from paddleseg.models import DeepLabV3P from deeplabv3p_attention import CBAM, RoadConnectivityAttention

@manager.MODELS.add_component class DeepLabV3P_CBAM(DeepLabV3P): """集成CBAM注意力机制的DeepLabV3+""" def init(self, num_classes=19, backbone='ResNet101_vd', backbone_indices=(0, 1, 2, 3), aspp_ratios=(1, 3, 6, 12, 18), aspp_out_channels=256, align_corners=False, pretrained=None):

    super().__init__(
        num_classes=num_classes,
        backbone=backbone,
        backbone_indices=backbone_indices,
        aspp_ratios=aspp_ratios,
        aspp_out_channels=aspp_out_channels,
        align_corners=align_corners,
        pretrained=pretrained
    )
    
    # 在ASPP模块后添加CBAM注意力
    aspp_out_channels = 256  # 与您的配置一致
    self.aspp_cbam = CBAM(aspp_out_channels)
    
    # 在解码器路径添加道路连通性注意力
    decoder_low_level_channels = 256  # 调整为您backbone的实际通道数
    self.decoder_cbam = RoadConnectivityAttention(decoder_low_level_channels)
    
    # 最终预测前的注意力
    self.final_cbam = CBAM(decoder_low_level_channels // 2)

def forward(self, x):
    # 编码器特征提取
    feats = self.backbone(x)
    low_level_feat = feats[self.backbone_indices[1]]  # 低级特征
    high_level_feat = feats[self.backbone_indices[0]]  # 高级特征
    
    # ASPP模块处理高级特征
    aspp_out = self.aspp(high_level_feat)
    
    # 在ASPP后应用注意力机制
    aspp_out = self.aspp_cbam(aspp_out)
    
    # 上采样ASPP输出
    aspp_out = self.conv_bn_relu(aspp_out)
    aspp_out = F.interpolate(
        aspp_out,
        size=low_level_feat.shape[2:],
        mode='bilinear',
        align_corners=self.align_corners
    )
    
    # 处理低级特征并应用道路连通性注意力
    low_level_feat = self.low_level_conv(low_level_feat)
    low_level_feat = self.decoder_cbam(low_level_feat)
    
    # 特征融合
    fused_feat = paddle.concat([aspp_out, low_level_feat], axis=1)
    fused_feat = self.conv_bn_relu_fusion(fused_feat)
    
    # 最终预测前的注意力增强
    fused_feat = self.final_cbam(fused_feat)
    
    # 上采样到原图大小
    output = F.interpolate(
        fused_feat,
        size=x.shape[2:],
        mode='bilinear',
        align_corners=self.align_corners
    )
    
    return [output]

然后我把这两个文件放在了/root/PaddleX/paddlex/repo_manager/repos/PaddleSeg/paddleseg/models下面,并在__init__.py中引入

from .deeplabv3p_cbam import DeepLabV3P_CBAM

在register.py中进行了注册:

register_model_info( { "model_name": "DeepLabV3P_CBAM", "suite": "Seg", "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3_Plus-CBAM-R-R101.yaml"), "supported_apis": ["train", "evaluate", "predict", "export"], } )

同时在model_list.py中添加了记录

MODELS = [ "DeepLabV3P_CBAM" "Deeplabv3_Plus-R101", "Deeplabv3_Plus-R50", "Deeplabv3-R101", "Deeplabv3-R50", "OCRNet_HRNet-W48", "OCRNet_HRNet-W18", "PP-LiteSeg-T", "PP-LiteSeg-B", "SegFormer-B0", "SegFormer-B1", "SegFormer-B2", "SegFormer-B3", "SegFormer-B4", "SegFormer-B5", "SeaFormer_base", "SeaFormer_tiny", "SeaFormer_small", "SeaFormer_large", "MaskFormer_tiny", "MaskFormer_small", ]

用命令行执行

python main.py -c paddlex/configs/modules/semantic_segmentation/Deeplabv3_Plus-CBAM-R-R101.yaml -o Global.mode=train -o Global.dataset_dir=/root/data/dataset/changxin2 -o Global.device=gpu:1 -o Train.batch_size=2 -o Train.learning_rate=0.001 -o Train.num_classes=2 -o Train.epochs_iters=80000 -o Global.output=/root/PaddleX/output/rd_seg_model

但是执行的时候就会出现错误:

Traceback (most recent call last): File "/root/PaddleX/paddlex/utils/result_saver.py", line 28, in wrap result = func(self, *args, **kwargs) File "/root/PaddleX/paddlex/engine.py", line 41, in run self._model.train() File "/root/PaddleX/paddlex/model.py", line 118, in train trainer = build_trainer(self._config) File "/root/PaddleX/paddlex/modules/base/trainer.py", line 44, in build_trainer return BaseTrainer.get(model_name)(config) File "/root/PaddleX/paddlex/utils/misc.py", line 196, in get raise_class_not_found_error(name, cls, all_entities) File "/root/PaddleX/paddlex/utils/errors/others.py", line 112, in raise_class_not_found_error raise ClassNotFoundException(msg) paddlex.utils.errors.others.ClassNotFoundException: DeepLabV3P_CBAM is not registered on BaseTrainer. The registied entities: [STFPM, CLIP_vit_base_patch16_224, CLIP_vit_large_patch14_224, ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384, MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1_x1_0, MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x1_0, MobileNetV2_x1_5, MobileNetV2_x2_0, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25, MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV4_conv_small, MobileNetV4_conv_medium, MobileNetV4_conv_large, MobileNetV4_hybrid_medium, MobileNetV4_hybrid_large, PP-HGNet_tiny, PP-HGNet_small, PP-HGNet_base, PP-HGNetV2-B0, PP-HGNetV2-B1, PP-HGNetV2-B2, PP-HGNetV2-B3, PP-HGNetV2-B4, PP-HGNetV2-B5, PP-HGNetV2-B6, PP-LCNet_x0_25, PP-LCNet_x0_25_textline_ori, PP-LCNet_x0_35, PP-LCNet_x0_5, PP-LCNet_x0_75, PP-LCNet_x1_0, PP-LCNet_x1_0_doc_ori, PP-LCNet_x1_0_textline_ori, PP-LCNet_x1_5, PP-LCNet_x2_0, PP-LCNet_x2_5, PP-LCNetV2_small, PP-LCNetV2_base, PP-LCNetV2_large, ResNet101, ResNet152, ResNet18, ResNet34, ResNet50, ResNet200_vd, ResNet101_vd, ResNet152_vd, ResNet18_vd, ResNet34_vd, ResNet50_vd, SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384, StarNet-S1, StarNet-S2, StarNet-S3, StarNet-S4, FasterNet-L, FasterNet-M, FasterNet-S, FasterNet-T0, FasterNet-T1, FasterNet-T2, PP-LCNet_x1_0_table_cls, MobileFaceNet, ResNet50_face, PP-ShiTuV2_rec, PP-ShiTuV2_rec_CLIP_vit_base, PP-ShiTuV2_rec_CLIP_vit_large, PicoDet-L, PicoDet-S, PP-YOLOE_plus-L, PP-YOLOE_plus-M, PP-YOLOE_plus-S, PP-YOLOE_plus-X, RT-DETR-H, RT-DETR-L, RT-DETR-R18, RT-DETR-R50, RT-DETR-X, PicoDet_layout_1x, PicoDet_layout_1x_table, PicoDet-S_layout_3cls, PicoDet-S_layout_17cls, PicoDet-L_layout_3cls, PicoDet-L_layout_17cls, RT-DETR-H_layout_3cls, RT-DETR-H_layout_17cls, YOLOv3-DarkNet53, YOLOv3-MobileNetV3, YOLOv3-ResNet50_vd_DCN, YOLOX-L, YOLOX-M, YOLOX-N, YOLOX-S, YOLOX-T, YOLOX-X, FasterRCNN-ResNet34-FPN, FasterRCNN-ResNet50, FasterRCNN-ResNet50-FPN, FasterRCNN-ResNet50-vd-FPN, FasterRCNN-ResNet50-vd-SSLDv2-FPN, FasterRCNN-ResNet101, FasterRCNN-ResNet101-FPN, FasterRCNN-ResNeXt101-vd-FPN, FasterRCNN-Swin-Tiny-FPN, Cascade-FasterRCNN-ResNet50-FPN, Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN, PicoDet-M, PicoDet-XS, FCOS-ResNet50, DETR-R50, PP-ShiTuV2_det, PP-YOLOE-L_human, PP-YOLOE-S_human, PP-YOLOE-L_vehicle, PP-YOLOE-S_vehicle, PP-YOLOE_plus_SOD-L, PP-YOLOE_plus_SOD-S, PP-YOLOE_plus_SOD-largesize-L, CenterNet-DLA-34, CenterNet-ResNet50, PicoDet_LCNet_x2_5_face, BlazeFace, BlazeFace-FPN-SSH, PP-YOLOE_plus-S_face, PP-YOLOE-R-L, Co-Deformable-DETR-R50, Co-Deformable-DETR-Swin-T, Co-DINO-R50, Co-DINO-Swin-L, RT-DETR-L_wired_table_cell_det, RT-DETR-L_wireless_table_cell_det, PP-DocLayout-L, PP-DocLayout-M, PP-DocLayout-S, PP-DocLayout_plus-L, PP-DocBlockLayout, Mask-RT-DETR-S, Mask-RT-DETR-M, Mask-RT-DETR-X, Mask-RT-DETR-H, Mask-RT-DETR-L, SOLOv2, MaskRCNN-ResNet50, MaskRCNN-ResNet50-FPN, MaskRCNN-ResNet50-vd-FPN, MaskRCNN-ResNet101-FPN, MaskRCNN-ResNet101-vd-FPN, MaskRCNN-ResNeXt101-vd-FPN, MaskRCNN-ResNet50-vd-SSLDv2-FPN, Cascade-MaskRCNN-ResNet50-FPN, Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN, PP-YOLOE_seg-S, PP-TinyPose_128x96, PP-TinyPose_256x192, BEVFusion, ResNet50_ML, PP-LCNet_x1_0_ML, PP-HGNetV2-B0_ML, PP-HGNetV2-B4_ML, PP-HGNetV2-B6_ML, CLIP_vit_base_patch16_448_ML, PP-LCNet_x1_0_pedestrian_attribute, PP-LCNet_x1_0_vehicle_attribute, whisper_large, whisper_medium, whisper_base, whisper_small, whisper_tiny, DeepLabV3P_CBAMDeeplabv3_Plus-R101, Deeplabv3_Plus-R50, Deeplabv3-R101, Deeplabv3-R50, OCRNet_HRNet-W48, OCRNet_HRNet-W18, PP-LiteSeg-T, PP-LiteSeg-B, SegFormer-B0, SegFormer-B1, SegFormer-B2, SegFormer-B3, SegFormer-B4, SegFormer-B5, SeaFormer_base, SeaFormer_tiny, SeaFormer_small, SeaFormer_large, MaskFormer_tiny, MaskFormer_small, SLANet, SLANet_plus, SLANeXt_wired, SLANeXt_wireless, PP-OCRv5_mobile_det, PP-OCRv5_server_det, PP-OCRv4_mobile_det, PP-OCRv4_server_det, PP-OCRv4_mobile_seal_det, PP-OCRv4_server_seal_det, PP-OCRv3_mobile_det, PP-OCRv3_server_det, PP-OCRv3_mobile_rec, en_PP-OCRv3_mobile_rec, korean_PP-OCRv3_mobile_rec, japan_PP-OCRv3_mobile_rec, chinese_cht_PP-OCRv3_mobile_rec, te_PP-OCRv3_mobile_rec, ka_PP-OCRv3_mobile_rec, ta_PP-OCRv3_mobile_rec, latin_PP-OCRv3_mobile_rec, arabic_PP-OCRv3_mobile_rec, cyrillic_PP-OCRv3_mobile_rec, devanagari_PP-OCRv3_mobile_rec, PP-OCRv4_mobile_rec, PP-OCRv4_server_rec, en_PP-OCRv4_mobile_rec, PP-OCRv4_server_rec_doc, ch_SVTRv2_rec, ch_RepSVTR_rec, PP-OCRv5_server_rec, PP-OCRv5_mobile_rec, AutoEncoder_ad, DLinear_ad, Nonstationary_ad, PatchTST_ad, TimesNet_ad, TimesNet_cls, DLinear, NLinear, Nonstationary, PatchTST, RLinear, TiDE, TimesNet, PP-TSM-R50_8frames_uniform, PP-TSMv2-LCNetV2_8frames_uniform, PP-TSMv2-LCNetV2_16frames_uniform, YOWO, PP-DocBee-2B, PP-DocBee-7B, PP-Chart2Table, PP-DocBee2-3B, LaTeX_OCR_rec, UniMERNet, PP-FormulaNet-S, PP-FormulaNet-L, PP-FormulaNet_plus-S, PP-FormulaNet_plus-M, PP-FormulaNet_plus-L, GroundingDINO-T, YOLO-Worldv2-L, SAM-H_point, SAM-H_box]

环境

  1. 请提供您使用的PaddlePaddle和PaddleX的版本号

paddlex:3.0

  1. 请提供您使用的操作系统信息,如Linux/Windows/MacOS

centos 8.0

  1. 请问您使用的Python版本是?

python 3.10.18

  1. 请问您使用的CUDA/cuDNN的版本号是? cuda 12 cuDNN8.7

Guia do colaborador