Advertisement

第十三节 huggingface的trainner解读与Demo

阅读量:

文章目录

引言

  • 第二节 辅助函数
    • 函数返回值:yield关键字用于实现生成器函数的返回机制
    • 迭代器的暂停与恢复演示:详细讲解了如何通过__iter____next__方法控制迭代过程
    • 嵌套生成器的使用方法:介绍如何利用yield from语句将嵌套生成器展开
    • 生成器对象的状态获取方法及其作用:分析了torch.Generator()对象中get_state()方法的功能
    • 迭代器状态管理及索引复原演示:展示了如何通过设置索引和复原操作维持迭代流程

  • 四、trainer内训练循环代码解析

    • 自动化内存管理:通过self.accelerator.free_memory()方法实现内存释放功能
    • 数据加载机制:由self.get_train_dataloader()负责生成训练数据加载器
    • 数据预处理模块:该部分负责对输入数据进行格式化处理并转换为模型可接受的形式
      • 返回完整的训练数据加载器对象:该模块最终输出经过预处理后的完整训练数据加载器对象
  • 3、获取与训练相关的train_dataloader数据集

  • 4、优化器参数配置steps的设置

  • 5、学习率调整参数设置

  • 6、延迟创建优化器相关参数配置

  • 7、通过自定义方法self.create_optimizer_and_scheduler()构建并初始化优化器与学习策略

    • a、执行优化器与学习策略初始化操作

    • b、通过调用创建优化器与学习策略的方法实现完整的初始化流程

    • c、详细指定优化器构建的具体参数设置

    • d、明确设定学习策略算法及超参数配置细节

    • 8、self.state = TrainerState()解读

      • a、调用源码
      • b、TrainerState()源码解读
  • 10、梯度切分

    • 11、该模型的自 wrap 函数
      1. 该模型的 resume 方法与 加速策略 以及 优化器 状态加载
    • a. (a)完整代码实现
      • b. 权重重载过程
      • c. 模型加速策略
      • d. 优化器状态重载流程
    1. 加载模型训练状态
      1. 进行参数更新与准备工作
      1. 进行模型训练
      1. 完成后处理metrics和日志信息
      1. 解析_sorted_checkpoints函数的实现细节
    • a. 调用源码文件

      • b. 解析_sorted_checkpoints函数的实现细节
    • 17、删除多余权重

    • 18、self.callback_handler.on_train_end状态获取

    • 19、结果返回


前言

在Hugging Face平台中, trainer相关内容极为丰富。若要全面了解 trainer相关内容则需撰写多篇深度解析的文章。经过一番思考后,我决定将这一主题划分为六篇系统性文章进行阐述。第一篇专门解读TrainingArguments与 trainer参数设置的基本原则;第二篇则通过一个完整的代码示例并深入解析 trainer内核框架的整体架构;第三至第五篇分别聚焦于 epoch级训练循环与 step级训练循环的具体实现细节;最后一期重点探讨如何通过resume方法实现训练过程中的断点续训功能。本节作为第二篇文章,在此基础上着重展示如何利用 trainer提供的训练接口完成基础模型的训练流程,并通过详细代码示例帮助读者理解其工作原理(十分具有实践指导意义)。随后几期将陆续推出相关源码解析章节,在此过程中会对遇到的重要函数进行逐一解读并附上官方文档链接(例如:函数)。至于源码中的具体实现细节,则会在后续章节中结合实际案例逐步剖析清楚。最后一篇将继续深入探讨resume训练的相关技术实现细节,并分享模型权重保存方法等相关实用技巧


第一篇文章分享
第二篇文章分享
第三篇文章分享
第四篇文章分享
第五篇文章分享
第六篇文章分享

一、trainer和TrainingArguments训练与预测完整Demo

我计划设计一个支持全面训练与预测功能的完整系统。通过采用huggingface提供的 trainer 框架进行开发,并同时生成一些简单的文本样本以促进学习过程。在此基础上,我打算利用 bert-base-uncased 模型构建一个示范性的小例子。

1、数据构建

请看以下内容:简洁明了地介绍如何继承 torch.utils.data.Dataset 方法以实现自定义 custom 数据集的方式。无需详细说明,请直接呈现该方法的具体实现细节。

复制代码
    class SentimentDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    
    def __len__(self):
        return len(self.labels)
     # 创建数据集
    train_dataset = SentimentDataset(train_encodings, train_labels)
    val_dataset = SentimentDataset(val_encodings, val_labels)

而__getitem__功能实现最终结构如下图:

在这里插入图片描述

注释:这些train_dataset和val_dataset与PyTorch中的dataset类似,并且同样也需要传递给trainer模型作为输入数据。

2、TrainingArguments构建

这些内容如同我上面介绍一样,需要什么直接添加即可。

