Advertisement

第十五节 huggingface的trainner的_inner_training_loop函数源码解读(epoch)

阅读量:

文章目录

  • 前言

  • 一、完整源码呈现

    • 1、训练前源码
    • 3、训练源码(epoch)
  • . 解析训练周期代码流程

      • 1. 在 epoch 循环开始前的变量设置
    • 2. 启动 epochs 循环
    • 2. 使用训练数据集进行处理
    • 3. 确定每个 epoch 中 step 的数量
    • 4. 执行自定义回调函数 on_epoch_begin
  • 5、通过自加载机制加载rng状态

  • 跳过了当前epoch的迭代过程,在resume时继续使用后续批次的数据集

  • 源码被调用

    • skip_first_batches源码解读
  • 7、训练步骤的循环迭代过程

    • 8、在满足退出epoch循环的条件时停止自监督学习过程
    • 9、self.callback_handler.on_epoch_end方法
      • a、触发函数调用流程
      • b、返回相应的计算结果
      • c、对计算结果进行详细说明
  • 10、评估与权重保存(self._maybe_log_save_evaluate)

    • a、函数被调用
    • b、模块内部实现的相应算法逻辑进行解析
    • c、模块内部实现的self._save_checkpoint函数的源代码进行解析
    • d、模块内部实现的self._save_checkpoint函数的关键功能进行详细描述

      • 11、停止epoch外循环迭代

前言

在HuggingFace平台中, trainer组件涉及的内容极为丰富,在全面阐述 trainer相关内容时确实需要多篇文献支撑。经过深思熟虑后决定采用连环文章的形式进行系统性解析:第一篇介绍TrainingArguments与trainner参数的基本设置;第二篇则通过完整案例展示trainner(train与_inner_training_loop)的整体架构设计;第三篇深入解读inner_training_loop模块的数据准备、优化器配置等核心代码实现;第四篇重点解析inner_training_loop中循环训练阶段的外层循环机制;第五篇深入探讨step内循环训练阶段的具体实现细节;第六篇文章则着重讲述Resume方法的应用场景及其如何通过继承数据、继承优化器、继承模型等实现断点续训功能。而本篇文章作为第四章内容,则聚焦于全面解析HuggingFace trainer组件内TrainingLoop系统的外层循环训练机制及其相关代码实现细节

在huggingface中


[第一篇文章]:此篇文章详细探讨了...(2024年1月1日发布于博客)
[第二篇文章]:这篇技术文章深入分析了...(2024年1月2日发布于博客)
[第三篇文章]:此篇文章提供了具体的...(2024年1月3日发布于博客)
[第四篇文章]:相关技术文章进一步讨论了...(2024年1月4日发布于博客)
[第五篇文章]:此篇文章提供了详细的...(2024年1月5日发布于博客)
[第六篇文章]:相关技术文章进一步探讨了...(2024年1月6日发布于博客)

一、完整源码呈现

该部分呈现训练相关参数等源码与完整训练源码呈现。

1、训练前源码

最初打算直接从训练代码出发解读源码。但我经过一番考虑后发现,为了更好地理解整个系统的工作原理,还是有必要对相关的准备工作进行安排。于是决定,在此之前有必要对相关的准备工作进行安排,例如状态信息、相关的训练参数等具体内容均需要在其初始化阶段被赋值或更新。因此我决定将这部分代码的相关内容整理出来分享给读者

最初打算直接从训练代码出发解读源码。但我经过一番考虑后发现,为了更好地理解整个系统的工作原理,还是有必要对相关的准备工作进行安排

