Advertisement

第四十二章 快速进入PyTorch(工具)

阅读量:

目录

一、简介

二、安装 PyTorch Lightning

三、定义 LightningModule

3.1 SYSTEM VS MODEL

3.2 FORWARD vs TRAINING_STEP

三、配置 Lightning Trainer

四、基本特性

4.1 Manual vs automatic optimization

4.1.1 自动优化 (Automatic optimization)

4.1.1 手动优化 (Manual optimization)

4.2 Predict or Deploy

4.2.1 选项一 —— 子模型 (Sub-models)

4.2.2 选项二 —— 前馈 (Forward)

4.2.2 选项三 —— 生产 (Production)

4.3 Using CPUs/GPUs/TPUs

4.4 Checkpoints

4.5 Data flow

4.6 Logging

4.7 Optional extensions

4.7.1 回调 (Callbacks)

4.7.2 LightningDataModules

4.8 Debugging

五、其他炫酷特性


项目地址https://github.com/PyTorchLightning/pytorch-lightning


一、简介

本指南将展示如何分两步将 PyTorch 代码组织成 Lightning。

使用 PyTorch Lightning 组织代码,可以使代码:

  • 保持所有灵活性(这些完全基于 PyTorch 的生态系统),但去除了冗余部分。
  • 将研究代码与其工程实现分离开来,并显著提升了代码的可读性。
  • 更容易复现。
  • 通过自动化的训练循环和复杂的系统设计实现了大部分的手动流程自动化。
  • 在各种硬件配置下运行,并且无需修改模型即可使用。

二、安装 PyTorch Lightning

pip 安装:

复制代码
    pip install pytorch-lightning
    
    
    AI写代码undefined

或 conda 安装:

复制代码
    conda install pytorch-lightning -c conda-forge
    
    
    AI写代码r
    
    运行

或在 conda 虚拟环境下安装:

复制代码
    conda activate my_env
    
    pip install pytorch-lightning
    
    
    AI写代码undefined

在新源文件中导入以下将用到的依赖:

复制代码
    import os
    
    import torch
    
    from torch import nn
    
    import torch.nn.functional as F
    
    from torchvision.datasets import MNIST
    
    from torchvision import transforms
    
    from torch.utils.data import DataLoader
    
    import pytorch_lightning as pl
    
    from torch.utils.data import random_split
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/r89ZfmtwABnKqH0pxIvYgX4JiEdM.png)

三、定义 LightningModule

复制代码
    class LitAutoEncoder(pl.LightningModule):
    
     
    
    def __init__(self):
    
        super().__init__()
    
        self.encoder = nn.Sequential(
    
            nn.Linear(28*28, 64),
    
            nn.ReLU(),
    
            nn.Linear(64, 3)
    
        )
    
        self.decoder = nn.Sequential(
    
            nn.Linear(3, 64),
    
            nn.ReLU(),
    
            nn.Linear(64, 28*28)
    
        )
    
     
    
    def forward(self, x):
    
        # in lightning, forward defines the prediction/inference actions
    
        embedding = self.encoder(x)
    
        return embedding
    
     
    
    def training_step(self, batch, batch_idx):
    
        # training_step defined the train loop.
    
        # It is independent of forward
    
        x, y = batch
    
        x = x.view(x.size(0), -1)
    
        z = self.encoder(x)
    
        x_hat = self.decoder(z)
    
        loss = F.mse_loss(x_hat, x)
    
        # Logging to TensorBoard by default
    
        self.log('train_loss', loss)
    
        return loss
    
     
    
    def configure_optimizers(self):
    
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    
        return optimizer
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/QBCGbnfzoHcI9Oqr02aUYAxVyuTZ.png)

3.1 SYSTEM VS MODEL

注意,LightningModule 构建了一个系统而非仅仅是一个模型:不仅是一个模型而是涵盖了系统的LightningModule。

img

关于系统 (system) 的例子还有:

在该模块内,LightningModule 仍然只是一个 torch.nn.Module 实例,并通过将所有研究代码整合到同一个文件中来实现其独立性

  • Training Loop
  • Validation Loop
  • Testing Loop
  • Model(s) employed or the system of Models used
  • Optimizer algorithm

可从 Available Callback hooks 这一模块中选择超过20个 hooks,并将其应用于包括反向传播在内的任何训练步骤。

复制代码
    class LitAutoEncoder(pl.LightningModule):
    
     
    
    def backward(self, loss, optimizer, optimizer_idx):
    
        loss.backward()
    
    
    AI写代码python
    
    运行