复制代码
    # 定义训练参数
    training_args = TrainingArguments(
    output_dir='./out_dirs',
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    logging_dir="./logs",
    evaluation_strategy="epoch",
    report_to="none"
    )

为了详细解释一下,请问您想了解这个方法的工作原理吗?它能够加载预设的默认参数设置,并在某些特定条件下覆盖这些预设值以实现特定功能。

3、Trainer初始化

主角登场了,在此之前我们需要创建一个trainer类实例。需要注意的是,在传递训练集和验证集的数据集时,请遵循与PyTorch中类似的接口方式。

复制代码
    # 定义Trainer对象
    trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=lambda pred: {"accuracy": accuracy_score(pred.label_ids, pred.predictions.argmax(axis=1))}
    )

4、模型训练

上面对了基本就可以训练了,直接给出代码,如下:

复制代码
    # 开始训练
    trainer.train()
    
    # 评估模型
    eval_results = trainer.evaluate()
    # print(f"Accuracy: {eval_results['eval_accuracy']:.4f}")
    
    # 保存模型
    model.save_pretrained("./sentiment_model")

5、模型推理

完成训练,直接使用保存的文件实现推理,如下代码:

复制代码
    # 加载模型
    model = BertForSequenceClassification.from_pretrained("./sentiment_model")
    
    # 预测一些示例文本
    example_texts = ["I love this!", "I hate it."]
    inputs = tokenizer(example_texts, padding=True, truncation=True, return_tensors="pt")
    outputs = model(**inputs)
    predicted_labels = torch.argmax(outputs.logits, dim=1).tolist()
    
    # 打印预测结果
    for text, label in zip(example_texts, predicted_labels):
    print(f"Text: {text} -- Predicted Label: {'positive' if label == 1 else 'negative'}")

推理的inputs输入如下图:

在这里插入图片描述

6、完整demo代码

可直接复制使用代码,其完整代码如下:

复制代码
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    
    from torch.utils.data import Dataset, DataLoader
    
    import torch
    from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import accuracy_score
    import random
    
    # 随机种子
    seed = 42
    random.seed(seed)
    torch.manual_seed(seed)
    
    # 示例数据
    texts = ["I love Hugging Face!", "I hate this.", "This is fantastic!", "I dislike it."]
    labels = [1, 0, 1, 0]  # 1代表正面,0代表负面
    
    # 划分训练集和验证集
    train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=seed)
    
    # 加载预训练的Bert tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    
    # 数据编码
    train_encodings = tokenizer(train_texts, truncation=True, padding=True)
    val_encodings = tokenizer(val_texts, truncation=True, padding=True)
    
    class SentimentDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    
    def __len__(self):
        return len(self.labels)
    
    # 创建数据集
    train_dataset = SentimentDataset(train_encodings, train_labels)
    val_dataset = SentimentDataset(val_encodings, val_labels)
    
    # 加载预训练的Bert模型
    model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
    
    # 定义训练参数
    training_args = TrainingArguments(
    output_dir='./out_dirs',
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    logging_dir="./logs",
    evaluation_strategy="epoch",
    report_to="none"
    )
    
    # 定义Trainer对象
    trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=lambda pred: {"accuracy": accuracy_score(pred.label_ids, pred.predictions.argmax(axis=1))}
    )
    
    # 开始训练
    trainer.train()
    
    # 评估模型
    eval_results = trainer.evaluate()
    print(f"Accuracy: {eval_results['eval_accuracy']:.4f}")
    
    # 保存模型
    model.save_pretrained("./sentiment_model")
    
    # 加载模型
    model = BertForSequenceClassification.from_pretrained("./sentiment_model")
    
    # 预测一些示例文本
    example_texts = ["I love this!", "I hate it."]
    inputs = tokenizer(example_texts, padding=True, truncation=True, return_tensors="pt")
    outputs = model(**inputs)
    predicted_labels = torch.argmax(outputs.logits, dim=1).tolist()
    
    # 打印预测结果
    for text, label in zip(example_texts, predicted_labels):
    print(f"Text: {text} -- Predicted Label: {'positive' if label == 1 else 'negative'}")
在这里插入图片描述

7、完整运行结果

在这里插入图片描述

二、辅助函数

1、yield返回内容

为了实现后续数据加载器的采样迭代采用yield from torch.randperm(n, generator=generator).tolist()的方式,请让我大致介绍一下它能够返回的内容。

复制代码
    yield x:返回变量 x 的值。
    yield 10:返回整数 10。
    yield func():返回函数 func() 的执行结果。
    yield from iterable:从可迭代对象中逐个返回元素。

yield后面跟随什么就从后面走

2、迭代器中断恢复迭代demo

以下是一个关于迭代器意外中断如何重新启动的演示案例介绍。
具体实现代码如下:

