Lightning-AI/pytorch-lightning

training time increase epoch by epoch

Open

#20,076 opened on 2024年7月12日

GitHub で見る
 (2 comments) (0 reactions) (0 assignees)Python (26,687 stars) (3,233 forks)batch import
bughelp wantedperformancerepro neededver: 2.2.x

説明

Bug description

when I run the following code, the training time of the epoch will increase epoch by epoch. For example, the first epoch takes 3:39 min, and the second on takes 4:21min, and the third one takes 5:46 min ..., I don't know why. The following is my code . And the version of lightning my used is 2.3.1

What version are you seeing the problem on?

master

How to reproduce the bug

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File    :   train_pl.py
@Time    :   2024/07/04 17:07:44
@Author  :   Lin Weiquan 
@Version :   1.0
@Desc    :   基于pytorch-lightning训练
'''

from lib.models.besizer_crnn_v3_mask import CRNN_V3_1
from lib.models.besizer_crnn_v7_mask import CRNN_v7
from torch.utils.data import DataLoader
from lib.dataset import get_dataset
from lib.dataset.variable_width import DistCollateFn
from easydict import EasyDict as edict
from lib.utils.utils import model_info
from numpy import *
from pathlib import Path
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
import lib.utils.utils as utils
import torch
import torch.nn as nn
import editdistance
import argparse
import yaml
import time
import os
import torch.backends.cudnn as cudnn
import lightning as pl
import lib.config.alphabets as alphabets
import lib.config.alphabets_shuffle as alphabets_shuffle
import lib.config.alphabets_shuffle2 as alphabets_shuffle2
import lib.config.alphabets_shuffle3 as alphabets_shuffle3



def parse_arg():
    parser = argparse.ArgumentParser(description="train crnn_v7")

    parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)

    args = parser.parse_args()

    with open(args.cfg, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
        #config = yaml.load(f)
        config = edict(config)

    if not config.DATASET.SHUFFLE:
        config.DATASET.ALPHABETS = alphabets.alphabet
    else:
        config.DATASET.ALPHABETS = alphabets_shuffle3.alphabet
    config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
    print("NUM_CLASSES: ", config.MODEL.NUM_CLASSES)

    # try:
    #     config.TRAIN.DISTRIBUTED.LOCAL_RANK = int(os.environ["LOCAL_RANK"])
    # except:
    #     config.TRAIN.DISTRIBUTED.LOCAL_RANK = -1


    return config


class LightningFreeWrite(pl.LightningModule):
    def __init__(self, config):
        super(LightningFreeWrite, self).__init__()
        self.model = CRNN_V3_1(nclass=config.MODEL.NUM_CLASSES + 1,  nh=config.MODEL.NUM_HIDDEN)
        self.criterion = torch.nn.CTCLoss(zero_infinity=True)
        self.config = config
        self.converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
        self.best_char_acc = 0.0
        self.train_step_loss_outputs = []
        self.validation_step_loss_outputs = []
        self.validation_step_char_outputs = []
        self.validation_step_accuracy_outputs = []

    def forward(self, x, mask):
        output, mask = self.model(x, mask)
        return output ,mask

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.TRAIN.LR)
        last_epoch = self.config.TRAIN.BEGIN_EPOCH
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, self.config.TRAIN.LR_STEP,
            self.config.TRAIN.LR_FACTOR, last_epoch-1
        )
        opt_sched = {"scheduler": lr_scheduler, "interval": "epoch"}
        return [optimizer], [opt_sched]
    
    def ctc_loss(self, preds, gts, input_lens, length):
        return self.criterion(preds, gts, input_lens, length)
    
    def training_step(self, train_batch, batch_idx):
        
        inp, labels, masks, input_lens = train_batch

        bs = inp.size(0)

        # model infer
        preds, masks = self.forward(inp, masks)

        preds = preds.permute(1,0,2)
        
        # compute loss
        
        gts, length, ace_labels = self.converter.encode(labels)
       
        gts = gts.long()

        preds_size = torch.IntTensor([preds.size(0)] * bs)

        # loss = self.criterion(preds, gts, input_lens, length)
        loss = self.criterion(preds, gts, preds_size, length)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        # self.train_step_loss_outputs.append(loss)
        # self.log('train_loss', loss,sync_dist=True)
        return loss

    # def on_train_epoch_end(self):
        
    #     avg_loss = torch.stack(self.train_step_loss_outputs).mean().item()
    #     self.log('avg_train_loss', avg_loss, sync_dist=True, on_epoch=True, logger=True, prog_bar=True)
    #     # self.train_step_loss_outputs.clear()
    #     del self.train_step_loss_outputs
    #     self.train_step_loss_outputs = []
        # return avg_loss


    def validation_step(self, val_batch, batch_idx):
        inp, labels, masks, input_lens = val_batch

        # model infer
        preds, masks = self.forward(inp, masks)

        preds = preds.permute(1,0,2)
        
        # compute loss

        bs = inp.size(0)
        
        gts, length, ace_labels = self.converter.encode(labels)
       
        gts = gts.long()

        preds_size = torch.IntTensor([preds.size(0)] * bs)


        # loss = self.criterion(preds, gts, input_lens, length)
        loss = self.criterion(preds, gts, preds_size, length)

        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        # print(preds.data)
        sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False)
        

        n_correct = 0
        sum_char = 0
        error_char = 0
        for pred, target in zip(sim_preds, labels):
            if pred == target:
                n_correct += 1
            sum_char += len(target)
            edit_distance = editdistance.eval(pred, target)
            error_char += edit_distance
        accuracy = n_correct / len(labels)
        char_acc = 1 - error_char / sum_char
        # self.log('tl_acc', accuracy,sync_dist=True)
        # self.log('char_acc', char_acc,sync_dist=True)

        self.validation_step_loss_outputs.append(loss)
        self.validation_step_char_outputs.append(char_acc)
        self.validation_step_accuracy_outputs.append(accuracy)
        return loss
        
    
    def on_validation_epoch_end(self):
        # print(self.validation_step_outputs)
        
        avg_loss = torch.stack(self.validation_step_loss_outputs).mean().item()
        avg_char_acc = mean(self.validation_step_char_outputs)
        avg_tl_acc = mean(self.validation_step_accuracy_outputs)
        self.log('avg_val_loss', avg_loss,sync_dist=True, on_epoch=True, logger=True, prog_bar=True)
        self.log('avg_char_acc', avg_char_acc,sync_dist=True, on_epoch=True, logger=True, prog_bar=True)
        self.log('avg_tl_acc', avg_tl_acc,sync_dist=True, on_epoch=True, logger=True, prog_bar=True)
        # self.validation_step_loss_outputs.clear() # free memory
        # self.validation_step_char_outputs.clear()
        # self.validation_step_accuracy_outputs.clear()
        del self.validation_step_loss_outputs, self.validation_step_char_outputs, self.validation_step_accuracy_outputs
        self.validation_step_loss_outputs = []
        self.validation_step_char_outputs = []
        self.validation_step_accuracy_outputs = []
    
def main():
    # os.environ['OMP_NUM_THREADS'] = '1'
    pl.seed_everything(9958, workers=True)
    torch.set_float32_matmul_precision('high')

    # load config
    config = parse_arg()

    # output_dict = utils.create_log_folder(config, phase='train')

    # cudnn
    # cudnn.benchmark = config.CUDNN.BENCHMARK
    # cudnn.deterministic = config.CUDNN.DETERMINISTIC
    # cudnn.enabled = config.CUDNN.ENABLED

    custom_preprocess = None
    
    train_dataset = get_dataset(config)(config, custom_preprocess, is_train=True)
    val_dataset = get_dataset(config)(config, custom_preprocess, is_train=False)
    
    
    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        # shuffle=config.TRAIN.SHUFFLE,
        drop_last=True,
        collate_fn = DistCollateFn(training = True),
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    
    val_dataloader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        # shuffle=config.TEST.SHUFFLE,
        drop_last=True,
        collate_fn = DistCollateFn(training = True),
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    time_step = time.strftime("%Y-%m-%d-%H-%M", time.localtime())

    save_log_path = os.path.join('/result', config.OUTPUT_DIR, config.DATASET.DATASET, config.MODEL.NAME, time_step, "log")
    save_ckpt_path = os.path.join('/result', config.OUTPUT_DIR, config.DATASET.DATASET, config.MODEL.NAME, time_step, "checkpoint")
    

    each_checkpoint = ModelCheckpoint(dirpath=save_ckpt_path,
                                        every_n_epochs=1, monitor="avg_char_acc", save_top_k=-1, save_last=True,filename='{epoch}-{step}-{avg_char_acc:.4f}')
   

    model = LightningFreeWrite(config=config)
    # 创建一个进度条回调实例
    progress_bar = TQDMProgressBar(refresh_rate=500)
    logger = TensorBoardLogger('tb_logs', name='crnn_v7',version=0, log_graph=True)
    trainer = pl.Trainer(accelerator="gpu", 
                        devices=[0,1], 
                        strategy="ddp", 
                        logger = logger,
                        # progress_bar_refresh_rate=100,
                        max_epochs=config.TRAIN.END_EPOCH,
                        default_root_dir=save_log_path,
                        enable_checkpointing=True,
                        gradient_clip_val=True,
                        callbacks=[each_checkpoint,progress_bar])

    trainer.fit(model, train_dataloader, val_dataloader)
    


if __name__ == '__main__':
    
    main()

Error messages and logs

# Error messages and logs here please

Environment

#- PyTorch Lightning Version (e.g., 1.5.0):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @borda

コントリビューターガイド