Lightning-AI/pytorch-lightning

Loggers fails to create metrics.csv file when running on multiple TPU cores

Open

#19,035 opened on 2023年11月20日

GitHub で見る
 (3 comments) (0 reactions) (0 assignees)Python (26,687 stars) (3,233 forks)batch import
bughelp wantedstrategy: xlaver: 2.2.x

説明

Bug description

Running the mnist-tutorial from Lightning-AI doesn't create a metrics.csv file when run on a v2-8 Cloud TPU using all 8 cores. This issue reproduces even after killing all running python processes and restarting the python3 kernel on Jupyter.

When setting devices=1 so that the model trains on a single core, the metrics.csv seems to always get created. Reproduced this issue on the latest stable build (2.1.0), as well as nightly (2.2.0dev).

What version are you seeing the problem on?

master

How to reproduce the bug

import lightning as L

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
# from torchmetrics.functional import accuracy
from torchvision import transforms

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST

BATCH_SIZE = 1024

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
     
class LitModel(L.LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        super().__init__()
        self.num_classes = num_classes # Needed to calculate metrics in val step 
        self.save_hyperparameters()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        # self.logger.log_metrics("train_loss", loss, step=batch_idx)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        # acc = accuracy(preds, y, task='multiclass', num_classes=self.num_classes)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        # self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

# Init DataModule
dm_2 = MNISTDataModule()
# Init model from datamodule's attributes
model_2 = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = L.Trainer(
    max_epochs=3,
    accelerator="tpu",
    devices=8,
)
# Train
print(f"Running on {len(trainer.device_ids)} devices.")
print(f"Logging metrics under: {trainer.logger.log_dir}...")
trainer.fit(model_2, dm_2)

from pathlib import Path
import pandas as pd

csv_path = Path(trainer.logger.log_dir) / 'metrics.csv'
pd.read_csv(csv_path)

Error messages and logs

Running on 8 devices.
Logging metrics under: /home/carlos.gaitan/bigrna-torch/lightning_logs/version_11...
...
FileNotFoundError: [Errno 2] No such file or directory: '/home/carlos.gaitan/bigrna-torch/lightning_logs/version_11/metrics.csv'

Environment

  • CUDA:
    • GPU: None
    • available: False
    • version: 11.7
  • Lightning:
    • lightning: 2.2.0.dev0
    • lightning-utilities: 0.9.0
    • pytorch-lightning: 2.1.0
    • torch: 2.0.0
    • torch-xla: 2.0
    • torchmetrics: 1.2.0
    • torchvision: 0.15.1
  • Packages:
    • absl-py: 1.4.0
    • aiohttp: 3.8.6
    • aiosignal: 1.3.1
    • anyio: 4.0.0
    • appdirs: 1.4.4
    • argon2-cffi: 23.1.0
    • argon2-cffi-bindings: 21.2.0
    • asttokens: 2.4.1
    • async-lru: 2.0.4
    • async-timeout: 4.0.3
    • attrs: 23.1.0
    • automat: 0.8.0
    • babel: 2.13.1
    • backcall: 0.2.0
    • beautifulsoup4: 4.12.2
    • bleach: 6.1.0
    • blinker: 1.4
    • cachetools: 5.3.0
    • certifi: 2019.11.28
    • cffi: 1.16.0
    • chardet: 3.0.4
    • charset-normalizer: 2.0.12
    • click: 8.1.7
    • cloud-init: 23.1.2
    • cloud-tpu-client: 0.10
    • cmake: 3.26.0
    • colorama: 0.4.3
    • comm: 0.2.0
    • command-not-found: 0.3
    • configobj: 5.0.6
    • constantly: 15.1.0
    • cryptography: 2.8
    • cython: 0.29.14
    • dbus-python: 1.2.16
    • debugpy: 1.8.0
    • decorator: 5.1.1
    • defusedxml: 0.7.1
    • distlib: 0.3.4
    • distro: 1.4.0
    • distro-info: 0.23ubuntu1
    • docker-pycreds: 0.4.0
    • einops: 0.7.0
    • entrypoints: 0.3
    • exceptiongroup: 1.1.3
    • executing: 2.0.1
    • fastjsonschema: 2.19.0
    • filelock: 3.7.1
    • frozenlist: 1.4.0
    • fsspec: 2023.10.0
    • gitdb: 4.0.11
    • gitpython: 3.1.40
    • google-api-core: 1.34.0
    • google-api-python-client: 1.8.0
    • google-auth: 2.23.4
    • google-auth-httplib2: 0.1.0
    • google-cloud-core: 2.3.3
    • google-cloud-storage: 2.13.0
    • google-crc32c: 1.5.0
    • google-resumable-media: 2.6.0
    • googleapis-common-protos: 1.58.0
    • httplib2: 0.14.0
    • hyperlink: 19.0.0
    • idna: 2.8
    • importlib-metadata: 6.8.0
    • importlib-resources: 6.1.1
    • incremental: 16.10.1
    • intel-openmp: 2022.1.0
    • ipykernel: 6.26.0
    • ipython: 8.12.3
    • ipywidgets: 8.1.1
    • jedi: 0.19.1
    • jinja2: 3.1.2
    • json5: 0.9.14
    • jsonpatch: 1.22
    • jsonpointer: 2.0
    • jsonschema: 4.19.2
    • jsonschema-specifications: 2023.11.1
    • jupyter: 1.0.0
    • jupyter-client: 8.6.0
    • jupyter-console: 6.6.3
    • jupyter-core: 5.5.0
    • jupyter-events: 0.9.0
    • jupyter-lsp: 2.2.0
    • jupyter-server: 2.10.1
    • jupyter-server-terminals: 0.4.4
    • jupyterlab: 4.0.8
    • jupyterlab-pygments: 0.2.2
    • jupyterlab-server: 2.25.1
    • jupyterlab-widgets: 3.0.9
    • keyring: 18.0.1
    • language-selector: 0.1
    • launchpadlib: 1.10.13
    • lazr.restfulclient: 0.14.2
    • lazr.uri: 1.0.3
    • libtpu-nightly: 0.1.dev20230213
    • lightning: 2.2.0.dev0
    • lightning-utilities: 0.9.0
    • lit: 15.0.7
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • matplotlib-inline: 0.1.6
    • mdurl: 0.1.2
    • memray: 1.10.0
    • mistune: 3.0.2
    • mkl: 2022.1.0
    • mkl-include: 2022.1.0
    • more-itertools: 4.2.0
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • nbclient: 0.9.0
    • nbconvert: 7.11.0
    • nbformat: 5.9.2
    • nest-asyncio: 1.5.8
    • netifaces: 0.10.4
    • networkx: 3.0
    • notebook: 7.0.6
    • notebook-shim: 0.2.3
    • numpy: 1.24.2
    • nvidia-cublas-cu11: 11.10.3.66
    • nvidia-cuda-cupti-cu11: 11.7.101
    • nvidia-cuda-nvrtc-cu11: 11.7.99
    • nvidia-cuda-runtime-cu11: 11.7.99
    • nvidia-cudnn-cu11: 8.5.0.96
    • nvidia-cufft-cu11: 10.9.0.58
    • nvidia-curand-cu11: 10.2.10.91
    • nvidia-cusolver-cu11: 11.4.0.1
    • nvidia-cusparse-cu11: 11.7.4.91
    • nvidia-nccl-cu11: 2.14.3
    • nvidia-nvtx-cu11: 11.7.91
    • oauth2client: 4.1.3
    • oauthlib: 3.1.0
    • overrides: 7.4.0
    • packaging: 20.3
    • pandas: 2.0.3
    • pandocfilters: 1.5.0
    • parso: 0.8.3
    • pathtools: 0.1.2
    • pexpect: 4.6.0
    • pickleshare: 0.7.5
    • pillow: 9.4.0
    • pip: 20.0.2
    • pkgutil-resolve-name: 1.3.10
    • platformdirs: 2.5.2
    • prometheus-client: 0.18.0
    • prompt-toolkit: 3.0.41
    • protobuf: 3.20.3
    • psutil: 5.9.6
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • pyasn1: 0.4.2
    • pyasn1-modules: 0.2.1
    • pycparser: 2.21
    • pydantic: 1.10.13
    • pydantic-cli: 4.3.0
    • pygments: 2.16.1
    • pygobject: 3.36.0
    • pyhamcrest: 1.9.0
    • pyjwt: 1.7.1
    • pymacaroons: 0.13.0
    • pynacl: 1.3.0
    • pyopenssl: 19.0.0
    • pyparsing: 2.4.6
    • pyrsistent: 0.15.5
    • pyserial: 3.4
    • python-apt: 2.0.0+ubuntu0.20.4.7
    • python-dateutil: 2.8.2
    • python-debian: 0.1.36ubuntu1
    • python-json-logger: 2.0.7
    • pytorch-lightning: 2.1.0
    • pytz: 2023.3.post1
    • pyyaml: 5.4.1
    • pyzmq: 25.1.1
    • qtconsole: 5.5.1
    • qtpy: 2.4.1
    • referencing: 0.31.0
    • requests: 2.31.0
    • requests-unixsocket: 0.2.0
    • rfc3339-validator: 0.1.4
    • rfc3986-validator: 0.1.1
    • rich: 13.6.0
    • rpds-py: 0.12.0
    • rsa: 4.9
    • scipy: 1.10.1
    • secretstorage: 2.3.1
    • send2trash: 1.8.2
    • sentry-sdk: 1.34.0
    • service-identity: 18.1.0
    • setproctitle: 1.3.3
    • setuptools: 62.3.2
    • simplejson: 3.16.0
    • six: 1.14.0
    • smmap: 5.0.1
    • sniffio: 1.3.0
    • sos: 4.3
    • soupsieve: 2.5
    • ssh-import-id: 5.10
    • stack-data: 0.6.3
    • sympy: 1.11.1
    • systemd-python: 234
    • tbb: 2021.6.0
    • terminado: 0.18.0
    • tinycss2: 1.2.1
    • tomli: 2.0.1
    • torch: 2.0.0
    • torch-xla: 2.0
    • torchmetrics: 1.2.0
    • torchvision: 0.15.1
    • tornado: 6.3.3
    • tqdm: 4.66.1
    • traitlets: 5.13.0
    • triton: 2.0.0
    • twisted: 18.9.0
    • typing-extensions: 4.5.0
    • tzdata: 2023.3
    • ubuntu-advantage-tools: 27.8
    • ufw: 0.36
    • unattended-upgrades: 0.1
    • uritemplate: 3.0.1
    • urllib3: 1.25.8
    • virtualenv: 20.14.1
    • wadllib: 1.3.3
    • wandb: 0.15.12
    • wcwidth: 0.2.10
    • webencodings: 0.5.1
    • websocket-client: 1.6.4
    • wheel: 0.34.2
    • widgetsnbextension: 4.0.9
    • yarl: 1.9.2
    • zipp: 1.0.0
    • zope.interface: 4.7.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.10
    • release: 5.13.0-1027-gcp
    • version: #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022

More info

The CSV file sometimes shows up after several minutes (about an hour?), but more often than not it doesn't. It looks like the CSVLogger does not always materialize the logged metrics when running in a distributed setting on TPUs. Sometimes the version_X folder is not created at all, and thus the hparams.yaml file containing metadata about the run is not written to disk.

cc @carmocca @JackCaoG @Liyang90 @gkroiz

コントリビューターガイド