复制代码
    import itertools
    # 定义一个列表
    numbers_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
    # 创建一个迭代器
    numbers_iterator = iter(numbers_list)
    # 记住第6次迭代的位置
    for a in itertools.islice(numbers_iterator, 6):
    print(a)
    # 需要重新开始迭代时
    print("Restarting iteration:")
    for num in itertools.chain([next(numbers_iterator)], numbers_iterator):
    print(num)

3、yield from结构

yield from 是在Python 3.3版本中发布的语法,在主代码中实现对子代码块进行操作的能力。它允许一个主代码块能够嵌套地运行另一个代码块在其运行过程中,并逐步将该代码块产生的结果传递给主代码块进行处理和操作。

当一个生成器函数调用 yield from sub_generator() 时,在执行过程中它会暂停自身运行,并主动调用子生成器函数 sub_generator 来产生值,并将这些值依次传递给外部调用者。一旦子生成器函数 sub_generator 完成运行并释放控制权,在这种情况下主生成器将立即返回并继续执行剩下的代码逻辑。

其核心机制是实现了高效的协作方式,在生成器函数中允许一个生成器调用另一个生成器的运行过程,并依次输出其产生的值。通过这种方式实现的代码不仅更具可读性和结构化。

4、torch.Generator()的状态generator.get_state()函数

该方法通过提供生成器的状态信息来实现对当前状态的具体包含内容。

初始值(seed)被指定为算法运行的基础参数,在设定初始值后能够保证后续产生的伪随机序列具有可重复性。

初始值(seed)被指定为算法运行的基础参数,在设定初始值后能够保证后续产生的伪随机序列具有可重复性。

其demo如下:

复制代码
    import torch
    # 创建一个生成器对象
    generator = torch.Generator()
    # 设置生成器的种子
    generator.manual_seed(42)
    # 获取生成器的状态信息
    generator_state = generator.get_state()
    
    print(generator_state)

5、迭代器状态与索引恢复Demo

当重新加载代码后需要从中断点复原迭代过程时,请确保已保存生成器的状态以支持后续操作。具体而言,在Python编程环境中可利用pickle模块将生成器对象进行序列化处理,并将其状态信息以文件形式存储。随后,在重新加载代码的过程中,请确保会话环境能够正常复原并激活相关组件以维持程序运行的完整性与一致性。

以下是一个演示如何保存和恢复生成器状态的示例代码:

复制代码
    import torch
    import pickle
    
    # 定义一个生成器函数
    def random_permutation_generator(n, generator_state, index):
    generator = torch.Generator()
    generator.set_state(generator_state)
    random_permutation = torch.randperm(n, generator=generator).tolist()
    
    for i in range(index, len(random_permutation)):
        yield random_permutation[i], i
    
    
    # 设置随机数生成器种子以确保结果可复现
    torch.manual_seed(42)
    
    n = 10  # 生成随机排列的范围为 0 到 n-1
    
    # 定义一个生成器
    generator = torch.Generator()
    generator.manual_seed(42)
    
    # 创建生成器对象
    gen = random_permutation_generator(n, generator.get_state(), 0)
    
    # 模拟迭代中断并保存生成器状态和已生成的索引到文件
    for i, (element, index) in enumerate(gen):
    print(element)
    if i == 4:
        with open('generator_state.pkl', 'wb') as f:
            state_index = (generator.get_state(), index)
            pickle.dump(state_index, f)
        break
    
    # 重新加载代码后恢复生成器状态和已生成的索引
    with open('generator_state.pkl', 'rb') as f:
    generator_state, index = pickle.load(f)
    
    gen = random_permutation_generator(n, generator_state, index)
    
    print("重新加载代码,从中断处继续生成随机排列的数字...")
    
    # 从中断处继续生成随机排列的数字
    for element, _ in gen:
    print(element)

我们利用存储了生成器的状态信息以及已产生的索引数据,在发生中断时能够确保能够恢复生成器的状态并从断点处正确地重新开始产生符合预期的随机排列数字。

注:这个帮助我们resume数据加载方式理解。

三、trainer.train()源码解读

以上案例我们已呈现trainer训练的完整Demo,在此基础上我将对模型内部源码运行机制进行详细阐述以期为后续代码优化提供参考依据

注:train的参数**kwargs

1、train完整代码

为了便于全面理解 trainer 类的行为模式和功能特性,请提供完整的 train 训练代码。通过提供完整的 train 训练代码,在后续部分中将逐步解释这一内容,并特别重要的一点将在分小节中详细阐述这一关键点。