3.2 FORWARD vs TRAINING_STEP

在Lightning框架中,我们特意将训练过程与推理过程区分开来。其中training_step完整地描述了训练流程。建议用户通过forward操作来实现推理逻辑。

在这种情况中, 可将自动编码器用作嵌入提取器 (embedding extractor)

复制代码
    def forward(self, x):
    
    embeddings = self.encoder(x)
    
    return embeddings
    
    
    AI写代码python
    
    运行

当然,没有什么可以阻止你在 training_step 中使用 forward:

复制代码
    def training_step(self, batch, batch_idx):
    
    ...
    
    z = self(x)
    
    
    AI写代码python
    
    运行

这确实取决于个人的应用程序,但仍建议将两个意图分开。

  • 使用 forward 推理/预测
  • 使用 training_step 训练

详细信息位于 LightningModule 文档中


三、配置 Lightning Trainer

首先进行明确的数据定义。Lightning系统仅提供一个 DataLoader 供训练、验证及测试数据分片处理:

复制代码
    dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
    
    train_loader = DataLoader(dataset)
    
    
    AI写代码python
    
    运行

随后, 初始化为 LightningModule 以及 PyTorch Lightning 的 Trainer 实例, 并对数据集与模型架构进行拟合训练。

复制代码
    # init model
    
    autoencoder = LitAutoEncoder()
    
     
    
    # most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
    
    # trainer = pl.Trainer(gpus=8) (if you have GPUs)
    
    trainer = pl.Trainer()
    
    trainer.fit(autoencoder, train_loader)
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/1EWzQwVbtHoNDRlsKmZXU6uF9d7i.png)

Trainer 自动化以下内容:

如果想要自行配置优化器,则建议采用Manual optimization mode(具体应用包括但不限于Reinforcement Learning和Generative Adversarial Networks等)。

这些便是我们在Lightning中要掌握的两大核心概念。此外,Lightning的所有其他属性都源自于Trainer或LightningModule。


四、基本特性

4.1 Manual vs automatic optimization

4.1.1 自动优化 (Automatic optimization)

借助Lightning框架,在无需关心何时开启或关闭梯度(grads)的情况下,实现反向传播或更新优化器;当在 training_step 中返回损失时自动处理这些流程。

复制代码
    def training_step(self, batch, batch_idx):
    
    loss = self.encoder(batch[0])
    
    return loss
    
    
    AI写代码python
    
    运行
4.1.1 手动优化 (Manual optimization)

值得注意的是,在特定领域内的研究中(例如GAN技术、强化学习方法以及涉及多优化器或多内循环机制的研究),人们可以通过放弃自动优化策略来实现全手动控制流程。

首先,关闭自动优化:

复制代码
    trainer = Trainer(automatic_optimization=False)
    
    
    AI写代码python
    
    运行

然后,构造自己的训练循环:

复制代码
    def training_step(self, batch, batch_idx, opt_idx):
    
    (opt_a, opt_b, opt_c) = self.optimizers()
    
     
    
    loss_a = self.generator(batch[0])
    
     
    
    # use this instead of loss.backward so we can automate half precision, etc...
    
    self.manual_backward(loss_a, opt_a, retain_graph=True)
    
    self.manual_backward(loss_a, opt_a)
    
    opt_a.step()
    
    opt_a.zero_grad()
    
     
    
    loss_b = self.discriminator(batch[0])
    
    self.manual_backward(loss_b, opt_b)
    
    ...
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/JOxzqEPUbLF7wu615lgHp8QZIR4f.png)

4.2 Predict or Deploy

完成训练后,可以使用 3 个选项将 LightningModule 用于预测。

4.2.1 选项一 —— 子模型 (Sub-models)

取出系统 (system) 内部的任何模型进行预测。

复制代码
    # ----------------------------------
    
    # to use as embedding extractor
    
    # ----------------------------------
    
    autoencoder = LitAutoEncoder.load_from_checkpoint('path/to/checkpoint_file.ckpt')
    
    encoder_model = autoencoder.encoder
    
    encoder_model.eval()
    
     
    
    # ----------------------------------
    
    # to use as image generator
    
    # ----------------------------------
    
    decoder_model = autoencoder.decoder
    
    decoder_model.eval()
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/rvMHn0ax4gh2XUbyCFOYwWmSuVZi.png)
4.2.2 选项二 —— 前馈 (Forward)

