Advertisement

AutoGluon处理多模态数据方法及案例——Multimodal Data Tables: Tabular, Text, and Image

阅读量:

多模式数据表:表格、文本和图像

注意:本教程依赖 GPU 才能进行图像与文本模型的训练。同时,在使用 MXNet 和 Torch 时,请确保安装相应的 CUDA 版本以支持 GPU 加密。

PetFinder 数据集

我们依赖于Kaggle上的PetFinder 数据集。该数据集包含出现在收养档案中的收容所动物信息,并旨在预测这些动物被收养的可能性。我们的目标是利用预测出的收养率来发现那些可以通过改进其收养档案而更容易找到新家的动物。

每只动物的收养档案包含了多种信息包括图片文字描述以及表格特征如年龄品种名称和颜色等

首先我们需要下载数据集。包含图像的数据集不仅仅只需要 CSV 文件而是由于其复杂性我们选择将该数据集存储在 S3 并将其打包成一个 zip 文件以方便后续处理工作。我们计划先下载它并解压内容:

复制代码
 download_dir = './ag_petfinder_tutorial'

    
 zip_file = 'https://automl-mm-bench.s3.amazonaws.com/petfinder_kaggle.zip'
    
  
    
 from autogluon.core.utils.loaders import load_zip
    
 load_zip.unzip(zip_file, unzip_dir=download_dir)

现在数据已经下载并解压,我们来看看内容:

复制代码
 import os

    
 os.listdir(download_dir)
    
  
    
  
    
 ['petfinder_processed', 'file.zip']

我们下载了一个名为'file.zip'的原始zip档案文件;'petfinder_processed'这个目录专门用于存储来自该数据集的数据文件

复制代码
 dataset_path = download_dir + '/petfinder_processed'

    
 os.listdir(dataset_path)
    
  
    
  
    
 ['test.csv', 'dev.csv', 'test_images', 'train_images', 'train.csv']

在这一区域中,我们可观察到三个数据集:train,test和dev.此外有两个目录:'test_images'和'train_images',其中包含了JPG格式的图像.

注意:我们采用dev数据集来进行测试,并选择该数据集的原因在于其内部包含了展示分数的核心事实标签predictor.leaderboard。建议我们查看位于train_images目录内的前十份文件

复制代码
 os.listdir(dataset_path + '/train_images')[:10]

    
  
    
 ['ca587cb42-1.jpg',
    
  'ae00eded4-4.jpg',
    
  '6e3457b81-2.jpg',
    
  'acb248693-1.jpg',
    
  '0bd867d1b-1.jpg',
    
  'fa53dd6cd-1.jpg',
    
  '9726ab93e-1.jpg',
    
  '39818f12c-2.jpg',
    
  '90ce48a71-2.jpg',
    
  '2ece6b26b-1.jpg']

接下来,我们将加载 train 和 dev CSV 文件:

复制代码
 import pandas as pd

    
  
    
 train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0)
    
 test_data = pd.read_csv(f'{dataset_path}/dev.csv', index_col=0)
复制代码
    train_data.head(3)
复制代码
  Type    Name    Age     Breed1  Breed2  Gender  Color1  Color2  Color3  MaturitySize    ...     Quantity        Fee     State   RescuerID       VideoAmt        Description     PetID   PhotoAmt        AdoptionSpeed   Images

    
 10721	1	Elbi	2	307	307	2	5	0	0	3	...	1	0	41336	e9a86209c54f589ba72c345364cf01aa	0	I'm looking for people to adopt my dog	e4b90955c	4.0	4	train_images/e4b90955c-1.jpg;train_images/e4b9...
    
 13114	2	Darling	4	266	0	1	1	0	0	2	...	1	0	41401	01f954cdf61526daf3fbeb8a074be742	0	Darling was born at the back lane of Jalan Alo...	a0c1384d1	5.0	3	train_images/a0c1384d1-1.jpg;train_images/a0c1...
    
 13194	1	Wolf	3	307	0	1	1	2	0	2	...	1	0	41332	6e19409f2847326ce3b6d0cec7e42f81	0	I found Wolf about a month ago stuck in a drai...	cf357f057	7.0	4	train_images/cf357f057-1.jpg;train_images/cf35..