复制代码
    def train(
        self,
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
        ignore_keys_for_eval: Optional[List[str]] = None,
        **kwargs,
    ):
        """
        Main training entry point.
    
        Args:
            resume_from_checkpoint (`str` or `bool`, *optional*):
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
                The trial run or the hyperparameter dictionary for hyperparameter search.
            ignore_keys_for_eval (`List[str]`, *optional*)
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
            kwargs (`Dict[str, Any]`, *optional*):
                Additional keyword arguments used to hide deprecated arguments
        """
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
    
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()
    
        args = self.args
    
        self.is_in_train = True
    
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
            self._move_model_to_device(self.model, args.device)
    
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)
        self._train_batch_size = self.args.train_batch_size
    
        # Model re-init
        model_reloaded = False
        if self.model_init is not None:
            # Seed must be set before instantiating the model when using model_init.
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
            self.model = self.call_model_init(trial)
            model_reloaded = True
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
    
        # Load potential model checkpoint,这个地方是resume
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
            if resume_from_checkpoint is None:
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
    
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
            self._load_from_checkpoint(resume_from_checkpoint)
    
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
            if self.place_model_on_device:
                self._move_model_to_device(self.model, args.device)
            self.model_wrapped = self.model
    
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )  # 这个是模型类
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

注:后面代码都是来源词函数内容。

2、resume_from_checkpoint参数设置

如果设置resume_from_checkpoint为True或指定权重文件夹路径,则能够完成resumption。在下一篇文章中, 我将重点介绍如何利用trainer实现resumption,并继承状态信息, 包括优化器参数和训练数据等。

复制代码
     if resume_from_checkpoint is False:
     resume_from_checkpoint = None

如果不使用resume_from_checkpoint则设置为None。

3、self._memory_tracker.start()函数

在Hugging Face Trainers框架中定义了一个名为self._memory_tracker.start()的功能块。该功能块的主要职责是初始化并启动一个用于追踪当前对象内存量使用的工具系统。通过调用该函数(或执行start()方法),系统会启动对模型训练过程中的内存量进行持续监测与记录机制。这一操作不仅有助于更好地监控资源(尤其是内存量)的分配与释放情况,在实时监控的基础上还能帮助开发者深入了解其工作流程的行为模式与性能指标表现状况,并及时识别潜在的问题所在点(如超出预期内存量使用上限的情况)。基于这些数据采集与分析结果的支持下,在后续阶段还可以进一步促进有效的内存在线管理以达到优化资源利用率的目的。

4、参数设置

这里也可以给

复制代码
    args = self.args
    self.is_in_train = True

其中args来源TrainingArguments(output_dir=output_dir)此设置。

5、self.model临时配置device

Do_train被视为一个不可靠的参数, 因为它可能未被正确初始化, 虽然train()方法仍然被调用. 为此, 下面提供了一个可行的解决方案, 使用以下代码块进行修复.

复制代码
    # do_train is not a reliable argument, as it might not be set and .train() still called, so the following is a workaround:
    if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
    self._move_model_to_device(self.model, args.device)

我们继续查看_move_model_to_device函数,如下:

复制代码
     def _move_model_to_device(self, model, device):
     model = model.to(device)
     # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
     if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
         model.tie_weights()

6、model_path弃用

’ model_path ‘已弃用,将在将来的版本中删除。使用’ resume_from_checkpoint ’ "。

复制代码
     if "model_path" in kwargs:
     resume_from_checkpoint = kwargs.pop("model_path")
     warnings.warn(
         "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
         "instead.",
         FutureWarning,
     )

7、self._hp_search_setup函数

self._hp_search_setup(trial)这一函数常用于Hugging Face Trainers类中与超参数搜索相关的配置设定,在此情境下trial变量代表一个特定的实验方案。当调用该函数时,Trainer会根据输入的实验方案对超参数进行配置设定包括但不限于不同组合的选择以及具体的搜索策略等具体内容。该过程有助于用户在训练阶段通过系统化的超参数优化工作来提升模型性能并最终找到最适合当前任务的最佳配置从而实现较高的效果与更强的学习能力。

复制代码
    self._hp_search_setup(trial)

8、获得batch

复制代码
    self._train_batch_size = self.args.train_batch_size

9、self.model_init使用

这个是我上一篇说过,给模型实列方法。未提供,将不运行。

复制代码
     # Model re-init
     model_reloaded = False
     if self.model_init is not None:
     # Seed must be set before instantiating the model when using model_init.
     enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
     self.model = self.call_model_init(trial)
     model_reloaded = True
     # Reinitializes optimizer and scheduler
     self.optimizer, self.lr_scheduler = None, None

10、resume方法权重加载

此处具有重要意义,并将在后续内容中进行重点阐述。此处将主要影响模型在中断时的恢复训练流程。

复制代码
    # Load potential model checkpoint,这个地方是resume
    if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
    resume_from_checkpoint = get_last_checkpoint(args.output_dir)
    if resume_from_checkpoint is None:
        raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
    
    if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
    self._load_from_checkpoint(resume_from_checkpoint)

判断是否为bool型,是先经过第一个条件,使resume_from_checkpoint

11、model_reloaded方法

当model被重新初始化时,请将它放置在适当的硬件设备上,并完成模型更新

