`DeepLabV3P_CBAM` is not registered on BaseTrainer.
#4739 opened on Nov 18, 2025
Description
Checklist:
- 查找历史相关issue寻求解答
- 翻阅FAQ
- 翻阅PaddleX 文档
- 确认bug是否在新版本里还未修复
描述问题
我在用paddlex的语义分割模型做遥感的道路提取,因为默认的DeepLabV3_plus-Res v101 里面不包含注意力机制,所以我自己扩展了一下,现在我希望加入这个模型去做训练,可是会出现paddlex.utils.errors.others.ClassNotFoundException: DeepLabV3P_CBAM is not registered on BaseTrainer.
的问题
复现
1.我扩展了两个文件。 deeplabv3p_attention.py,用来创建注意力模块
import paddle import paddle.nn as nn import paddle.nn.functional as F
class ChannelAttention(nn.Layer): """通道注意力机制""" def init(self, in_channels, ratio=16): super(ChannelAttention, self).init() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2D(in_channels, in_channels // ratio, 1, bias_attr=False),
nn.ReLU(),
nn.Conv2D(in_channels // ratio, in_channels, 1, bias_attr=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return x * self.sigmoid(out)
class SpatialAttention(nn.Layer): """空间注意力机制""" def init(self, kernel_size=7): super(SpatialAttention, self).init() self.conv = nn.Conv2D(2, 1, kernel_size, padding=kernel_size//2, bias_attr=False) self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = paddle.mean(x, axis=1, keepdim=True)
max_out = paddle.max(x, axis=1, keepdim=True)
x_combined = paddle.concat([avg_out, max_out], axis=1)
attention_map = self.conv(x_combined)
return x * self.sigmoid(attention_map)
class CBAM(nn.Layer): """CBAM注意力模块(通道+空间)""" def init(self, in_channels, reduction=16, kernel_size=7): super(CBAM, self).init() self.channel_attention = ChannelAttention(in_channels, reduction) self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
x = self.channel_attention(x)
x = self.spatial_attention(x)
return x
class RoadConnectivityAttention(nn.Layer): """专门针对道路连通性的注意力机制""" def init(self, in_channels, kernel_size=5): super(RoadConnectivityAttention, self).init() # 使用大卷积核捕捉长距离道路依赖 self.road_conv = nn.Conv2D(in_channels, in_channels, kernel_size, padding=kernel_size//2, groups=in_channels, bias_attr=False) self.road_bn = nn.BatchNorm2D(in_channels) self.road_act = nn.ReLU()
# 通道重校准
self.gap = nn.AdaptiveAvgPool2d(1)
self.channel_fc = nn.Sequential(
nn.Linear(in_channels, in_channels // 4),
nn.ReLU(),
nn.Linear(in_channels // 4, in_channels),
nn.Sigmoid()
)
def forward(self, x):
b, c, h, w = x.shape
# 道路结构增强
road_feat = self.road_conv(x)
road_feat = self.road_bn(road_feat)
road_feat = self.road_act(road_feat)
# 通道注意力
channel_weights = self.gap(x).reshape([b, c])
channel_weights = self.channel_fc(channel_weights).reshape([b, c, 1, 1])
return x + road_feat * channel_weights
deeplabv3p_cbam.py,带有注意力模块的DPV3模型
import paddle import paddle.nn as nn from paddleseg.cvlibs import manager from paddleseg.models import DeepLabV3P from deeplabv3p_attention import CBAM, RoadConnectivityAttention
@manager.MODELS.add_component class DeepLabV3P_CBAM(DeepLabV3P): """集成CBAM注意力机制的DeepLabV3+""" def init(self, num_classes=19, backbone='ResNet101_vd', backbone_indices=(0, 1, 2, 3), aspp_ratios=(1, 3, 6, 12, 18), aspp_out_channels=256, align_corners=False, pretrained=None):
super().__init__(
num_classes=num_classes,
backbone=backbone,
backbone_indices=backbone_indices,
aspp_ratios=aspp_ratios,
aspp_out_channels=aspp_out_channels,
align_corners=align_corners,
pretrained=pretrained
)
# 在ASPP模块后添加CBAM注意力
aspp_out_channels = 256 # 与您的配置一致
self.aspp_cbam = CBAM(aspp_out_channels)
# 在解码器路径添加道路连通性注意力
decoder_low_level_channels = 256 # 调整为您backbone的实际通道数
self.decoder_cbam = RoadConnectivityAttention(decoder_low_level_channels)
# 最终预测前的注意力
self.final_cbam = CBAM(decoder_low_level_channels // 2)
def forward(self, x):
# 编码器特征提取
feats = self.backbone(x)
low_level_feat = feats[self.backbone_indices[1]] # 低级特征
high_level_feat = feats[self.backbone_indices[0]] # 高级特征
# ASPP模块处理高级特征
aspp_out = self.aspp(high_level_feat)
# 在ASPP后应用注意力机制
aspp_out = self.aspp_cbam(aspp_out)
# 上采样ASPP输出
aspp_out = self.conv_bn_relu(aspp_out)
aspp_out = F.interpolate(
aspp_out,
size=low_level_feat.shape[2:],
mode='bilinear',
align_corners=self.align_corners
)
# 处理低级特征并应用道路连通性注意力
low_level_feat = self.low_level_conv(low_level_feat)
low_level_feat = self.decoder_cbam(low_level_feat)
# 特征融合
fused_feat = paddle.concat([aspp_out, low_level_feat], axis=1)
fused_feat = self.conv_bn_relu_fusion(fused_feat)
# 最终预测前的注意力增强
fused_feat = self.final_cbam(fused_feat)
# 上采样到原图大小
output = F.interpolate(
fused_feat,
size=x.shape[2:],
mode='bilinear',
align_corners=self.align_corners
)
return [output]
然后我把这两个文件放在了/root/PaddleX/paddlex/repo_manager/repos/PaddleSeg/paddleseg/models下面,并在__init__.py中引入
from .deeplabv3p_cbam import DeepLabV3P_CBAM
在register.py中进行了注册:
register_model_info( { "model_name": "DeepLabV3P_CBAM", "suite": "Seg", "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3_Plus-CBAM-R-R101.yaml"), "supported_apis": ["train", "evaluate", "predict", "export"], } )
同时在model_list.py中添加了记录
MODELS = [ "DeepLabV3P_CBAM" "Deeplabv3_Plus-R101", "Deeplabv3_Plus-R50", "Deeplabv3-R101", "Deeplabv3-R50", "OCRNet_HRNet-W48", "OCRNet_HRNet-W18", "PP-LiteSeg-T", "PP-LiteSeg-B", "SegFormer-B0", "SegFormer-B1", "SegFormer-B2", "SegFormer-B3", "SegFormer-B4", "SegFormer-B5", "SeaFormer_base", "SeaFormer_tiny", "SeaFormer_small", "SeaFormer_large", "MaskFormer_tiny", "MaskFormer_small", ]
用命令行执行
python main.py -c paddlex/configs/modules/semantic_segmentation/Deeplabv3_Plus-CBAM-R-R101.yaml -o Global.mode=train -o Global.dataset_dir=/root/data/dataset/changxin2 -o Global.device=gpu:1 -o Train.batch_size=2 -o Train.learning_rate=0.001 -o Train.num_classes=2 -o Train.epochs_iters=80000 -o Global.output=/root/PaddleX/output/rd_seg_model
但是执行的时候就会出现错误:
Traceback (most recent call last): File "/root/PaddleX/paddlex/utils/result_saver.py", line 28, in wrap
result = func(self, *args, **kwargs)
File "/root/PaddleX/paddlex/engine.py", line 41, in run
self._model.train()
File "/root/PaddleX/paddlex/model.py", line 118, in train
trainer = build_trainer(self._config)
File "/root/PaddleX/paddlex/modules/base/trainer.py", line 44, in build_trainer
return BaseTrainer.get(model_name)(config)
File "/root/PaddleX/paddlex/utils/misc.py", line 196, in get
raise_class_not_found_error(name, cls, all_entities)
File "/root/PaddleX/paddlex/utils/errors/others.py", line 112, in raise_class_not_found_error
raise ClassNotFoundException(msg)
paddlex.utils.errors.others.ClassNotFoundException: DeepLabV3P_CBAM is not registered on BaseTrainer.
The registied entities: [STFPM, CLIP_vit_base_patch16_224, CLIP_vit_large_patch14_224, ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384, MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1_x1_0, MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x1_0, MobileNetV2_x1_5, MobileNetV2_x2_0, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25, MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV4_conv_small, MobileNetV4_conv_medium, MobileNetV4_conv_large, MobileNetV4_hybrid_medium, MobileNetV4_hybrid_large, PP-HGNet_tiny, PP-HGNet_small, PP-HGNet_base, PP-HGNetV2-B0, PP-HGNetV2-B1, PP-HGNetV2-B2, PP-HGNetV2-B3, PP-HGNetV2-B4, PP-HGNetV2-B5, PP-HGNetV2-B6, PP-LCNet_x0_25, PP-LCNet_x0_25_textline_ori, PP-LCNet_x0_35, PP-LCNet_x0_5, PP-LCNet_x0_75, PP-LCNet_x1_0, PP-LCNet_x1_0_doc_ori, PP-LCNet_x1_0_textline_ori, PP-LCNet_x1_5, PP-LCNet_x2_0, PP-LCNet_x2_5, PP-LCNetV2_small, PP-LCNetV2_base, PP-LCNetV2_large, ResNet101, ResNet152, ResNet18, ResNet34, ResNet50, ResNet200_vd, ResNet101_vd, ResNet152_vd, ResNet18_vd, ResNet34_vd, ResNet50_vd, SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384, StarNet-S1, StarNet-S2, StarNet-S3, StarNet-S4, FasterNet-L, FasterNet-M, FasterNet-S, FasterNet-T0, FasterNet-T1, FasterNet-T2, PP-LCNet_x1_0_table_cls, MobileFaceNet, ResNet50_face, PP-ShiTuV2_rec, PP-ShiTuV2_rec_CLIP_vit_base, PP-ShiTuV2_rec_CLIP_vit_large, PicoDet-L, PicoDet-S, PP-YOLOE_plus-L, PP-YOLOE_plus-M, PP-YOLOE_plus-S, PP-YOLOE_plus-X, RT-DETR-H, RT-DETR-L, RT-DETR-R18, RT-DETR-R50, RT-DETR-X, PicoDet_layout_1x, PicoDet_layout_1x_table, PicoDet-S_layout_3cls, PicoDet-S_layout_17cls, PicoDet-L_layout_3cls, PicoDet-L_layout_17cls, RT-DETR-H_layout_3cls, RT-DETR-H_layout_17cls, YOLOv3-DarkNet53, YOLOv3-MobileNetV3, YOLOv3-ResNet50_vd_DCN, YOLOX-L, YOLOX-M, YOLOX-N, YOLOX-S, YOLOX-T, YOLOX-X, FasterRCNN-ResNet34-FPN, FasterRCNN-ResNet50, FasterRCNN-ResNet50-FPN, FasterRCNN-ResNet50-vd-FPN, FasterRCNN-ResNet50-vd-SSLDv2-FPN, FasterRCNN-ResNet101, FasterRCNN-ResNet101-FPN, FasterRCNN-ResNeXt101-vd-FPN, FasterRCNN-Swin-Tiny-FPN, Cascade-FasterRCNN-ResNet50-FPN, Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN, PicoDet-M, PicoDet-XS, FCOS-ResNet50, DETR-R50, PP-ShiTuV2_det, PP-YOLOE-L_human, PP-YOLOE-S_human, PP-YOLOE-L_vehicle, PP-YOLOE-S_vehicle, PP-YOLOE_plus_SOD-L, PP-YOLOE_plus_SOD-S, PP-YOLOE_plus_SOD-largesize-L, CenterNet-DLA-34, CenterNet-ResNet50, PicoDet_LCNet_x2_5_face, BlazeFace, BlazeFace-FPN-SSH, PP-YOLOE_plus-S_face, PP-YOLOE-R-L, Co-Deformable-DETR-R50, Co-Deformable-DETR-Swin-T, Co-DINO-R50, Co-DINO-Swin-L, RT-DETR-L_wired_table_cell_det, RT-DETR-L_wireless_table_cell_det, PP-DocLayout-L, PP-DocLayout-M, PP-DocLayout-S, PP-DocLayout_plus-L, PP-DocBlockLayout, Mask-RT-DETR-S, Mask-RT-DETR-M, Mask-RT-DETR-X, Mask-RT-DETR-H, Mask-RT-DETR-L, SOLOv2, MaskRCNN-ResNet50, MaskRCNN-ResNet50-FPN, MaskRCNN-ResNet50-vd-FPN, MaskRCNN-ResNet101-FPN, MaskRCNN-ResNet101-vd-FPN, MaskRCNN-ResNeXt101-vd-FPN, MaskRCNN-ResNet50-vd-SSLDv2-FPN, Cascade-MaskRCNN-ResNet50-FPN, Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN, PP-YOLOE_seg-S, PP-TinyPose_128x96, PP-TinyPose_256x192, BEVFusion, ResNet50_ML, PP-LCNet_x1_0_ML, PP-HGNetV2-B0_ML, PP-HGNetV2-B4_ML, PP-HGNetV2-B6_ML, CLIP_vit_base_patch16_448_ML, PP-LCNet_x1_0_pedestrian_attribute, PP-LCNet_x1_0_vehicle_attribute, whisper_large, whisper_medium, whisper_base, whisper_small, whisper_tiny, DeepLabV3P_CBAMDeeplabv3_Plus-R101, Deeplabv3_Plus-R50, Deeplabv3-R101, Deeplabv3-R50, OCRNet_HRNet-W48, OCRNet_HRNet-W18, PP-LiteSeg-T, PP-LiteSeg-B, SegFormer-B0, SegFormer-B1, SegFormer-B2, SegFormer-B3, SegFormer-B4, SegFormer-B5, SeaFormer_base, SeaFormer_tiny, SeaFormer_small, SeaFormer_large, MaskFormer_tiny, MaskFormer_small, SLANet, SLANet_plus, SLANeXt_wired, SLANeXt_wireless, PP-OCRv5_mobile_det, PP-OCRv5_server_det, PP-OCRv4_mobile_det, PP-OCRv4_server_det, PP-OCRv4_mobile_seal_det, PP-OCRv4_server_seal_det, PP-OCRv3_mobile_det, PP-OCRv3_server_det, PP-OCRv3_mobile_rec, en_PP-OCRv3_mobile_rec, korean_PP-OCRv3_mobile_rec, japan_PP-OCRv3_mobile_rec, chinese_cht_PP-OCRv3_mobile_rec, te_PP-OCRv3_mobile_rec, ka_PP-OCRv3_mobile_rec, ta_PP-OCRv3_mobile_rec, latin_PP-OCRv3_mobile_rec, arabic_PP-OCRv3_mobile_rec, cyrillic_PP-OCRv3_mobile_rec, devanagari_PP-OCRv3_mobile_rec, PP-OCRv4_mobile_rec, PP-OCRv4_server_rec, en_PP-OCRv4_mobile_rec, PP-OCRv4_server_rec_doc, ch_SVTRv2_rec, ch_RepSVTR_rec, PP-OCRv5_server_rec, PP-OCRv5_mobile_rec, AutoEncoder_ad, DLinear_ad, Nonstationary_ad, PatchTST_ad, TimesNet_ad, TimesNet_cls, DLinear, NLinear, Nonstationary, PatchTST, RLinear, TiDE, TimesNet, PP-TSM-R50_8frames_uniform, PP-TSMv2-LCNetV2_8frames_uniform, PP-TSMv2-LCNetV2_16frames_uniform, YOWO, PP-DocBee-2B, PP-DocBee-7B, PP-Chart2Table, PP-DocBee2-3B, LaTeX_OCR_rec, UniMERNet, PP-FormulaNet-S, PP-FormulaNet-L, PP-FormulaNet_plus-S, PP-FormulaNet_plus-M, PP-FormulaNet_plus-L, GroundingDINO-T, YOLO-Worldv2-L, SAM-H_point, SAM-H_box]
环境
- 请提供您使用的PaddlePaddle和PaddleX的版本号
paddlex:3.0
- 请提供您使用的操作系统信息,如Linux/Windows/MacOS
centos 8.0
- 请问您使用的Python版本是?
python 3.10.18
- 请问您使用的CUDA/cuDNN的版本号是? cuda 12 cuDNN8.7