复制代码
     
        self.state.epoch = 0
        start_time = time.time()
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        steps_trained_progress_bar = None
    
        # Check if continuing training from a checkpoint
        if resume_from_checkpoint is not None and os.path.isfile(
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
        ):
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
            if not args.ignore_data_skip:
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
            else:
                steps_trained_in_current_epoch = 0
    
            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
            if not args.ignore_data_skip:
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first"
                    f" {steps_trained_in_current_epoch} batches in the first epoch."
                )
    
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else: # finetune-lora
            self.state.trial_params = None
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
    
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
        tr_loss = torch.tensor(0.0).to(args.device)
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
        self._globalstep_last_logged = self.state.global_step
        model.zero_grad()
    
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
    
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
        if not args.ignore_data_skip:
            for epoch in range(epochs_trained):
                for _ in train_dataloader:
                    break
    
        total_batched_samples = 0

3、训练源码(epoch)

该部分是训练源码呈现。

复制代码
        total_batched_samples = 0
        for epoch in range(epochs_trained, num_train_epochs):  # 训练总共epoch数,我们的传参
            epoch_iterator = train_dataloader
    
            # Reset the past mems state at the beginning of each epoch if necessary.
            if args.past_index >= 0:
                self._past = None
    
            steps_in_epoch = (
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
            )
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
    
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)
    
            rng_to_sync = False
            steps_skipped = 0
            if steps_trained_in_current_epoch > 0:
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
                steps_skipped = steps_trained_in_current_epoch
                steps_trained_in_current_epoch = 0
                rng_to_sync = True
    
            step = -1
            for step, inputs in enumerate(epoch_iterator):
                total_batched_samples += 1
                
                ...
    
    
            if step < 0:
                logger.warning(
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
    
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
    
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
            if self.control.should_training_stop:
                break

二、训练epoch循环源码解读

1、epoch循环体前的变量

这里,我简单介绍下epoch循环对应变量。

复制代码
        len_dataloader = None
        if has_length(train_dataloader): 
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
            num_examples = self.num_examples(train_dataloader)  # 获得样本数量
     		
     		...
     
        self.state.epoch = 0  # 先赋值默认epoch为0
        start_time = time.time()  # 时间
        epochs_trained = 0  # for epoch in range(epochs_trained, num_train_epochs)开始epoch开始时间,后面会根据resume变化
        steps_trained_in_current_epoch = 0  # 当前epoch训练迭代step值
        steps_trained_progress_bar = None
        # Check if continuing training from a checkpoint,根据resume方式重新更新赋值,特别是路径resume_from_checkpoint
        if resume_from_checkpoint is not None and os.path.isfile(
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
        ):
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))  #路径下有 trainer_state.json重载状态
            epochs_trained = self.state.global_step // num_update_steps_per_epoch  # 模型记录总迭数除以每个epoch迭代数的商就是已经训练得epoch
            if not args.ignore_data_skip: # 如果参数ignore_data_skip则继承每个epoch的step,否则忽略
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) # 取余获得当前epoch已迭代数
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps # 更新该值
            else:
                steps_trained_in_current_epoch = 0 # 跳过为0,每个epoch重新开始
            
            ...
    
        # Update the references,给回调函数
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
    		...
    		
    
    		# tr_loss is a tensor to avoid synchronization of TPUs through .item()
        tr_loss = torch.tensor(0.0).to(args.device) # 设置默认loss
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
    
    
    		# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
        if not args.ignore_data_skip:
            for epoch in range(epochs_trained):  # 跳过epochs_trained的epoch
                for _ in train_dataloader:
                    break

基于给定的某些变量,用于解释其相关变量的含义。该假设下共有48个样本;批量大小为3;单块GPU运行;每个 epoch 迭代2次;训练过程中在0.5 epoch时终止。

复制代码
    len_dataloader :获得dataloader的数量,以假设为例,该值为16;
    self.state.global_step:记录模型训练的总迭代数,这里通过resume的json文件给出;
    num_update_steps_per_epoch :每个epoch需要迭代次数,以假设为例,该值为16=len_dataloader // args.gradient_accumulation_steps ;
    args.ignore_data_skip:参数,决定是否跳过终止epoch未训练完的迭代数,False表示不跳过,True表示跳过;
    steps_trained_in_current_epoch:记录当前epoch已迭代位置数量,与args.ignore_data_skip参数联合决定,上面代码已注释;
    total_batched_samples :设置默认值0;