复制代码
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
            if self.place_model_on_device:
                self._move_model_to_device(self.model, args.device)
            self.model_wrapped = self.model

12、self._inner_training_loop方法

即为此处的模型训练入口。随后为self._inner_training_loop分配相应的batch

即为此处的模型训练入口。随后为self._inner_training_loop分配相应的batch

复制代码
     inner_training_loop = find_executable_batch_size(
     self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
     )  # 这个是模型类

然后通过下面返回,直接调用该函数,其代码如下:

复制代码
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

四、trainer的_inner_training_loop函数源码解读

该函数可视为模型训练的入口点。此外,在此过程中所涉及的各种要素也相当丰富。为了更好地辅助学习和实践操作,在此我对其中认为较为重要的部分进行了一些补充说明。

1、self.accelerator.free_memory()加速器内存释放

复制代码
    在Hugging Face的Trainer类中,self.accelerator.free_memory()函数通常用于释放加速器(Accelerator)上的内存。加速器通常指的是GPU或TPU等硬件加速设备,用于加速深度学习模型的训练过程。通过调用这个函数,Trainer可以释放加速器上的内存,以便在训练过程中及时释放不再需要的内存资源,从而优化内存使用并提高训练效率。

2、self.get_train_dataloader()函数

a、数据加工函数

这个函数十分重要,这个是对数据的处理,其调用代码如下:

复制代码
        # Data loader and number of training steps
        train_dataloader = self.get_train_dataloader()

我将继续深入研究get_train_dataloader函数内部的具体实现细节。
// 为了更好地理解关于self.get_train_dataloader()函数的工作原理及其作用机制,请参考下篇文章获取全面解析。

b、返回内容train_dataloader

值得注意的是,在下一节中将详细阐述该方法的具体实现细节;然而,在当前阶段为了更好地理解其功能和作用

复制代码
        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
        }  # 这个就是dataloader相关参数

第二部分内容:

复制代码
    dataloader_params["sampler"] = self._get_train_sampler()
    dataloader_params["drop_last"] = self.args.dataloader_drop_last
    dataloader_params["worker_init_fn"] = seed_worker

在分析过程中,我们发现dataloader_params参数中最为关键的是采用了collate_fn方法以及采样sampler方法,这些技术构成了前述内容的基础.对于加速过程以及打包操作的具体实现细节,我们将在下文进行详细阐述.

复制代码
    return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))  # 这里是

注:借助 trainer 方法对该函数进行集成化处理后,从而能够根据具体需求提供 collate_fn 以及相应的采样策略。

3、获得train_dataloader数据训练相关内容

配置训练参数:定义训练周期数量:num_train_epochs,
每轮中的更新次数:设定每轮更新次数为num_update_steps_per_epoch,
总优化步骤数量:设定全局最大优化步骤值为max_steps

复制代码
        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
        total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
    
        len_dataloader = None
        if has_length(train_dataloader): 
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
            num_examples = self.num_examples(train_dataloader)  # 获得样本数量
            if args.max_steps > 0: # -1
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
                )
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
            else: 
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
            max_steps = args.max_steps
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            # 设置非常大的训练周期数,以便我们可以尽可能多次地遍历迭代器
            num_train_epochs = sys.maxsize
            num_update_steps_per_epoch = max_steps
            num_examples = total_train_batch_size * args.max_steps
            num_train_samples = args.max_steps * total_train_batch_size
        else:
            raise ValueError(
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
            )

其中has_length函数代码如下,该函数try获得数据长度:

复制代码
    def has_length(dataset):
    """
    Checks if the dataset implements __len__() and it doesn't raise an error
    """
    try:
        return len(dataset) is not None
    except TypeError:
        # TypeError: len() of unsized object
        return False

紧接着通过num_examples函数获得样本数量,其代码如下:

复制代码
    def num_examples(self, dataloader: DataLoader) -> int:
        """
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
        """
        try:
            dataset = dataloader.dataset
            # Special case for IterableDatasetShard, we need to dig deeper
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size

由以下代码计算出max_steps、训练轮次以及总样本量num_train_samples(即 train_dataloader 的样本数量乘以 训练周期)。

复制代码
    max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
    num_train_epochs = math.ceil(args.num_train_epochs)
    num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs

注:这里米昂用到args.num_train_epochs值

4、steps参数设置

复制代码
        # Compute absolute values for logging, eval, and save if given as ratio
        if args.logging_steps and args.logging_steps < 1:
            args.logging_steps = math.ceil(max_steps * args.logging_steps)
        if args.eval_steps and args.eval_steps < 1:
            args.eval_steps = math.ceil(max_steps * args.eval_steps)
        if args.save_steps and args.save_steps < 1:
            args.save_steps = math.ceil(max_steps * args.save_steps)
    
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa

5、optimizer参数设置