3行×25列

以前面三个示例为例可以看出包含多种表格属性、文本描述项以及图像路径信息

基于PetFinder数据集,我们将试图预测动物被收养的速度(AdoptionSpeed),并将其划分为五个类别。这表明我们正在处理一个多分类问题。

复制代码
 label = 'AdoptionSpeed'

    
 image_col = 'Images'

让我们看一下图像列中的值是什么样的:

复制代码
 train_data[image_col].iloc[0]

    
  
    
 'train_images/e4b90955c-1.jpg;train_images/e4b90955c-2.jpg;train_images/e4b90955c-3.jpg;train_images/e4b90955c-4.jpg'

目前AutoGluon仅限于每行一张图片。因为PetFinder数据集中的每一行通常包含一个或多个图片,在进行后续操作前我们需要对该图片列进行预处理以确保其仅包含该行的第一个图片。

复制代码
 train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0])

    
 test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0])
    
  
    
 train_data[image_col].iloc[0]
    
  
    
  
    
 'train_images/e4b90955c-1.jpg'

AutoGluon 根据图像列提供的文件路径加载图像。

在这里,我们更新路径以指向磁盘上的正确位置:

复制代码
 def path_expander(path, base_folder):

    
     path_l = path.split(';')
    
     return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])
    
  
    
 train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
    
 test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
    
  
    
 train_data[image_col].iloc[0]
    
  
    
 '/var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/ag_petfinder_tutorial/petfinder_processed/train_images/e4b90955c-1.jpg'
复制代码
 example_row = train_data.iloc[1]

    
  
    
 example_row
    
  
    
 Type                                                             2
    
 Name                                                       Darling
    
 Age                                                              4
    
 Breed1                                                         266
    
 Breed2                                                           0
    
 Gender                                                           1
    
 Color1                                                           1
    
 Color2                                                           0
    
 Color3                                                           0
    
 MaturitySize                                                     2
    
 FurLength                                                        1
    
 Vaccinated                                                       2
    
 Dewormed                                                         2
    
 Sterilized                                                       2
    
 Health                                                           1
    
 Quantity                                                         1
    
 Fee                                                              0
    
 State                                                        41401
    
 RescuerID                         01f954cdf61526daf3fbeb8a074be742
    
 VideoAmt                                                         0
    
 Description      Darling was born at the back lane of Jalan Alo...
    
 PetID                                                    a0c1384d1
    
 PhotoAmt                                                       5.0
    
 AdoptionSpeed                                                    3
    
 Images           /var/lib/jenkins/workspace/workspace/autogluon...
    
 Name: 13114, dtype: object
复制代码
 example_row['Description']

    
  
    
  
    
 'Darling was born at the back lane of Jalan Alor and was foster by a feeder. All his siblings had died of accident. His mother and grandmother had just been spayed. Darling make a great condo/apartment cat. He love to play a lot. He would make a great companion for someone looking for a cat to love.'
复制代码
 example_image = example_row['Images']

    
  
    
 from IPython.display import Image, display
    
 pil_img = Image(filename=example_image)
    
 display(pil_img)

PetFinder 数据集相当大。出于本教程的目的,我们将采样 500 行进行训练。

在处理大型多模态数据集时可能需要极其密集的计算资源,在 AutoGluon 的预设配置下尤为突出。当进行原型设计阶段时, 建议从数据集中采样来识别哪些模型值得进一步训练. 与传统机器学习算法类似地, 逐步引入更大规模的数据集并延长训练时间限制来优化模型性能.

复制代码
    train_data = train_data.sample(500, random_state=0)