如果需要,也可以加入一个 forward 方法进行预测:

复制代码
    # ----------------------------------
    
    # using the AE to extract embeddings
    
    # ----------------------------------
    
    class LitAutoEncoder(pl.LightningModule):
    
    def forward(self, x):
    
        embedding = self.encoder(x)
    
        return embedding
    
     
    
    autoencoder = LitAutoencoder()
    
    autoencoder = autoencoder(torch.rand(1, 28 * 28))
    # ----------------------------------
    
    # or using the AE to generate images
    
    # ----------------------------------
    
    class LitAutoEncoder(pl.LightningModule):
    
    def forward(self):
    
        z = torch.rand(1, 3)
    
        image = self.decoder(z)
    
        image = image.view(1, 1, 28, 28)
    
        return image
    
     
    
    autoencoder = LitAutoencoder()
    
    image_sample = autoencoder(()
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/qa4Hgl3QDPsTBvUj8LAcu7GRiMKb.png)
4.2.2 选项三 —— 生产 (Production)

对于系统 (systems),onnx和torchscript在执行效率上表现出显著的优势。建议采取以下措施:首先确保已添加forward方法;其次可以选择仅跟踪所需子模型(sub-models)。

复制代码
    # ----------------------------------
    
    # torchscript
    
    # ----------------------------------
    
    autoencoder = LitAutoEncoder()
    
    torch.jit.save(autoencoder.to_torchscript(), "model.pt")
    
    os.path.isfile("model.pt")
    # ----------------------------------
    
    # onnx
    
    # ----------------------------------
    
    with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
    
     autoencoder = LitAutoEncoder()
    
     input_sample = torch.randn((1, 28 * 28))
    
     autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
    
     os.path.isfile(tmpfile.name)
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/uJUEZ23o4gQDc1k9ALqRfmH5POw8.png)

4.3 Using CPUs/GPUs/TPUs

在 Lightning 中使用 CPU、GPU 或 TPU 非常容易。不需要更改代码,只需设置训练器选项即可.

复制代码
    # train on CPU
    
    trainer = pl.Trainer()
    # train on 8 CPUs
    
    trainer = pl.Trainer(num_processes=8)
    # train on 1024 CPUs across 128 machines
    
    trainer = pl.Trainer(
    
    num_processes=8,
    
    num_nodes=128
    
    )
    # train on 1 GPU
    
    trainer = pl.Trainer(gpus=1)
    # train on multiple GPUs across nodes (32 gpus here)
    
    trainer = pl.Trainer(
    
    gpus=4,
    
    num_nodes=8
    
    )
    # train on gpu 1, 3, 5 (3 gpus total)
    
    trainer = pl.Trainer(gpus=[1, 3, 5])
    # Multi GPU with mixed precision
    
    trainer = pl.Trainer(gpus=2, precision=16)
    # Train on TPUs
    
    trainer = pl.Trainer(tpu_cores=8)
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/WFT6Ut0nSOHGyvfZYjMDNwc7h4JQ.png)

现在,无需更改自己代码,就可以用上述代码执行以下操作:

复制代码
    # train on TPUs using 16 bit precision
    
    # using only half the training data and checking validation every quarter of a training epoch
    
    trainer = pl.Trainer(
    
    tpu_cores=8,
    
    precision=16,
    
    limit_train_batches=0.5,
    
    val_check_interval=0.25
    
    )
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/CO91FW7zXK8fxioGSTnUqahYg6mA.png)

4.4 Checkpoints

Lightning 自动管理模型的存储过程。完成训练后,则可通过以下方法加载检查点 (checkpoints)。

Lightning 自动管理模型的存储过程。完成训练后,则可通过以下方法加载检查点 (checkpoints)。

复制代码
    model = LitModel.load_from_checkpoint(path)
    
    
    AI写代码python
    
    运行

上述(above)检查点(checkpoints)涉及(involve)两个主要步骤:一是(first)实现(achieve)模型初始化(model initialization),二是(second)配置(configure)状态字典(state dict)。在手动操作时(when manual operations are required),建议采用以下替代方案:

复制代码
    # load the ckpt
    
    ckpt = torch.load('path/to/checkpoint.ckpt')
    
     
    
    # equivalent to the above
    
    model = LitModel()
    
    model.load_state_dict(ckpt['state_dict'])
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/cembhD4nJXsWLQrjk1NtFdOUIgH2.png)

4.5 Data flow