与优化器相关的参数设置中,特别是对if self.is_deepspeed_enabled这一参数的设置给予了重点讨论和优化建议。特别关注该条件变量的设定方式和意义,在实际应用中发现其对整体性能提升具有决定性作用。通常情况下,在处理大模型时会采用这种方法;而在此实验中则不采用这种方法

复制代码
    delay_optimizer_creation = (
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
        )  # False
    
        # We need to reset the scheduler, as its parameters may be different on subsequent calls
        if self._created_lr_scheduler:  # False
            self.lr_scheduler = None
            self._created_lr_scheduler = False
    
        if self.is_deepspeed_enabled: # 这个是deepspeed使用用到的,暂时为False
            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
    
        if not delay_optimizer_creation: 
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

6、delay_optimizer_creation参数

如果为False表示之后创建优化器,否则之前创建。默认为之前创建。

复制代码
        delay_optimizer_creation = (
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
        )  # False

7、创建优化器optimizer与学习策略scheduler

这一项具有重要意义,在此处将通过编写特定代码来生成优化器与学习策略。值得注意的是,在此处同样属于该函数self.create_optimizer_and_scheduler$的功能范围。该函数将会调用相应的优化器与策略作为基础组件进行构建。为此,在后续阶段中我们需要使用resume方法,并明确其继承关系及具体应用位置。

注:特别说明,细节在下一节博客有介绍!

a、调用优化器与学习策略代码

考虑到一些结合了重新设置和deepspeed方式的优化器构建过程,在本例中,并未使用这些方法;而是直接调用create_optimizer_and_scheduler函数。

复制代码
    # We need to reset the scheduler, as its parameters may be different on subsequent calls
    if self._created_lr_scheduler:  # False
    self.lr_scheduler = None
    self._created_lr_scheduler = False
    
    if self.is_deepspeed_enabled: # 这个是deepspeed使用用到的,暂时为False
    self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
    # 本实列采用方式
    if not delay_optimizer_creation: 
    self.create_optimizer_and_scheduler(num_training_steps=max_steps)

b、构建优化器与学习策略函数源码self.create_optimizer_and_scheduler

此函数已充分说明了其实质为配置优化器与学习率调度器的行为模式。建议采用表现良好且合理的默认参数组合以获得良好训练效果。如需采用其他配置方案,则可参考以下两种实现途径:一种是在Trainer.init方法中指定一个优化器元组;另一种则需在子类中重写相关初始化逻辑(可参考PyTorch官方文档中的创建优化器与学习率调节模块的具体实现)

