第十五节 huggingface的trainner的_inner_training_loop函数源码解读(epoch)
文章目录
-
前言
-
一、完整源码呈现
-
- 1、训练前源码
- 3、训练源码(epoch)
-
二. 解析训练周期代码流程
-
- 1. 在 epoch 循环开始前的变量设置
- 2. 启动 epochs 循环
- 2. 使用训练数据集进行处理
- 3. 确定每个 epoch 中 step 的数量
- 4. 执行自定义回调函数 on_epoch_begin
-
-
5、通过自加载机制加载rng状态
-
跳过了当前epoch的迭代过程,在resume时继续使用后续批次的数据集
-
源码被调用
- skip_first_batches源码解读
-
7、训练步骤的循环迭代过程
- 8、在满足退出epoch循环的条件时停止自监督学习过程
- 9、self.callback_handler.on_epoch_end方法
- a、触发函数调用流程
- b、返回相应的计算结果
- c、对计算结果进行详细说明
-
10、评估与权重保存(self._maybe_log_save_evaluate)
-
- a、函数被调用
-
- b、模块内部实现的相应算法逻辑进行解析
-
- c、模块内部实现的self._save_checkpoint函数的源代码进行解析
-
-
d、模块内部实现的self._save_checkpoint函数的关键功能进行详细描述
- 11、停止epoch外循环迭代
-
前言
在HuggingFace平台中, trainer组件涉及的内容极为丰富,在全面阐述 trainer相关内容时确实需要多篇文献支撑。经过深思熟虑后决定采用连环文章的形式进行系统性解析:第一篇介绍TrainingArguments与trainner参数的基本设置;第二篇则通过完整案例展示trainner(train与_inner_training_loop)的整体架构设计;第三篇深入解读inner_training_loop模块的数据准备、优化器配置等核心代码实现;第四篇重点解析inner_training_loop中循环训练阶段的外层循环机制;第五篇深入探讨step内循环训练阶段的具体实现细节;第六篇文章则着重讲述Resume方法的应用场景及其如何通过继承数据、继承优化器、继承模型等实现断点续训功能。而本篇文章作为第四章内容,则聚焦于全面解析HuggingFace trainer组件内TrainingLoop系统的外层循环训练机制及其相关代码实现细节
在huggingface中
[第一篇文章]:此篇文章详细探讨了...(2024年1月1日发布于博客)
[第二篇文章]:这篇技术文章深入分析了...(2024年1月2日发布于博客)
[第三篇文章]:此篇文章提供了具体的...(2024年1月3日发布于博客)
[第四篇文章]:相关技术文章进一步讨论了...(2024年1月4日发布于博客)
[第五篇文章]:此篇文章提供了详细的...(2024年1月5日发布于博客)
[第六篇文章]:相关技术文章进一步探讨了...(2024年1月6日发布于博客)
一、完整源码呈现
该部分呈现训练相关参数等源码与完整训练源码呈现。
1、训练前源码
最初打算直接从训练代码出发解读源码。但我经过一番考虑后发现,为了更好地理解整个系统的工作原理,还是有必要对相关的准备工作进行安排。于是决定,在此之前有必要对相关的准备工作进行安排,例如状态信息、相关的训练参数等具体内容均需要在其初始化阶段被赋值或更新。因此我决定将这部分代码的相关内容整理出来分享给读者
最初打算直接从训练代码出发解读源码。但我经过一番考虑后发现,为了更好地理解整个系统的工作原理,还是有必要对相关的准备工作进行安排
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
if self.hp_name is not None and self._trial is not None:
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
# parameter to Train when using DDP.
self.state.trial_name = self.hp_name(self._trial)
if trial is not None:
assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
self.state.trial_params = hp_params(assignments)
else: # finetune-lora
self.state.trial_params = None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
for _ in train_dataloader:
break
total_batched_samples = 0
3、训练源码(epoch)
该部分是训练源码呈现。
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs): # 训练总共epoch数,我们的传参
epoch_iterator = train_dataloader
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_iterator)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
for step, inputs in enumerate(epoch_iterator):
total_batched_samples += 1
...
if step < 0:
logger.warning(
"There seems to be not a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
else:
logger.warning(
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
if self.control.should_training_stop:
break
二、训练epoch循环源码解读
1、epoch循环体前的变量
这里,我简单介绍下epoch循环对应变量。
len_dataloader = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = self.num_examples(train_dataloader) # 获得样本数量
...
self.state.epoch = 0 # 先赋值默认epoch为0
start_time = time.time() # 时间
epochs_trained = 0 # for epoch in range(epochs_trained, num_train_epochs)开始epoch开始时间,后面会根据resume变化
steps_trained_in_current_epoch = 0 # 当前epoch训练迭代step值
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint,根据resume方式重新更新赋值,特别是路径resume_from_checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) #路径下有 trainer_state.json重载状态
epochs_trained = self.state.global_step // num_update_steps_per_epoch # 模型记录总迭数除以每个epoch迭代数的商就是已经训练得epoch
if not args.ignore_data_skip: # 如果参数ignore_data_skip则继承每个epoch的step,否则忽略
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) # 取余获得当前epoch已迭代数
steps_trained_in_current_epoch *= args.gradient_accumulation_steps # 更新该值
else:
steps_trained_in_current_epoch = 0 # 跳过为0,每个epoch重新开始
...
# Update the references,给回调函数
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
...
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device) # 设置默认loss
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained): # 跳过epochs_trained的epoch
for _ in train_dataloader:
break
基于给定的某些变量,用于解释其相关变量的含义。该假设下共有48个样本;批量大小为3;单块GPU运行;每个 epoch 迭代2次;训练过程中在0.5 epoch时终止。
len_dataloader :获得dataloader的数量,以假设为例,该值为16;
self.state.global_step:记录模型训练的总迭代数,这里通过resume的json文件给出;
num_update_steps_per_epoch :每个epoch需要迭代次数,以假设为例,该值为16=len_dataloader // args.gradient_accumulation_steps ;
args.ignore_data_skip:参数,决定是否跳过终止epoch未训练完的迭代数,False表示不跳过,True表示跳过;
steps_trained_in_current_epoch:记录当前epoch已迭代位置数量,与args.ignore_data_skip参数联合决定,上面代码已注释;
total_batched_samples :设置默认值0;
2、开始epochs循环
epoch循环调用代码:
for epoch in range(epochs_trained, num_train_epochs): # 训练总共epoch数,我们的传参
可以看出,在循环过程中 epoch 范围由 epochs_trained 和 num_train_epochs 共享,并且 num_train_epochs 的取值源自提供的 args 参数。其中 epochs_trained 的具体数值可在上文找到,默认设为 0 值。若采用 resume 策略,则会进行相应的更新操作。
2、训练数据
训练数据即每个epoch对train_dataloader进行重置或重新赋值,并且其中,在上一节中已经详细介绍了该对象。随后的任务即是确定并计算每个 epoch 所需的迭代次数,并将这一数值通过 steps_in_epoch 变量进行表示;这一过程相对简单,并无需额外解释或展开说明。
epoch_iterator = train_dataloader
3、每个epoch迭代数step获得
获去每个epoch需要迭代数量,一般就是dataloader的len数量。
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_iterator)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
4、self.callback_handler.on_epoch_begin
回调on_epoch_begin函数,若有变化则改变。
调用函数
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
其参数如下图,其中args只是部分截图。