2、开始epochs循环

epoch循环调用代码:

复制代码
    for epoch in range(epochs_trained, num_train_epochs):  # 训练总共epoch数,我们的传参

可以看出,在循环过程中 epoch 范围由 epochs_trained 和 num_train_epochs 共享,并且 num_train_epochs 的取值源自提供的 args 参数。其中 epochs_trained 的具体数值可在上文找到,默认设为 0 值。若采用 resume 策略,则会进行相应的更新操作。

2、训练数据

训练数据即每个epoch对train_dataloader进行重置或重新赋值,并且其中,在上一节中已经详细介绍了该对象。随后的任务即是确定并计算每个 epoch 所需的迭代次数,并将这一数值通过 steps_in_epoch 变量进行表示;这一过程相对简单,并无需额外解释或展开说明。

复制代码
            epoch_iterator = train_dataloader

3、每个epoch迭代数step获得

获去每个epoch需要迭代数量,一般就是dataloader的len数量。

复制代码
            # Reset the past mems state at the beginning of each epoch if necessary.
            if args.past_index >= 0:
                self._past = None
    
            steps_in_epoch = (
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
            )

4、self.callback_handler.on_epoch_begin

回调on_epoch_begin函数,若有变化则改变。

调用函数

复制代码
    self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

其参数如下图,其中args只是部分截图。

在这里插入图片描述

on_epoch_begin函数

在给定的代码块中,在control对象的属性should_epoch_stop属性上赋值一个布尔值False之后,通过调用self.call_event()函数来进行事件回调操作,并将其绑定到名为on_epoch_begin的目标事件上。

复制代码
    def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        control.should_epoch_stop = False
        return self.call_event("on_epoch_begin", args, state, control)

通过调用回调函数on_epoch_begin$来获取相关内容;如果有任何变化将进行必要的更新操作。

复制代码
    def call_event(self, event, args, state, control, **kwargs):
        for callback in self.callbacks:
            result = getattr(callback, event)(
                args,
                state,
                control,
                model=self.model,
                tokenizer=self.tokenizer,
                optimizer=self.optimizer,
                lr_scheduler=self.lr_scheduler,
                train_dataloader=self.train_dataloader,
                eval_dataloader=self.eval_dataloader,
                **kwargs,
            )
            # A Callback can skip the return of `control` if it doesn't change it.
            if result is not None:
                control = result
        return control

5、self._load_rng_state载入

在满足以下条件时候需载入self._load_rng_state。

复制代码
    if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

6、跳过当前epoch的迭代skip_first_batches(resume时候使用)

源码调用

在resume时, 应跳过已经完成训练的数据进行使用, 该变量的当前值代表已完成训练的步骤数, 则应采取以下措施暂时将该变量设为零。否则,在首次训练时或忽略迭代的情况下, 其初始值设为零。

复制代码
            rng_to_sync = False
            steps_skipped = 0
            if steps_trained_in_current_epoch > 0:
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
                steps_skipped = steps_trained_in_current_epoch
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

skip_first_batches源码解读