复制代码
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.
    
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
     		
        """
        self.create_optimizer()
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

清楚地说,在这个函数中通过调用self.create_optimizer()实现了对优化器的使用,并将其赋值为自定义的self.optimizer。另一方面,在实现学习策略的过程中,该方法的参数optimizer与自适应优化器自变量(即self.optimizer)保持一致。

自定义创建优化器与创建调度器(参数包括num_training_steps和optimizer)将在下节博客深入解析。

c、优化器构建细节强调

基于自定义函数self.create_optimizer()构建的优化器,在涉及学习率或参数等方面具有特定配置。为了深入分析其实现细节,我将对其实现代码进行展示和解释。

复制代码
        optimizer_kwargs = {"lr": args.learning_rate}  # 这个地方获得学习率,很重要
    
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        } # 获得betas与eps

在传递参数时进行调整之前,请务必确保学习率及相关参数已由args完成配置。

d、优化器策略构建细节强调

优化器策略使用如下函数:

复制代码
    self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

而改函数:

复制代码
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
        """
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
    
        Args:
            num_training_steps (int): The number of training steps to do.
        """
        if self.lr_scheduler is None:
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
                optimizer=self.optimizer if optimizer is None else optimizer,
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
                num_training_steps=num_training_steps,
            )
            self._created_lr_scheduler = True
        return self.lr_scheduler

此外,在代码实现中可以看到 get_scheduler函数调用后会返回一个 LambdaLR实例(如图所示),该实例继承自父类并负责管理学习率衰减过程。在初始化阶段(如图所示),我们定义了优化器参数、学习率衰减函数以及起始 epoch 值等关键组件。这些设置将被存储在 self.lr_scheduler属性中以便后续训练过程使用。

在这里插入图片描述

8、self.state = TrainerState()解读

该函数是一个状态记录

a、调用源码

self.state保持类的实列化调用如下:

复制代码
        self.state = TrainerState()
        self.state.is_hyper_param_search = trial is not None

b、TrainerState()源码解读

直接给TrainerState类的源码,其中的源码注释直接翻译成中文。

复制代码
    @dataclass
    class TrainerState:
    """
      <提示>
    
    在这个类中,一个步骤被理解为一个更新步骤。当使用梯度累积时,一个更新步骤可能需要多次前向和后向传递:如果您使用 `gradient_accumulation_steps=n`,那么一个更新步骤需要经过 *n* 个批次。
    
    </提示>
    
    参数:
    epoch (`float`, *可选*):
        仅在训练期间设置,表示训练所处的时期(小数部分表示当前时期完成的百分比)。
    global_step (`int`, *可选*,默认为 0):
        在训练期间,表示已完成的更新步数。
    max_steps (`int`, *可选*,默认为 0):
        当前训练中要执行的更新步数。
    total_flos (`float`, *可选*,默认为 0):
        自训练开始以来模型执行的总浮点操作数(以浮点数存储以避免溢出)。
    log_history (`List[Dict[str, float]]`, *可选*):
        自训练开始以来已完成的日志列表。
    best_metric (`float`, *可选*):
        在跟踪最佳模型时,迄今为止遇到的最佳指标值。
    best_model_checkpoint (`str`, *可选*):
        在跟踪最佳模型时,迄今为止遇到的最佳模型检查点的名称值。
    is_local_process_zero (`bool`, *可选*,默认为 `True`):
        此进程是否为本地(例如,在多台机器上以分布方式进行训练时在一台机器上)主要进程。
    is_world_process_zero (`bool`, *可选*,默认为 `True`):
        此进程是否为全局主要进程(在多台机器上以分布方式进行训练时,对于一个进程,这只会是 `True`)。
    is_hyper_param_search (`bool`, *可选*,默认为 `False`):
        我们是否在使用 Trainer.hyperparameter_search 进行超参数搜索过程中。这将影响数据在 TensorBoard 中记录的方式。
    
    """
    
    epoch: Optional[float] = None
    global_step: int = 0
    max_steps: int = 0
    num_train_epochs: int = 0
    total_flos: float = 0
    log_history: List[Dict[str, float]] = None
    best_metric: Optional[float] = None
    best_model_checkpoint: Optional[str] = None
    is_local_process_zero: bool = True
    is_world_process_zero: bool = True
    is_hyper_param_search: bool = False
    trial_name: str = None
    trial_params: Dict[str, Union[str, float, int, bool]] = None
    
    def __post_init__(self):
        if self.log_history is None:
            self.log_history = []
    
    def save_to_json(self, json_path: str):
        """Save the content of this instance in JSON format inside `json_path`."""
        json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
        with open(json_path, "w", encoding="utf-8") as f:
            f.write(json_string)
    
    @classmethod
    def load_from_json(cls, json_path: str):
        """Create an instance from the content of `json_path`."""
        with open(json_path, "r", encoding="utf-8") as f:
            text = f.read()
        return cls(**json.loads(text))

10、gradient_checkpointing

复制代码
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing: 
            self.model.gradient_checkpointing_enable()

11、模型self._wrap_model函数

这个函数就是对模型进行包装,调用方式如下:

复制代码
    model = self._wrap_model(self.model_wrapped)

我们看到该函数部分内部代码(如下),根据不同条件返回模型。

复制代码
    def _wrap_model(self, model, training=True, dataloader=None):
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)
    
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
         ...

11、resume方法、模型加速、优化器与策略载入

a、完整代码

在完成模型封装后, 紧接着是否需要加载权重参数, 以及如何优化模型封装过程, 其完整代码如下:

复制代码
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)
    
        # as the model is wrapped, don't use `accelerator.prepare`
        # this is for unhandled cases such as
        # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
        use_accelerator_prepare = True if model is self.model else False
    
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)
    
        # prepare using `accelerator` prepare
        if use_accelerator_prepare: 
            self.model.train()
            if hasattr(self.lr_scheduler, "step"):
                if self.use_apex:
                    model = self.accelerator.prepare(self.model)
                else: 
                    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
            else:
                # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
                    self.model, self.optimizer, self.lr_scheduler
                )
    
        if self.is_fsdp_enabled:
            self.model = model
    
        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model: 
            self.model_wrapped = model
    
        # backward compatibility
        if self.is_deepspeed_enabled: 
            self.deepspeed = self.model_wrapped
    
        # deepspeed ckpt loading
        if resume_from_checkpoint is not None and self.is_deepspeed_enabled:
            deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
    
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

b、权重重载

权重重载共有二处,分别根据不同条件载入,第一处代码如下:

复制代码
    if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

第二处代码如下:

复制代码
    # deepspeed ckpt loading
     if resume_from_checkpoint is not None and self.is_deepspeed_enabled:
     deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)

c、模型加速

模型是否需要加速,其代码如下:

复制代码
    # prepare using `accelerator` prepare
        if use_accelerator_prepare: 
            self.model.train()
            if hasattr(self.lr_scheduler, "step"):
                if self.use_apex:
                    model = self.accelerator.prepare(self.model)
                else: 
                    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
            else:
                # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
                    self.model, self.optimizer, self.lr_scheduler
                )

d、优化器状态重载

优化器相关内容重载,其代码如下:

复制代码
       # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

以上内容,我现将不在解释,等后面resume方式,在进一步说明。

11、训练状态重载

该处主要涉及对训练相关参数的配置,并采用resume_from_checkpoint的方式进行加载操作;其中TRAINER_STATE_NAME被设定为trainer_state.json,并且需要确保该文件的存在。

复制代码
     self.state.epoch = 0
     start_time = time.time()
     epochs_trained = 0
     steps_trained_in_current_epoch = 0
     steps_trained_progress_bar = None
    
     # Check if continuing training from a checkpoint
     if resume_from_checkpoint is not None and os.path.isfile(
     os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
     ):
     self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
     epochs_trained = self.state.global_step // num_update_steps_per_epoch
     if not args.ignore_data_skip:
         steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
         steps_trained_in_current_epoch *= args.gradient_accumulation_steps
     else:
         steps_trained_in_current_epoch = 0
    
     logger.info("  Continuing training from checkpoint, will skip to saved global_step")
     logger.info(f"  Continuing training from epoch {epochs_trained}")
     logger.info(f"  Continuing training from global step {self.state.global_step}")
     if not args.ignore_data_skip:
         logger.info(
             f"  Will skip the first {epochs_trained} epochs then the first"
             f" {steps_trained_in_current_epoch} batches in the first epoch."
         )

13、训练前更新参数与准备

大致了解这一过程是在训练开始之前完成的参数更新与准备工作。特别强调的是self.control设置方式。

复制代码
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else: # finetune-lora
            self.state.trial_params = None
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
    
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
        tr_loss = torch.tensor(0.0).to(args.device)
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
        self._globalstep_last_logged = self.state.global_step
        model.zero_grad()
    
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
    
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
        if not args.ignore_data_skip:
            for epoch in range(epochs_trained):
                for _ in train_dataloader:
                    break

14、模型训练

紧接着就是模型训练的部分代码展示。同时涉及的其他内容,在下文章节中我会详细说明。

复制代码
    total_batched_samples = 0
    for epoch in range(epochs_trained, num_train_epochs):  # 训练总共epoch数,我们的传参
       epoch_iterator = train_dataloader
    
       # Reset the past mems state at the beginning of each epoch if necessary.
       if args.past_index >= 0:
       self._past = None
    
       steps_in_epoch = (
       len(epoch_iterator)
       if len_dataloader is not None
       else args.max_steps * args.gradient_accumulation_steps
       )
      ...
       for step, inputs in enumerate(epoch_iterator):
       total_batched_samples += 1
       if rng_to_sync:
           self._load_rng_state(resume_from_checkpoint)
           rng_to_sync = False

15、训练完的metrics、log等处理

这部分是处理后需要处理相关内容,直接给出源码如下:

复制代码
        if args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")
        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
                dist.barrier()
            elif is_sagemaker_mp_enabled():
                smp.barrier()
    
            self._load_best_model()
    
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step
    
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
        metrics["train_loss"] = train_loss
    
        self.is_in_train = False
    
        self._memory_tracker.stop_and_update_metrics(metrics)
    
        self.log(metrics)

16、_sorted_checkpoints函数源码解读

a、调用源码

复制代码
        run_dir = self._get_output_dir(trial) # 获得路径
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

b、_sorted_checkpoints源码解读

该函数的主要功能是保存权重参数,在best_model_checkpoint参数被指定的情况下才具有实际意义。如果不设置该参数,则其作用几乎可以忽略不计。

复制代码
     def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
        ordering_and_checkpoint_path = []
    
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
    
        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match is not None and regex_match.groups() is not None:
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
    
        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
        return checkpoints_sorted
在这里插入图片描述

17、删除多余权重

当self.args.save_total_limit这一项被设置时,在约束权重数量上限的情况下,在某个阶段就会移除不必要的权重文件。

复制代码
        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
            for checkpoint in checkpoints_sorted:
                if checkpoint != self.state.best_model_checkpoint:
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

18、self.callback_handler.on_train_end状态获取

这一操作用于将一些状态赋值/更新到self.control变量中。其中具体的实现细节见下述代码段:

复制代码
    self.control = self.callback_handler.on_train_end(args, self.state, self.control)

其保存如下图所示,我也会在后面训练代码解读文章进行说明。

在这里插入图片描述

19、结果返回

此函数计算所得的数值结果包括self.state.global_step、train_loss以及metrics。 TrainOutput类用于存储训练过程中的关键指标数据。 其具体实现见下文。

复制代码
    return TrainOutput(self.state.global_step, train_loss, metrics)

TrainOutput类源码如下:

复制代码
    class TrainOutput(NamedTuple):
    global_step: int
    training_loss: float
    metrics: Dict[str, float]

全部评论 (0)

还没有任何评论哟~