Lightning-AI/pytorch-lightning

Validation runs only for one iteration when restarting from checkpoint mid-epoch, wrongly reporting validation loss

Open

#19,549 创建于 2024年2月29日

在 GitHub 查看
 (3 评论) (0 反应) (0 负责人)Python (26,687 star) (3,233 fork)batch import
bughelp wantedloops

描述

Bug description

When resuming from a mid-epoch checkpoint (which I have to use as my dataset is large), the training loop runs a validation loop for only one iteration, which leads to wrong validation loss logged.

It appears like the batch_progress of lighting.pytorch.loops._EvaluationLoop wrongly gets filled from the checkpoint as if the validation loop was already done, and not properly reset after the checkpoint is loaded.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import os

import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        print("Training", batch.shape, loss.item(), batch_idx)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        print("Validation", batch.shape, loss.item())
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def on_train_epoch_end(self) -> None:
        return super().on_train_epoch_end()


def run(ckpt_path):
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=4,
        limit_val_batches=3,
        val_check_interval=2,
        max_epochs=1,
        enable_model_summary=False,
        enable_progress_bar=False,
        num_sanity_val_steps=0,
        logger=False,
        callbacks=[
            ModelCheckpoint(
                save_last=False,
                save_top_k=10,
                monitor="valid_loss",
                every_n_train_steps=1,
                dirpath="./checkpoints",
                enable_version_counter=False,
            )
        ],
    )
    trainer.fit(
        model,
        train_dataloaders=train_data,
        val_dataloaders=val_data,
        ckpt_path=ckpt_path,
    )


run(ckpt_path=None)
run(ckpt_path="checkpoints/epoch=0-step=3.ckpt")

Error messages and logs

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/.../lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory /local/mnt/workspace/pim/projects/equi-scaling/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/.../lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/.../lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
Training torch.Size([2, 32]) -0.8351932168006897 0
/.../lightning/pytorch/callbacks/model_checkpoint.py:382: `ModelCheckpoint(monitor='valid_loss')` could not find the monitored key in the returned metrics: ['train_loss', 'epoch', 'step']. HINT: Did you call `log('valid_loss', value)` in the `LightningModule`?
Training torch.Size([2, 32]) -6.302826881408691 1
Validation torch.Size([2, 32]) -3.5122809410095215
Validation torch.Size([2, 32]) 2.169618844985962
Validation torch.Size([2, 32]) -6.107339859008789
Training torch.Size([2, 32]) -7.338858604431152 2
Training torch.Size([2, 32]) 9.845891952514648 3
Validation torch.Size([2, 32]) -4.606845378875732
Validation torch.Size([2, 32]) 8.005867004394531
Validation torch.Size([2, 32]) -4.361298561096191
`Trainer.fit` stopped: `max_epochs=1` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at checkpoints/epoch=0-step=3.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Restored all states from the checkpoint at checkpoints/epoch=0-step=3.ckpt
/.../lightning/pytorch/loops/training_epoch_loop.py:156: You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable results if further training is done. Consider using an end-of-epoch checkpoint
Validation torch.Size([2, 32]) 8.310140609741211   # <--- only one iteration
Training torch.Size([2, 32]) -4.653291702270508 2
Training torch.Size([2, 32]) -8.187255859375 3
Validation torch.Size([2, 32]) 9.539518356323242
Validation torch.Size([2, 32]) -8.617881774902344
Validation torch.Size([2, 32]) 0.5192334651947021
`Trainer.fit` stopped: `max_epochs=1` reached.

Environment

* CUDA:
        - GPU:
                - NVIDIA GeForce RTX 2080 Ti
        - available:         True
        - version:           11.7
* Lightning:
        - lightning:         2.2.0.post0
        - lightning-utilities: 0.10.1
        - pytorch-lightning: 2.2.0.post0
        - torch:             2.0.1
        - torch-ema:         0.3
        - torch-geometric:   2.5.0
        - torch-scatter:     2.1.2+pt20cu117
        - torchmetrics:      1.3.1
        - torchvision:       0.15.2
* System:
        - OS:                Linux
        - architecture:
                - 64bit
        - processor:         x86_64
        - python:            3.10.13
        - release:           5.4.0-152-generic

More info

A fix/workaround for this issue, is to add self.batch_progress.reset_on_run() at the end of _EvaluationLoop.run.

cc @carmocca @justusschock

贡献者指南