复制代码
    def skip_first_batches(dataloader, num_batches=0):
    """
    Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
    """
    dataset = dataloader.dataset
    sampler_is_batch_sampler = False
    if isinstance(dataset, IterableDataset):
        new_batch_sampler = None
    else:
        sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
        batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
        new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
    
    # We ignore all of those since they are all dealt with by our new_batch_sampler
    ignore_kwargs = [
        "batch_size",
        "shuffle",
        "sampler",
        "batch_sampler",
        "drop_last",
    ]
    
    kwargs = {
        k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
        for k in _PYTORCH_DATALOADER_KWARGS
        if k not in ignore_kwargs
    }
    
    # Need to provide batch_size as batch_sampler is None for Iterable dataset
    if new_batch_sampler is None:
        kwargs["drop_last"] = dataloader.drop_last
        kwargs["batch_size"] = dataloader.batch_size
    
    if isinstance(dataloader, DataLoaderDispatcher):
        if new_batch_sampler is None:
            # Need to manually skip batches in the dataloader
            kwargs["skip_batches"] = num_batches
        dataloader = DataLoaderDispatcher(
            dataset,
            split_batches=dataloader.split_batches,
            batch_sampler=new_batch_sampler,
            _drop_last=dataloader._drop_last,
            **kwargs,
        )
    elif isinstance(dataloader, DataLoaderShard):
        if new_batch_sampler is None:
            # Need to manually skip batches in the dataloader
            kwargs["skip_batches"] = num_batches
        elif sampler_is_batch_sampler:
            kwargs["sampler"] = new_batch_sampler
            kwargs["batch_size"] = dataloader.batch_size
        else:
            kwargs["batch_sampler"] = new_batch_sampler
        dataloader = DataLoaderShard(
            dataset,
            device=dataloader.device,
            rng_types=dataloader.rng_types,
            synchronized_generator=dataloader.synchronized_generator,
            **kwargs,
        )
    else:
        if new_batch_sampler is None:
            # Need to manually skip batches in the dataloader
            dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
        else:
            dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
    
    return dataloader

涉及到resume的数据继承载入!

7、训练step循环迭代

接下来将开始进入dataloader的循环迭代过程,在每个epoch阶段中将迭代step配置为-1,并采用与主流模型相似的方式运行。由于涉及后续内容的深入讨论,在后年专门撰写相关文章。

复制代码
           step = -1
            for step, inputs in enumerate(epoch_iterator):
                total_batched_samples += 1
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False

8、跳出epoch循环条件更新(self.control.should_training_stop)

设置 step 变量在每次 DataLoader 迭代开始时初始值设为 -1。当 DataLoader 执行迭代操作时会自动更新当前的步数变量。在每个 epoch 开始时将步数变量初始化回初始状态(即 -1),待整个训练任务完成并结束当前 epoch 循环后 就可以将 self.control.should_training_stop 设置为 True 从而实现提前退出 epoch 训练的目的。

复制代码
    step = -1
    ...
    if step < 0:
    logger.warning(
        "There seems to be not a single sample in your epoch_iterator, stopping training at step"
        f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
        f" num_steps ({max_steps}) higher than the number of available samples."
    )
    self.control.should_training_stop = True

9、self.callback_handler.on_epoch_end方法

a、函数调用

这部分工作是经过训练后的后处理。我感觉模型结果应该被存储在self.control中。其调用代码如下:

复制代码
    self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)

b、返回结果

之所以这么推断是最终这个函数运行如下:

复制代码
    def call_event(self, event, args, state, control, **kwargs):
        for callback in self.callbacks:
            result = getattr(callback, event)(
                args,
                state,
                control,
                model=self.model,
                tokenizer=self.tokenizer,
                optimizer=self.optimizer,
                lr_scheduler=self.lr_scheduler,
                train_dataloader=self.train_dataloader,
                eval_dataloader=self.eval_dataloader,
                **kwargs,
            )
            # A Callback can skip the return of `control` if it doesn't change it.
            if result is not None:
                control = result
        return control

当存在情况时(即存在某个条件满足时),返回的control变量将对应于result变量,并且其中的结果字段完整地记录了模型的所有相关信息(如上)。

在这里插入图片描述

c、结果说明

可以观察到模型返回了自定义化后的参数包括自定义化后的token化器、优化器、学习率调度器以及训练数据加载器等。其中学习率被设定为2e-4,在上图中对应的值则为1e-4这一结果显然是通过lr_scheduler进行调节得到的。值得注意的是,在此过程中我们特别关注了模型的状态信息