on_epoch_begin函数
在给定的代码块中,在control对象的属性should_epoch_stop属性上赋值一个布尔值False之后,通过调用self.call_event()函数来进行事件回调操作,并将其绑定到名为on_epoch_begin的目标事件上。
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_epoch_stop = False
return self.call_event("on_epoch_begin", args, state, control)
通过调用回调函数on_epoch_begin$来获取相关内容;如果有任何变化将进行必要的更新操作。
def call_event(self, event, args, state, control, **kwargs):
for callback in self.callbacks:
result = getattr(callback, event)(
args,
state,
control,
model=self.model,
tokenizer=self.tokenizer,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader,
eval_dataloader=self.eval_dataloader,
**kwargs,
)
# A Callback can skip the return of `control` if it doesn't change it.
if result is not None:
control = result
return control
5、self._load_rng_state载入
在满足以下条件时候需载入self._load_rng_state。
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
6、跳过当前epoch的迭代skip_first_batches(resume时候使用)
源码调用
在resume时, 应跳过已经完成训练的数据进行使用, 该变量的当前值代表已完成训练的步骤数, 则应采取以下措施暂时将该变量设为零。否则,在首次训练时或忽略迭代的情况下, 其初始值设为零。
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
skip_first_batches源码解读
def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
"""
dataset = dataloader.dataset
sampler_is_batch_sampler = False
if isinstance(dataset, IterableDataset):
new_batch_sampler = None
else:
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
# We ignore all of those since they are all dealt with by our new_batch_sampler
ignore_kwargs = [
"batch_size",
"shuffle",
"sampler",
"batch_sampler",
"drop_last",
]
kwargs = {
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
for k in _PYTORCH_DATALOADER_KWARGS
if k not in ignore_kwargs
}
# Need to provide batch_size as batch_sampler is None for Iterable dataset
if new_batch_sampler is None:
kwargs["drop_last"] = dataloader.drop_last
kwargs["batch_size"] = dataloader.batch_size
if isinstance(dataloader, DataLoaderDispatcher):
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
kwargs["skip_batches"] = num_batches
dataloader = DataLoaderDispatcher(
dataset,
split_batches=dataloader.split_batches,
batch_sampler=new_batch_sampler,
_drop_last=dataloader._drop_last,
**kwargs,
)
elif isinstance(dataloader, DataLoaderShard):
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
kwargs["skip_batches"] = num_batches
elif sampler_is_batch_sampler:
kwargs["sampler"] = new_batch_sampler
kwargs["batch_size"] = dataloader.batch_size
else:
kwargs["batch_sampler"] = new_batch_sampler
dataloader = DataLoaderShard(
dataset,
device=dataloader.device,
rng_types=dataloader.rng_types,
synchronized_generator=dataloader.synchronized_generator,
**kwargs,
)
else:
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
return dataloader
涉及到resume的数据继承载入!
7、训练step循环迭代
接下来将开始进入dataloader的循环迭代过程,在每个epoch阶段中将迭代step配置为-1,并采用与主流模型相似的方式运行。由于涉及后续内容的深入讨论,在后年专门撰写相关文章。
step = -1
for step, inputs in enumerate(epoch_iterator):
total_batched_samples += 1
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
8、跳出epoch循环条件更新(self.control.should_training_stop)
设置 step 变量在每次 DataLoader 迭代开始时初始值设为 -1。当 DataLoader 执行迭代操作时会自动更新当前的步数变量。在每个 epoch 开始时将步数变量初始化回初始状态(即 -1),待整个训练任务完成并结束当前 epoch 循环后 就可以将 self.control.should_training_stop 设置为 True 从而实现提前退出 epoch 训练的目的。
step = -1
...
if step < 0:
logger.warning(
"There seems to be not a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True
9、self.callback_handler.on_epoch_end方法
a、函数调用
这部分工作是经过训练后的后处理。我感觉模型结果应该被存储在self.control中。其调用代码如下:
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
b、返回结果
之所以这么推断是最终这个函数运行如下:
def call_event(self, event, args, state, control, **kwargs):
for callback in self.callbacks:
result = getattr(callback, event)(
args,
state,
control,
model=self.model,
tokenizer=self.tokenizer,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader,
eval_dataloader=self.eval_dataloader,
**kwargs,
)
# A Callback can skip the return of `control` if it doesn't change it.
if result is not None:
control = result
return control
当存在情况时(即存在某个条件满足时),返回的control变量将对应于result变量,并且其中的结果字段完整地记录了模型的所有相关信息(如上)。

c、结果说明
可以观察到模型返回了自定义化后的参数包括自定义化后的token化器、优化器、学习率调度器以及训练数据加载器等。其中学习率被设定为2e-4,在上图中对应的值则为1e-4这一结果显然是通过lr_scheduler进行调节得到的。值得注意的是,在此过程中我们特别关注了模型的状态信息
10、评估与权重保存(self._maybe_log_save_evaluate)
a、函数调用
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
b、函数源码解读
随后我们进入这个函数,其源码如下:
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log:
...
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()
...
self.log(logs)
metrics = None
if self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
metrics = {}
...
self._report_to_hp_search(trial, self.state.global_step, metrics)
# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
可以看出上面源码基于相关参数来保存所需的内容。需要注意的是,由self.control.should_save决定是否保存更多的内容,并且尤其是权重相关的数据。
可以看出上面源码基于相关参数来保存所需的内容。需要注意的是,由\texttt{self.control.should\_save}这一控制变量决定是否进行更详细的存储,尤其是在涉及权重的数据时更为突出。
c、self._save_checkpoint函数源码解读
该函数负责阐述相关权重或模型状态的信息存储情况,并且涉及如optimizer.pt、trainer_state.json、scheduler.pt等文件的存储。
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" # 浏览获得文件夹路径 checkpoint-48
if self.hp_search_backend is None and trial is None:
self.store_flos()
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
if self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
self.model_wrapped.save_checkpoint(output_dir)
# Save optimizer and scheduler
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()
if self.fsdp or self.is_fsdp_enabled:
if self.is_fsdp_enabled:
save_fsdp_optimizer(
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
)
else:
# FSDP has a different interface for saving optimizer states.
# Needs to be called on all ranks to gather all states.
# full_optim_state_dict will be deprecated after Pytorch 2.2!
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) # OPTIMIZER_NAME=optimizer.pt
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) # SCHEDULER_NAME=scheduler.pt
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
smp.barrier()
if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
smp.save(
opt_state_dict,
os.path.join(output_dir, OPTIMIZER_NAME), # OPTIMIZER_NAME=optimizer.pt
partial=True,
v3=smp.state.cfg.shard_optimizer_state,
)
if self.args.should_save:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.is_deepspeed_enabled:
# deepspeed.save_checkpoint above saves model/optim/sched
if self.fsdp and not self.is_fsdp_enabled:
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
else:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
# Save the Trainer state
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
# Save RNG state in non-distributed training
rng_states = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"cpu": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
else:
rng_states["cuda"] = torch.cuda.random.get_rng_state()
if is_torch_tpu_available():
rng_states["xla"] = xm.get_rng_state()
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
if self.args.world_size <= 1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
d、self._save_checkpoint重点说明
我们发现保存是直接指到state_dict()内容,其源码证明如下:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))

11、停止epoch外循环迭代
由控制参数self.control.should_training_stop触发停止条件,并明确地被设置为与之前的决策相关。
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
else:
logger.warning(
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
if self.control.should_training_stop:
break
