PaddlePaddle/PaddleX

`DeepLabV3P_CBAM` is not registered on BaseTrainer.

Open

#4739 opened on Nov 18, 2025

View on GitHub
 (1 comment) (1 reaction) (1 assignee)Python (4,520 stars) (894 forks)batch import
help wanted

Description

Checklist:

描述问题

我在用paddlex的语义分割模型做遥感的道路提取,因为默认的DeepLabV3_plus-Res v101 里面不包含注意力机制,所以我自己扩展了一下,现在我希望加入这个模型去做训练,可是会出现paddlex.utils.errors.others.ClassNotFoundException: DeepLabV3P_CBAM is not registered on BaseTrainer. 的问题

复现

1.我扩展了两个文件。 deeplabv3p_attention.py,用来创建注意力模块

import paddle import paddle.nn as nn import paddle.nn.functional as F

class ChannelAttention(nn.Layer): """通道注意力机制""" def init(self, in_channels, ratio=16): super(ChannelAttention, self).init() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1)

    self.fc = nn.Sequential(
        nn.Conv2D(in_channels, in_channels // ratio, 1, bias_attr=False),
        nn.ReLU(),
        nn.Conv2D(in_channels // ratio, in_channels, 1, bias_attr=False)
    )
    self.sigmoid = nn.Sigmoid()

def forward(self, x):
    avg_out = self.fc(self.avg_pool(x))
    max_out = self.fc(self.max_pool(x))
    out = avg_out + max_out
    return x * self.sigmoid(out)

class SpatialAttention(nn.Layer): """空间注意力机制""" def init(self, kernel_size=7): super(SpatialAttention, self).init() self.conv = nn.Conv2D(2, 1, kernel_size, padding=kernel_size//2, bias_attr=False) self.sigmoid = nn.Sigmoid()

def forward(self, x):
    avg_out = paddle.mean(x, axis=1, keepdim=True)
    max_out = paddle.max(x, axis=1, keepdim=True)
    x_combined = paddle.concat([avg_out, max_out], axis=1)
    attention_map = self.conv(x_combined)
    return x * self.sigmoid(attention_map)

class CBAM(nn.Layer): """CBAM注意力模块(通道+空间)""" def init(self, in_channels, reduction=16, kernel_size=7): super(CBAM, self).init() self.channel_attention = ChannelAttention(in_channels, reduction) self.spatial_attention = SpatialAttention(kernel_size)

def forward(self, x):
    x = self.channel_attention(x)
    x = self.spatial_attention(x)
    return x

class RoadConnectivityAttention(nn.Layer): """专门针对道路连通性的注意力机制""" def init(self, in_channels, kernel_size=5): super(RoadConnectivityAttention, self).init() # 使用大卷积核捕捉长距离道路依赖 self.road_conv = nn.Conv2D(in_channels, in_channels, kernel_size, padding=kernel_size//2, groups=in_channels, bias_attr=False) self.road_bn = nn.BatchNorm2D(in_channels) self.road_act = nn.ReLU()

    # 通道重校准
    self.gap = nn.AdaptiveAvgPool2d(1)
    self.channel_fc = nn.Sequential(
        nn.Linear(in_channels, in_channels // 4),
        nn.ReLU(),
        nn.Linear(in_channels // 4, in_channels),
        nn.Sigmoid()
    )
    
def forward(self, x):
    b, c, h, w = x.shape
    
    # 道路结构增强
    road_feat = self.road_conv(x)
    road_feat = self.road_bn(road_feat)
    road_feat = self.road_act(road_feat)
    
    # 通道注意力
    channel_weights = self.gap(x).reshape([b, c])
    channel_weights = self.channel_fc(channel_weights).reshape([b, c, 1, 1])
    
    return x + road_feat * channel_weights

deeplabv3p_cbam.py,带有注意力模块的DPV3模型

import paddle import paddle.nn as nn from paddleseg.cvlibs import manager from paddleseg.models import DeepLabV3P from deeplabv3p_attention import CBAM, RoadConnectivityAttention

@manager.MODELS.add_component class DeepLabV3P_CBAM(DeepLabV3P): """集成CBAM注意力机制的DeepLabV3+""" def init(self, num_classes=19, backbone='ResNet101_vd', backbone_indices=(0, 1, 2, 3), aspp_ratios=(1, 3, 6, 12, 18), aspp_out_channels=256, align_corners=False, pretrained=None):

    super().__init__(
        num_classes=num_classes,
        backbone=backbone,
        backbone_indices=backbone_indices,
        aspp_ratios=aspp_ratios,
        aspp_out_channels=aspp_out_channels,
        align_corners=align_corners,
        pretrained=pretrained
    )
    
    # 在ASPP模块后添加CBAM注意力
    aspp_out_channels = 256  # 与您的配置一致
    self.aspp_cbam = CBAM(aspp_out_channels)
    
    # 在解码器路径添加道路连通性注意力
    decoder_low_level_channels = 256  # 调整为您backbone的实际通道数
    self.decoder_cbam = RoadConnectivityAttention(decoder_low_level_channels)
    
    # 最终预测前的注意力
    self.final_cbam = CBAM(decoder_low_level_channels // 2)

def forward(self, x):
    # 编码器特征提取
    feats = self.backbone(x)
    low_level_feat = feats[self.backbone_indices[1]]  # 低级特征
    high_level_feat = feats[self.backbone_indices[0]]  # 高级特征
    
    # ASPP模块处理高级特征
    aspp_out = self.aspp(high_level_feat)
    
    # 在ASPP后应用注意力机制
    aspp_out = self.aspp_cbam(aspp_out)
    
    # 上采样ASPP输出
    aspp_out = self.conv_bn_relu(aspp_out)
    aspp_out = F.interpolate(
        aspp_out,
        size=low_level_feat.shape[2:],
        mode='bilinear',
        align_corners=self.align_corners
    )
    
    # 处理低级特征并应用道路连通性注意力
    low_level_feat = self.low_level_conv(low_level_feat)
    low_level_feat = self.decoder_cbam(low_level_feat)
    
    # 特征融合
    fused_feat = paddle.concat([aspp_out, low_level_feat], axis=1)
    fused_feat = self.conv_bn_relu_fusion(fused_feat)
    
    # 最终预测前的注意力增强
    fused_feat = self.final_cbam(fused_feat)
    
    # 上采样到原图大小
    output = F.interpolate(
        fused_feat,
        size=x.shape[2:],
        mode='bilinear',
        align_corners=self.align_corners
    )
    
    return [output]

然后我把这两个文件放在了/root/PaddleX/paddlex/repo_manager/repos/PaddleSeg/paddleseg/models下面,并在__init__.py中引入

from .deeplabv3p_cbam import DeepLabV3P_CBAM

在register.py中进行了注册:

register_model_info( { "model_name": "DeepLabV3P_CBAM", "suite": "Seg", "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3_Plus-CBAM-R-R101.yaml"), "supported_apis": ["train", "evaluate", "predict", "export"], } )

同时在model_list.py中添加了记录

MODELS = [ "DeepLabV3P_CBAM" "Deeplabv3_Plus-R101", "Deeplabv3_Plus-R50", "Deeplabv3-R101", "Deeplabv3-R50", "OCRNet_HRNet-W48", "OCRNet_HRNet-W18", "PP-LiteSeg-T", "PP-LiteSeg-B", "SegFormer-B0", "SegFormer-B1", "SegFormer-B2", "SegFormer-B3", "SegFormer-B4", "SegFormer-B5", "SeaFormer_base", "SeaFormer_tiny", "SeaFormer_small", "SeaFormer_large", "MaskFormer_tiny", "MaskFormer_small", ]

用命令行执行

python main.py -c paddlex/configs/modules/semantic_segmentation/Deeplabv3_Plus-CBAM-R-R101.yaml -o Global.mode=train -o Global.dataset_dir=/root/data/dataset/changxin2 -o Global.device=gpu:1 -o Train.batch_size=2 -o Train.learning_rate=0.001 -o Train.num_classes=2 -o Train.epochs_iters=80000 -o Global.output=/root/PaddleX/output/rd_seg_model

但是执行的时候就会出现错误:

Traceback (most recent call last): File "/root/PaddleX/paddlex/utils/result_saver.py", line 28, in wrap result = func(self, *args, **kwargs) File "/root/PaddleX/paddlex/engine.py", line 41, in run self._model.train() File "/root/PaddleX/paddlex/model.py", line 118, in train trainer = build_trainer(self._config) File "/root/PaddleX/paddlex/modules/base/trainer.py", line 44, in build_trainer return BaseTrainer.get(model_name)(config) File "/root/PaddleX/paddlex/utils/misc.py", line 196, in get raise_class_not_found_error(name, cls, all_entities) File "/root/PaddleX/paddlex/utils/errors/others.py", line 112, in raise_class_not_found_error raise ClassNotFoundException(msg) paddlex.utils.errors.others.ClassNotFoundException: DeepLabV3P_CBAM is not registered on BaseTrainer. The registied entities: [STFPM, CLIP_vit_base_patch16_224, CLIP_vit_large_patch14_224, ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384, MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1_x1_0, MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x1_0, MobileNetV2_x1_5, MobileNetV2_x2_0, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25, MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV4_conv_small, MobileNetV4_conv_medium, MobileNetV4_conv_large, MobileNetV4_hybrid_medium, MobileNetV4_hybrid_large, PP-HGNet_tiny, PP-HGNet_small, PP-HGNet_base, PP-HGNetV2-B0, PP-HGNetV2-B1, PP-HGNetV2-B2, PP-HGNetV2-B3, PP-HGNetV2-B4, PP-HGNetV2-B5, PP-HGNetV2-B6, PP-LCNet_x0_25, PP-LCNet_x0_25_textline_ori, PP-LCNet_x0_35, PP-LCNet_x0_5, PP-LCNet_x0_75, PP-LCNet_x1_0, PP-LCNet_x1_0_doc_ori, PP-LCNet_x1_0_textline_ori, PP-LCNet_x1_5, PP-LCNet_x2_0, PP-LCNet_x2_5, PP-LCNetV2_small, PP-LCNetV2_base, PP-LCNetV2_large, ResNet101, ResNet152, ResNet18, ResNet34, ResNet50, ResNet200_vd, ResNet101_vd, ResNet152_vd, ResNet18_vd, ResNet34_vd, ResNet50_vd, SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384, StarNet-S1, StarNet-S2, StarNet-S3, StarNet-S4, FasterNet-L, FasterNet-M, FasterNet-S, FasterNet-T0, FasterNet-T1, FasterNet-T2, PP-LCNet_x1_0_table_cls, MobileFaceNet, ResNet50_face, PP-ShiTuV2_rec, PP-ShiTuV2_rec_CLIP_vit_base, PP-ShiTuV2_rec_CLIP_vit_large, PicoDet-L, PicoDet-S, PP-YOLOE_plus-L, PP-YOLOE_plus-M, PP-YOLOE_plus-S, PP-YOLOE_plus-X, RT-DETR-H, RT-DETR-L, RT-DETR-R18, RT-DETR-R50, RT-DETR-X, PicoDet_layout_1x, PicoDet_layout_1x_table, PicoDet-S_layout_3cls, PicoDet-S_layout_17cls, PicoDet-L_layout_3cls, PicoDet-L_layout_17cls, RT-DETR-H_layout_3cls, RT-DETR-H_layout_17cls, YOLOv3-DarkNet53, YOLOv3-MobileNetV3, YOLOv3-ResNet50_vd_DCN, YOLOX-L, YOLOX-M, YOLOX-N, YOLOX-S, YOLOX-T, YOLOX-X, FasterRCNN-ResNet34-FPN, FasterRCNN-ResNet50, FasterRCNN-ResNet50-FPN, FasterRCNN-ResNet50-vd-FPN, FasterRCNN-ResNet50-vd-SSLDv2-FPN, FasterRCNN-ResNet101, FasterRCNN-ResNet101-FPN, FasterRCNN-ResNeXt101-vd-FPN, FasterRCNN-Swin-Tiny-FPN, Cascade-FasterRCNN-ResNet50-FPN, Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN, PicoDet-M, PicoDet-XS, FCOS-ResNet50, DETR-R50, PP-ShiTuV2_det, PP-YOLOE-L_human, PP-YOLOE-S_human, PP-YOLOE-L_vehicle, PP-YOLOE-S_vehicle, PP-YOLOE_plus_SOD-L, PP-YOLOE_plus_SOD-S, PP-YOLOE_plus_SOD-largesize-L, CenterNet-DLA-34, CenterNet-ResNet50, PicoDet_LCNet_x2_5_face, BlazeFace, BlazeFace-FPN-SSH, PP-YOLOE_plus-S_face, PP-YOLOE-R-L, Co-Deformable-DETR-R50, Co-Deformable-DETR-Swin-T, Co-DINO-R50, Co-DINO-Swin-L, RT-DETR-L_wired_table_cell_det, RT-DETR-L_wireless_table_cell_det, PP-DocLayout-L, PP-DocLayout-M, PP-DocLayout-S, PP-DocLayout_plus-L, PP-DocBlockLayout, Mask-RT-DETR-S, Mask-RT-DETR-M, Mask-RT-DETR-X, Mask-RT-DETR-H, Mask-RT-DETR-L, SOLOv2, MaskRCNN-ResNet50, MaskRCNN-ResNet50-FPN, MaskRCNN-ResNet50-vd-FPN, MaskRCNN-ResNet101-FPN, MaskRCNN-ResNet101-vd-FPN, MaskRCNN-ResNeXt101-vd-FPN, MaskRCNN-ResNet50-vd-SSLDv2-FPN, Cascade-MaskRCNN-ResNet50-FPN, Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN, PP-YOLOE_seg-S, PP-TinyPose_128x96, PP-TinyPose_256x192, BEVFusion, ResNet50_ML, PP-LCNet_x1_0_ML, PP-HGNetV2-B0_ML, PP-HGNetV2-B4_ML, PP-HGNetV2-B6_ML, CLIP_vit_base_patch16_448_ML, PP-LCNet_x1_0_pedestrian_attribute, PP-LCNet_x1_0_vehicle_attribute, whisper_large, whisper_medium, whisper_base, whisper_small, whisper_tiny, DeepLabV3P_CBAMDeeplabv3_Plus-R101, Deeplabv3_Plus-R50, Deeplabv3-R101, Deeplabv3-R50, OCRNet_HRNet-W48, OCRNet_HRNet-W18, PP-LiteSeg-T, PP-LiteSeg-B, SegFormer-B0, SegFormer-B1, SegFormer-B2, SegFormer-B3, SegFormer-B4, SegFormer-B5, SeaFormer_base, SeaFormer_tiny, SeaFormer_small, SeaFormer_large, MaskFormer_tiny, MaskFormer_small, SLANet, SLANet_plus, SLANeXt_wired, SLANeXt_wireless, PP-OCRv5_mobile_det, PP-OCRv5_server_det, PP-OCRv4_mobile_det, PP-OCRv4_server_det, PP-OCRv4_mobile_seal_det, PP-OCRv4_server_seal_det, PP-OCRv3_mobile_det, PP-OCRv3_server_det, PP-OCRv3_mobile_rec, en_PP-OCRv3_mobile_rec, korean_PP-OCRv3_mobile_rec, japan_PP-OCRv3_mobile_rec, chinese_cht_PP-OCRv3_mobile_rec, te_PP-OCRv3_mobile_rec, ka_PP-OCRv3_mobile_rec, ta_PP-OCRv3_mobile_rec, latin_PP-OCRv3_mobile_rec, arabic_PP-OCRv3_mobile_rec, cyrillic_PP-OCRv3_mobile_rec, devanagari_PP-OCRv3_mobile_rec, PP-OCRv4_mobile_rec, PP-OCRv4_server_rec, en_PP-OCRv4_mobile_rec, PP-OCRv4_server_rec_doc, ch_SVTRv2_rec, ch_RepSVTR_rec, PP-OCRv5_server_rec, PP-OCRv5_mobile_rec, AutoEncoder_ad, DLinear_ad, Nonstationary_ad, PatchTST_ad, TimesNet_ad, TimesNet_cls, DLinear, NLinear, Nonstationary, PatchTST, RLinear, TiDE, TimesNet, PP-TSM-R50_8frames_uniform, PP-TSMv2-LCNetV2_8frames_uniform, PP-TSMv2-LCNetV2_16frames_uniform, YOWO, PP-DocBee-2B, PP-DocBee-7B, PP-Chart2Table, PP-DocBee2-3B, LaTeX_OCR_rec, UniMERNet, PP-FormulaNet-S, PP-FormulaNet-L, PP-FormulaNet_plus-S, PP-FormulaNet_plus-M, PP-FormulaNet_plus-L, GroundingDINO-T, YOLO-Worldv2-L, SAM-H_point, SAM-H_box]

环境

  1. 请提供您使用的PaddlePaddle和PaddleX的版本号

paddlex:3.0

  1. 请提供您使用的操作系统信息,如Linux/Windows/MacOS

centos 8.0

  1. 请问您使用的Python版本是?

python 3.10.18

  1. 请问您使用的CUDA/cuDNN的版本号是? cuda 12 cuDNN8.7

Contributor guide