10、评估与权重保存(self._maybe_log_save_evaluate)

a、函数调用

复制代码
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)

b、函数源码解读

随后我们进入这个函数,其源码如下:

复制代码
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
        if self.control.should_log:
           ...
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
            logs["learning_rate"] = self._get_learning_rate()
    			...
            self.log(logs)
        metrics = None
        if self.control.should_evaluate:
            if isinstance(self.eval_dataset, dict):
                metrics = {}
    				...
            self._report_to_hp_search(trial, self.state.global_step, metrics)
    
            # Run delayed LR scheduler now that metrics are populated
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                metric_to_check = self.args.metric_for_best_model
                if not metric_to_check.startswith("eval_"):
                    metric_to_check = f"eval_{metric_to_check}"
                self.lr_scheduler.step(metrics[metric_to_check])
    
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

可以看出上面源码基于相关参数来保存所需的内容。需要注意的是,由self.control.should_save决定是否保存更多的内容,并且尤其是权重相关的数据。

可以看出上面源码基于相关参数来保存所需的内容。需要注意的是,由\texttt{self.control.should\_save}这一控制变量决定是否进行更详细的存储,尤其是在涉及权重的数据时更为突出。

c、self._save_checkpoint函数源码解读

该函数负责阐述相关权重或模型状态的信息存储情况,并且涉及如optimizer.pt、trainer_state.json、scheduler.pt等文件的存储。

复制代码
    def _save_checkpoint(self, model, trial, metrics=None):
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
        # want to save except FullyShardedDDP.
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
        # Save model checkpoint
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" # 浏览获得文件夹路径 checkpoint-48
        if self.hp_search_backend is None and trial is None:
            self.store_flos()
    
        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)
        self.save_model(output_dir, _internal_call=True)
        if self.is_deepspeed_enabled:
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
            # config `stage3_gather_16bit_weights_on_model_save` is True
            self.model_wrapped.save_checkpoint(output_dir)
    
        # Save optimizer and scheduler
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
            self.optimizer.consolidate_state_dict()
    
        if self.fsdp or self.is_fsdp_enabled:
            if self.is_fsdp_enabled:
                save_fsdp_optimizer(
                    self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
                )
            else:
                # FSDP has a different interface for saving optimizer states.
                # Needs to be called on all ranks to gather all states.
                # full_optim_state_dict will be deprecated after Pytorch 2.2!
                full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
    
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) # OPTIMIZER_NAME=optimizer.pt
            with warnings.catch_warnings(record=True) as caught_warnings:
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) # SCHEDULER_NAME=scheduler.pt
                reissue_pt_warnings(caught_warnings)
        elif is_sagemaker_mp_enabled():
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),   # OPTIMIZER_NAME=optimizer.pt
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
        elif self.args.should_save and not self.is_deepspeed_enabled:
            # deepspeed.save_checkpoint above saves model/optim/sched
            if self.fsdp and not self.is_fsdp_enabled:
                torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
            else:
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
    
            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
            reissue_pt_warnings(caught_warnings)
            if self.do_grad_scaling:
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
    
        # Determine the new best metric / best model checkpoint
        if metrics is not None and self.args.metric_for_best_model is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]
    
            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir
    
        # Save the Trainer state
        if self.args.should_save:
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
    
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()
    
        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()
    
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
    
        if self.args.world_size <= 1:
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
    
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)
    
        # Maybe delete some older checkpoints.
        if self.args.should_save:
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

d、self._save_checkpoint重点说明

我们发现保存是直接指到state_dict()内容,其源码证明如下:

复制代码
     torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
     torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
     torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
在这里插入图片描述

11、停止epoch外循环迭代

由控制参数self.control.should_training_stop触发停止条件,并明确地被设置为与之前的决策相关。

复制代码
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
            if self.control.should_training_stop:
                break

全部评论 (0)

还没有任何评论哟~