每个循环 (训练,验证,测试) 都有三个钩子 (hooks) 函数可以执行:

  • x_step
  • x_step_end
  • x_epoch_end

为阐明数据如何流动,此处使用训练循环说明 (即 x=training):

复制代码
    outs = []
    
    for batch in data:
    
    out = training_step(batch)
    
    outs.append(out)
    
    training_epoch_end(outs)
    
    
    AI写代码python
    
    运行

这在 Lightning 中等价为:

复制代码
    def training_step(self, batch, batch_idx):
    
    prediction = ...
    
    return prediction
    
     
    
    def training_epoch_end(self, training_step_outputs):
    
    for prediction in predictions:
    
        # do something with these
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/ShGLT19HXPEMebJ5pYrjczWxDlyI.png)

当采用 DP 或 DDP2 分布式模式(包括将每个 batch 分配到各个 GPU 上),建议采取 x_step_end 方法手动完成数据汇总(aggregate)。如果不想手动操作,则可以让 Lightning 自行处理。

复制代码
    for batch in data:
    
    model_copies = copy_model_per_gpu(model, num_gpus)
    
    batch_split = split_batch_per_gpu(batch, num_gpus)
    
     
    
    gpu_outs = []
    
    for model, batch_part in zip(model_copies, batch_split):
    
        # LightningModule hook
    
        gpu_out = model.training_step(batch_part)
    
        gpu_outs.append(gpu_out)
    
     
    
    # LightningModule hook
    
    out = training_step_end(gpu_outs)
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/t4LrIdKZUq2JBFnPlXYNi06fD95H.png)

这在 Lightning 中等价为:

复制代码
    def training_step(self, batch, batch_idx):
    
    loss = ...
    
    return loss
    
     
    
    def training_step_end(self, losses):
    
    gpu_0_loss = losses[0]
    
    gpu_1_loss = losses[1]
    
    return (gpu_0_loss + gpu_1_loss) * 1/2
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/wG2dDT5l48kb7uZUCMKLYnxhNt3E.png)

小贴士

验证和测试循环具有同样的结构。


4.6 Logging

建议您采用该函数用于将日志信息发送到Tensorboard或其他进度指示器中。该方法可以从LightningModule 中的任何方法调用:

复制代码
    def training_step(self, batch, batch_idx):
    
    self.log('my_metric', x)
    
    
    AI写代码python
    
    运行

log() 方法有以下几个选项:

  • step metrics(记录训练步骤中的度量指标)
    • epoch completion tracking(在 epoch 结束时自动追踪和记录)
    • progress indicator(进度显示器)
    • log aggregator(类似于 Tensorboard 的日志收集器)

基于从哪里调用日志 (log),Lightning 自动选择合适的模式。当然也可以通过配置标志并支持自定义的行为。

注意

设置 on_epoch = True 将在整个训练 epoch 内累积记录值。

复制代码
>     def training_step(self, batch, batch_idx):
>  
>         self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
>  
>  
>     AI写代码python
>  
>     运行
>  
>       
>       
>       
>  
>

注意

你也可以直接使用你的记录器 (loggers) 的任何方法:

复制代码
    def training_step(self, batch, batch_idx):
    
    tensorboard = self.logger.experiment
    
    tensorboard.any_summary_writer_method_you_want())
    
    
    AI写代码python
    
    运行

当训练启动时即可使用日志记录工具(logger)或启用 Tensorboard logs 来追踪模型性能数据

复制代码
    tensorboard --logdir ./lightning_logs
    
    
    AI写代码python
    
    运行

注意

Lightning 会自动在进度栏中显示从 training_step 返回的损失值。因此,无需显式记录此类日志,例如像这样 self.log(‘loss’, loss, prog_bar = True)。

深入探讨记录器(Loggers)的相关细节,并系统性地了解其功能特点及其在实际应用中的表现形式


4.7 Optional extensions

4.7.1 回调 (Callbacks)

该函数(Callbacks)被指定为一个具有任意性质的独立实体,并可在训练循环中的任何阶段被触发或运行。

以下是添加不太理想的学习率衰减规则的示例:

复制代码
    class DecayLearningRate(pl.Callback)
    
     
    
    def __init__(self):
    
        self.old_lrs = []
    
     
    
    def on_train_start(self, trainer, pl_module):
    
        # track the initial learning rates
    
        for opt_idx in optimizer in enumerate(trainer.optimizers):
    
            group = []
    
            for param_group in optimizer.param_groups:
    
                group.append(param_group['lr'])
    
            self.old_lrs.append(group)
    
     
    
    def on_train_epoch_end(self, trainer, pl_module, outputs):
    
        for opt_idx in optimizer in enumerate(trainer.optimizers):
    
            old_lr_group = self.old_lrs[opt_idx]
    
            new_lr_group = []
    
            for p_idx, param_group in enumerate(optimizer.param_groups):
    
                old_lr = old_lr_group[p_idx]
    
                new_lr = old_lr * 0.98
    
                new_lr_group.append(new_lr)
    
                param_group['lr'] = new_lr
    
             self.old_lrs[opt_idx] = new_lr_group
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/kI2H3hExfe1YU4OJSpt086adFbQA.png)

可以通过回调执行的操作:

  • 在特定的训练阶段向客户端发送消息
    • 增长方案
    • 调整学习速率参数
    • 观察梯度分布情况
    • 受益于你的创新思维框架

探索更多关于自定义回调 (custom callbacks) 的详细信息。

4.7.2 LightningDataModules

相关工具组件(如 DataLoader 和数据处理代码)趋于分布。通过将它们整合至 LightningDataModule 中来提升代码复用性。

复制代码
    class MNISTDataModule(pl.LightningDataModule):
    
     
    
      def __init__(self, batch_size=32):
    
          super().__init__()
    
          self.batch_size = batch_size
    
     
    
      # When doing distributed training, Datamodules have two optional arguments for
    
      # granular control over download/prepare/splitting data:
    
     
    
      # OPTIONAL, called only on 1 GPU/machine
    
      def prepare_data(self):
    
          MNIST(os.getcwd(), train=True, download=True)
    
          MNIST(os.getcwd(), train=False, download=True)
    
     
    
      # OPTIONAL, called for every GPU/machine (assigning state is OK)
    
      def setup(self, stage):
    
          # transforms
    
          transform=transforms.Compose([
    
              transforms.ToTensor(),
    
              transforms.Normalize((0.1307,), (0.3081,))
    
          ])
    
          # split dataset
    
          if stage == 'fit':
    
              mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
    
              self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
    
          if stage == 'test':
    
              self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
    
     
    
      # return the dataloader for each split
    
      def train_dataloader(self):
    
          mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
    
          return mnist_train
    
     
    
      def val_dataloader(self):
    
          mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
    
          return mnist_val
    
     
    
      def test_dataloader(self):
    
          mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
    
          return mnist_test
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/ICUPYOGfgyjRXQvceT10zJbqNWsl.png)

LightningDataModule旨在促进跨项目数据片段和转换的共享与再利用(splits and transforms)。该模块整合了从获取到标注化的所有数据处理步骤(包括下载、标注化、处理)。

现在,只需将 LightningDataModule 传递给 Trainer

复制代码
    # init model
    
    model = LitModel()
    
     
    
    # init data
    
    dm = MNISTDataModule()
    
     
    
    # train
    
    trainer = pl.Trainer()
    
    trainer.fit(model, dm)
    
     
    
    # test
    
    trainer.test(datamodule=dm)
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/clX8DLICvp354hoxnzamYEO1HRKw.png)

DataModules对于基于数据构建模型表现出色。了解更多信息有关LightningDataModule可能会有所帮助。


4.8 Debugging

Lightning 有许多调试工具。这是其中一些示例:

复制代码
    # use only 10 train batches and 3 val batches
    
    trainer = pl.Trainer(limit_train_batches=10, limit_val_batches=3)
    # Automatically overfit the sane batch of your model for a sanity test
    
    trainer = pl.Trainer(overfit_batches=1)
    # unit test all the code- hits every line of your code once to see if you have bugs,
    
    # instead of waiting hours to crash on validation
    
    trainer = pl.Trainer(fast_dev_run=True)
    # train only 20% of an epoch
    
    trainer = pl. Trainer(limit_train_batches=0.2)
    # run validation every 25% of a training epoch
    
    trainer = pl.Trainer(val_check_interval=0.25)
    # Profile your code to find speed/memory bottlenecks
    
    Trainer(profiler=True)
    
    
    AI写代码python
    
    运行
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/jmok6HlIVMbNxJFuKgBU4dz5eTOR.png)

五、其他炫酷特性

定义并训练了第一个 Lightning 模型之后,在探索其他卓越的功能之前,请考虑尝试以下选项:

也许你可以查看Step-by-step walk-through来获取更多细节。


全部评论 (0)

还没有任何评论哟~