构建特征元数据

接下来,请分析一下AutoGluon是如何通过创建一个FeatureMetadata对象来识别其在训练数据中推断出的特征类型。

复制代码
 from autogluon.tabular import FeatureMetadata

    
 feature_metadata = FeatureMetadata.from_df(train_data)
    
  
    
 print(feature_metadata)
    
  
    
  
    
 ('float', [])        :  1 | ['PhotoAmt']
    
 ('int', [])          : 19 | ['Type', 'Age', 'Breed1', 'Breed2', 'Gender', ...]
    
 ('object', [])       :  4 | ['Name', 'RescuerID', 'PetID', 'Images']
    
 ('object', ['text']) :  1 | ['Description']

注意,在FeatureMetadata中,“描述”列会被自动识别为文本类型;由此可见,我们无需自行指定其为文本类型。

为了指定 AutoGluon 需要识别图像路径所在的列信息。我们可以通过创建并配置一个自定义的 FeatureMetadata 对象,并将其 'image_path' 特殊类型字段添加至图像列来完成这一配置过程。随后我们将这个自定义的 FeatureMetadata 传递给 TabularPredictor.fit 进行训练。

复制代码
 feature_metadata = feature_metadata.add_special_types({image_col: ['image_path']})

    
  
    
 print(feature_metadata)
    
  
    
 ('float', [])              :  1 | ['PhotoAmt']
    
 ('int', [])                : 19 | ['Type', 'Age', 'Breed1', 'Breed2', 'Gender', ...]
    
 ('object', [])             :  3 | ['Name', 'RescuerID', 'PetID']
    
 ('object', ['image_path']) :  1 | ['Images']
    
 ('object', ['text'])       :  1 | ['Description']

指定超参数

随后确定模型目标任务是由hyperparametersTabularPredictor.fit 的配置项完成的。

AutoGluon 内置了一个专为'多模式'数据集设计的预定义配置方案。该配置可通过以下途径获取:

复制代码
 from autogluon.tabular.configs.hyperparameter_configs import get_hyperparameter_config

    
 hyperparameters = get_hyperparameter_config('multimodal')
    
  
    
 hyperparameters
    
  
    
 {'NN_TORCH': {},
    
  'GBM': [{},
    
   {'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}},
    
   'GBMLarge'],
    
  'CAT': {},
    
  'XGB': {},
    
  'AG_TEXT_NN': {'presets': 'medium_quality_faster_train'},
    
  'AG_IMAGE_NN': {},
    
  'VW': {}}

该超参数设置将被用来训练多种表格模型以及微调预训练语言模型(Electra-BERT)用于文本分类,并结合ResNet图像分类模型进行图像分析。

用 TabularPredictor 拟合

我们采用了前文所设定的特征元数据以及超参数,并将其应用于数据集上的TabularPredictor训练。该预测器能够整合表格、文本以及图像等多种功能。

复制代码
 from autogluon.tabular import TabularPredictor

    
 predictor = TabularPredictor(label=label).fit(
    
     train_data=train_data,
    
     hyperparameters=hyperparameters,
    
     feature_metadata=feature_metadata,
    
     time_limit=900,
    
 )
复制代码
 No path specified. Models will be saved in: "AutogluonModels/ag-20220315_003808/"

    
 Beginning AutoGluon training ... Time limit = 900s
    
 AutoGluon will save models to "AutogluonModels/ag-20220315_003808/"
    
 AutoGluon Version:  0.4.0b20220315
    
 Python Version:     3.9.10
    
 Operating System:   Linux
    
 Train Data Rows:    500
    
 Train Data Columns: 24
    
 Label Column: AdoptionSpeed
    
 Preprocessing data ...
    
 AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).
    
     5 unique label values:  [2, 3, 4, 0, 1]
    
     If 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])
    
 Train Data Class Count: 5
    
 Using Feature Generators to preprocess the data ...
    
 Fitting AutoMLPipelineFeatureGenerator...
    
     Available Memory:                    22403.51 MB
    
     Train Data (Original)  Memory Usage: 0.51 MB (0.0% of available memory)
    
     Stage 1 Generators:
    
         Fitting AsTypeFeatureGenerator...
    
                 Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
    
     Stage 2 Generators:
    
         Fitting FillNaFeatureGenerator...
    
     Stage 3 Generators:
    
         Fitting IdentityFeatureGenerator...
    
         Fitting IdentityFeatureGenerator...
    
                 Fitting RenameFeatureGenerator...
    
         Fitting CategoryFeatureGenerator...
    
                 Fitting CategoryMemoryMinimizeFeatureGenerator...
    
         Fitting TextSpecialFeatureGenerator...
    
                 Fitting BinnedFeatureGenerator...
    
                 Fitting DropDuplicatesFeatureGenerator...
    
         Fitting TextNgramFeatureGenerator...
    
                 Fitting CountVectorizer for text features: ['Description']
    
                 CountVectorizer fit with vocabulary size = 170
    
         Fitting IdentityFeatureGenerator...
    
         Fitting IsNanFeatureGenerator...
    
     Stage 4 Generators:
    
         Fitting DropUniqueFeatureGenerator...
    
     Unused Original Features (Count: 1): ['PetID']
    
         These features were not used to generate any of the output features. Add a feature generator compatible with these features to utilize them.
    
         Features can also be unused if they carry very little information, such as being categorical but having almost entirely unique values or being duplicates of other features.
    
         These features do not need to be present at inference time.
    
         ('object', []) : 1 | ['PetID']
    
     Types of features in original data (raw dtype, special dtypes):
    
         ('float', [])              :  1 | ['PhotoAmt']
    
         ('int', [])                : 18 | ['Type', 'Age', 'Breed1', 'Breed2', 'Gender', ...]
    
         ('object', [])             :  2 | ['Name', 'RescuerID']
    
         ('object', ['image_path']) :  1 | ['Images']
    
         ('object', ['text'])       :  1 | ['Description']
    
     Types of features in processed data (raw dtype, special dtypes):
    
         ('category', [])                    :   2 | ['Name', 'RescuerID']
    
         ('category', ['text_as_category'])  :   1 | ['Description']
    
         ('float', [])                       :   1 | ['PhotoAmt']
    
         ('int', [])                         :  17 | ['Age', 'Breed1', 'Breed2', 'Gender', 'Color1', ...]
    
         ('int', ['binned', 'text_special']) :  24 | ['Description.char_count', 'Description.word_count', 'Description.capital_ratio', 'Description.lower_ratio', 'Description.digit_ratio', ...]
    
         ('int', ['bool'])                   :   1 | ['Type']
    
         ('int', ['text_ngram'])             : 171 | ['__nlp__.about', '__nlp__.active', '__nlp__.active and', '__nlp__.adopt', '__nlp__.adopted', ...]
    
         ('object', ['image_path'])          :   1 | ['Images']
    
         ('object', ['text'])                :   1 | ['Description_raw_text']
    
     0.5s = Fit runtime
    
     23 features in original data used to generate 219 features in processed data.
    
     Train Data (Processed) Memory Usage: 0.58 MB (0.0% of available memory)
    
 Data preprocessing and feature engineering runtime = 0.57s ...
    
 AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
    
     To change this, specify the eval_metric parameter of Predictor()
    
 Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 400, Val Rows: 100
    
 Fitting 9 L1 models ...
    
 Fitting model: LightGBM ... Training model for up to 899.43s of the 899.43s of remaining time.
    
     0.34     = Validation score   (accuracy)
    
     1.29s    = Training   runtime
    
     0.01s    = Validation runtime
    
 Fitting model: LightGBMXT ... Training model for up to 898.13s of the 898.13s of remaining time.
    
     0.34     = Validation score   (accuracy)
    
     0.82s    = Training   runtime
    
     0.01s    = Validation runtime
    
 Fitting model: CatBoost ... Training model for up to 897.28s of the 897.28s of remaining time.
    
     0.3      = Validation score   (accuracy)
    
     2.88s    = Training   runtime
    
     0.01s    = Validation runtime
    
 Fitting model: XGBoost ... Training model for up to 894.38s of the 894.38s of remaining time.
    
     0.35     = Validation score   (accuracy)
    
     1.64s    = Training   runtime
    
     0.01s    = Validation runtime
    
 Fitting model: NeuralNetTorch ... Training model for up to 892.73s of the 892.72s of remaining time.
    
     0.35     = Validation score   (accuracy)
    
     1.8s     = Training   runtime
    
     0.02s    = Validation runtime
    
 Fitting model: VowpalWabbit ... Training model for up to 890.9s of the 890.9s of remaining time.
    
     0.24     = Validation score   (accuracy)
    
     0.73s    = Training   runtime
    
     0.03s    = Validation runtime
    
 Fitting model: LightGBMLarge ... Training model for up to 889.81s of the 889.81s of remaining time.
    
     0.37     = Validation score   (accuracy)
    
     2.48s    = Training   runtime
    
     0.01s    = Validation runtime
    
 Fitting model: TextPredictor ... Training model for up to 887.3s of the 887.3s of remaining time.
    
 Global seed set to 0
    
 Using 16bit native Automatic Mixed Precision (AMP)
    
 GPU available: True, used: True
    
 TPU available: False, using: 0 TPU cores
    
 IPU available: False, using: 0 IPUs
    
 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
    
  
    
   | Name              | Type                | Params
    
 ----------------------------------------------------------
    
 0 | model             | MultimodalFusionMLP | 13.7 M
    
 1 | validation_metric | Accuracy            | 0
    
 2 | loss_func         | CrossEntropyLoss    | 0
    
 ----------------------------------------------------------
    
 13.7 M    Trainable params
    
 0         Non-trainable params
    
 13.7 M    Total params
    
 27.305    Total estimated model params size (MB)
    
 Global seed set to 0
    
 Epoch 0, global step 1: val_accuracy reached 0.24000 (best 0.24000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/AutogluonModels/ag-20220315_003808/models/TextPredictor/epoch=0-step=1.ckpt" as top 3
    
 Epoch 0, global step 3: val_accuracy reached 0.28000 (best 0.28000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/AutogluonModels/ag-20220315_003808/models/TextPredictor/epoch=0-step=3.ckpt" as top 3
    
 Epoch 1, global step 5: val_accuracy reached 0.25000 (best 0.28000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/AutogluonModels/ag-20220315_003808/models/TextPredictor/epoch=1-step=5.ckpt" as top 3
    
 Epoch 1, global step 7: val_accuracy reached 0.27000 (best 0.28000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/AutogluonModels/ag-20220315_003808/models/TextPredictor/epoch=1-step=7.ckpt" as top 3
    
 Epoch 2, global step 9: val_accuracy reached 0.30000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/AutogluonModels/ag-20220315_003808/models/TextPredictor/epoch=2-step=9.ckpt" as top 3
    
 Epoch 2, global step 11: val_accuracy reached 0.28000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/AutogluonModels/ag-20220315_003808/models/TextPredictor/epoch=2-step=11.ckpt" as top 3
    
 Epoch 3, global step 13: val_accuracy was not in top 3
    
 Epoch 3, global step 15: val_accuracy was not in top 3
    
 Epoch 4, global step 17: val_accuracy was not in top 3
    
 Epoch 4, global step 19: val_accuracy was not in top 3
    
 Epoch 5, global step 21: val_accuracy reached 0.30000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/AutogluonModels/ag-20220315_003808/models/TextPredictor/epoch=5-step=21.ckpt" as top 3
    
 Epoch 5, global step 23: val_accuracy was not in top 3
    
 Epoch 6, global step 25: val_accuracy was not in top 3
    
 Epoch 6, global step 27: val_accuracy was not in top 3
    
 Epoch 7, global step 29: val_accuracy was not in top 3
    
     0.25     = Validation score   (accuracy)
    
     52.65s   = Training   runtime
    
     0.62s    = Validation runtime
    
 Fitting model: ImagePredictor ... Training model for up to 833.92s of the 833.92s of remaining time.
    
 ImagePredictor sets accuracy as default eval_metric for classification problems.
    
 The number of requested GPUs is greater than the number of available GPUs.Reduce the number to 1
    
 modified configs(<old> != <new>): {
    
 root.misc.seed       42 != 716
    
 root.misc.num_workers 4 != 8
    
 root.train.epochs    200 != 15
    
 root.train.early_stop_max_value 1.0 != inf
    
 root.train.batch_size 32 != 16
    
 root.train.early_stop_baseline 0.0 != -inf
    
 root.train.early_stop_patience -1 != 10
    
 root.img_cls.model   resnet101 != resnet50
    
 }
    
 Saved config to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/AutogluonModels/ag-20220315_003808/models/ImagePredictor/78a64d2a/.trial_0/config.yaml
    
 Model resnet50 created, param count:                                         23518277
    
 AMP not enabled. Training in float32.
    
 Disable EMA as it is not supported for now.
    
 Start training from [Epoch 0]
    
 [Epoch 0] training: accuracy=0.182500
    
 [Epoch 0] speed: 84 samples/sec     time cost: 4.555561
    
 [Epoch 0] validation: top1=0.230000 top5=1.000000
    
 [Epoch 0] Current best top-1: 0.230000 vs previous -inf, saved to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/AutogluonModels/ag-20220315_003808/models/ImagePredictor/78a64d2a/.trial_0/best_checkpoint.pkl
    
 [Epoch 1] training: accuracy=0.280000
    
 [Epoch 1] speed: 93 samples/sec     time cost: 4.089249
    
 [Epoch 1] validation: top1=0.220000 top5=1.000000
    
 [Epoch 2] training: accuracy=0.310000
    
 [Epoch 2] speed: 94 samples/sec     time cost: 4.081534
    
 [Epoch 2] validation: top1=0.220000 top5=1.000000
    
 [Epoch 3] training: accuracy=0.332500
    
 [Epoch 3] speed: 94 samples/sec     time cost: 4.083061
    
 [Epoch 3] validation: top1=0.230000 top5=1.000000
    
 [Epoch 4] training: accuracy=0.330000
    
 [Epoch 4] speed: 93 samples/sec     time cost: 4.095332
    
 [Epoch 4] validation: top1=0.280000 top5=1.000000
    
 [Epoch 4] Current best top-1: 0.280000 vs previous 0.230000, saved to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-tabular-v3/docs/_build/eval/tutorials/tabular_prediction/AutogluonModels/ag-20220315_003808/models/ImagePredictor/78a64d2a/.trial_0/best_checkpoint.pkl
    
 [Epoch 5] training: accuracy=0.355000
    
 [Epoch 5] speed: 93 samples/sec     time cost: 4.091876
    
 [Epoch 5] validation: top1=0.270000 top5=1.000000
    
 [Epoch 6] training: accuracy=0.377500
    
 [Epoch 6] speed: 93 samples/sec     time cost: 4.109509
    
 [Epoch 6] validation: top1=0.280000 top5=1.000000
    
 [Epoch 7] training: accuracy=0.370000
    
 [Epoch 7] speed: 93 samples/sec     time cost: 4.090519
    
 [Epoch 7] validation: top1=0.230000 top5=1.000000
    
 [Epoch 8] training: accuracy=0.390000
    
 [Epoch 8] speed: 93 samples/sec     time cost: 4.102693
    
 [Epoch 8] validation: top1=0.260000 top5=1.000000
    
 [Epoch 9] training: accuracy=0.400000
    
 [Epoch 9] speed: 93 samples/sec     time cost: 4.096608
    
 [Epoch 9] validation: top1=0.240000 top5=1.000000
    
 [Epoch 10] training: accuracy=0.365000
    
 [Epoch 10] speed: 93 samples/sec    time cost: 4.091880
    
 [Epoch 10] validation: top1=0.240000 top5=1.000000
    
 [Epoch 11] training: accuracy=0.417500
    
 [Epoch 11] speed: 93 samples/sec    time cost: 4.099170
    
 [Epoch 11] validation: top1=0.230000 top5=1.000000
    
 [Epoch 12] training: accuracy=0.417500
    
 [Epoch 12] speed: 93 samples/sec    time cost: 4.108267
    
 [Epoch 12] validation: top1=0.190000 top5=1.000000
    
 [Epoch 13] training: accuracy=0.430000
    
 [Epoch 13] speed: 93 samples/sec    time cost: 4.096501
    
 [Epoch 13] validation: top1=0.260000 top5=1.000000
    
 [Epoch 14] training: accuracy=0.447500
    
 [Epoch 14] speed: 93 samples/sec    time cost: 4.105119
    
 [Epoch 14] validation: top1=0.240000 top5=1.000000
    
 Applying the state from the best checkpoint...
    
     0.28     = Validation score   (accuracy)
    
     72.02s   = Training   runtime
    
     0.93s    = Validation runtime
    
 Fitting model: WeightedEnsemble_L2 ... Training model for up to 360.0s of the 756.98s of remaining time.
    
     0.37     = Validation score   (accuracy)
    
     0.21s    = Training   runtime
    
     0.0s     = Validation runtime
    
 AutoGluon training complete, total runtime = 143.24s ... Best model: "WeightedEnsemble_L2"
    
 TabularPredictor saved. To load, use: predictor = TabularPredictor.load("AutogluonModels/ag-20220315_003808/")

预测器拟合后,我们可以看看排行榜,看看各种模型的表现:

复制代码
 leaderboard = predictor.leaderboard(test_data)

    
  
    
             model  score_test  score_val  pred_time_test  pred_time_val   fit_time  pred_time_test_marginal  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
    
 0        LightGBMLarge    0.323775       0.37        0.016176       0.009152   2.483743                 0.016176                0.009152           2.483743            1       True          7
    
 1  WeightedEnsemble_L2    0.323775       0.37        0.418150       0.009574   2.690200                 0.401975                0.000422           0.206457            2       True         10
    
 2       NeuralNetTorch    0.319773       0.35        0.067532       0.020741   1.798623                 0.067532                0.020741           1.798623            1       True          5
    
 3             CatBoost    0.319106       0.30        0.020695       0.012886   2.882545                 0.020695                0.012886           2.882545            1       True          3
    
 4           LightGBMXT    0.315772       0.34        0.040475       0.007016   0.820213                 0.040475                0.007016           0.820213            1       True          2
    
 5              XGBoost    0.292431       0.35        0.044200       0.007139   1.641712                 0.044200                0.007139           1.641712            1       True          4
    
 6             LightGBM    0.289763       0.34        0.023030       0.006547   1.285891                 0.023030                0.006547           1.285891            1       True          1
    
 7        TextPredictor    0.285428       0.25       11.946029       0.617158  52.650396                11.946029                0.617158          52.650396            1       True          8
    
 8         VowpalWabbit    0.278760       0.24        0.823174       0.033238   0.729900                 0.823174                0.033238           0.729900            1       True          6
    
 9       ImagePredictor    0.271757       0.28       10.863207       0.932101  72.022450                10.863207                0.932101          72.022450            1       True          9

全部评论 (0)

还没有任何评论哟~