Advertisement

YOLOv8训练自己的数据集

阅读量:

1. 配置自己的.yaml文件

当加载这个.yaml文件时,能根据里面的路径找到相应的数据。

复制代码
 # 数据集的根目录

    
 path: D:\edge\YOLO\YOLOv8_Official_Version\ultralytics\cfg\datasets
    
  
    
 # 训练数据的目录(.jpg目录)
    
 train: D:\edge\YOLO\YOLOv8_Official_Version\ultralytics\cfg\datasets\images\train
    
  
    
 # 验证数据的目录(.jpg目录)
    
 val: D:\edge\YOLO\YOLOv8_Official_Version\ultralytics\cfg\datasets\images\val
    
  
    
 # 测试数据的目录(.jpg目录)
    
 test: D:\edge\YOLO\YOLOv8_Official_Version\ultralytics\cfg\datasets\images\test
    
  
    
 # 类别的数量
    
 nc: 6
    
  
    
 # class names
    
 names:
    
   0: ore carrier
    
   1: passenger ship
    
   2: container ship
    
   3: bulk cargo carrier
    
   4: general cargo ship  
    
   5: fishing boat
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/lgxBynJLA7w2pC0KrGT5ZMPtuo8H.png)

2. 标签数据格式转换 xml格式转换成txt格式

复制代码
 # .xml文件转换成.txt文件

    
  
    
 import copy
    
 from lxml.etree import Element, SubElement, tostring, ElementTree
    
 import xml.etree.ElementTree as ET
    
 import pickle
    
 import os
    
 from os import listdir, getcwd
    
 from os.path import join
    
  
    
 # 检测目标的类别
    
 classes = ["ore carrier", "passenger ship",
    
        "container ship", "bulk cargo carrier",
    
        "general cargo ship", "fishing boat"]
    
  
    
 CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
    
  
    
 def convert(size, box):
    
     dw = 1. / size[0]
    
     dh = 1. / size[1]
    
     x = (box[0] + box[1]) / 2.0    # (x_min + x_max) / 2.0
    
     y = (box[2] + box[3]) / 2.0    # (y_min + y_max) / 2.0
    
     w = box[1] - box[0]   # x_max - x_min
    
     h = box[3] - box[2]   # y_max - y_min
    
     x = x * dw
    
     w = w * dw
    
     y = y * dh
    
     h = h * dh
    
     return (x, y, w, h)
    
  
    
 def convert_annotation(image_id):
    
     # .xml格式文件的地址
    
     in_file = open('地址1\%s.xml' % (image_id), encoding='UTF-8')
    
  
    
     # 生成的.txt格式文件的地址
    
     out_file = open('地址2\%s.txt' % (image_id), 'w')
    
     
    
     tree = ET.parse(in_file)
    
     root = tree.getroot()
    
     size = root.find('size')
    
     w = int(size.find('width').text)
    
     h = int(size.find('height').text)
    
  
    
     for obj in root.iter('object'):
    
     cls = obj.find('name').text
    
     
    
     if cls not in classes:
    
         continue
    
     cls_id = classes.index(cls)
    
     xmlbox = obj.find('bndbox')
    
     b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
    
          float(xmlbox.find('ymax').text))
    
     bb = convert((w, h), b)
    
     out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
    
  
    
 # .xml格式文件的地址
    
 xml_path = os.path.join(CURRENT_DIR, '地址1/')
    
  
    
 # xml列表
    
 img_xmls = os.listdir(xml_path)
    
 for img_xml in img_xmls:
    
     label_name = img_xml.split('.')[0]
    
     print(label_name)
    
     convert_annotation(label_name)
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/OoN39Sjid2ZmHlMKUXWuLknhf8FC.png)

3. 数据集划分 train val test

复制代码
 # 生成train val teat 数据集文件和标签文件

    
  
    
 import os, shutil, random
    
 from tqdm import tqdm
    
  
    
 def split_img(img_path, label_path, split_list):
    
     try:
    
     # Data是要生成的train,val,test的文件夹  ultralytics\cfg\datasets
    
     Data = '地址1'
    
  
    
     train_img_dir = Data + '/images/train'
    
     val_img_dir = Data + '/images/val'
    
     test_img_dir = Data + '/images/test'
    
  
    
     train_label_dir = Data + '/labels/train'
    
     val_label_dir = Data + '/labels/val'
    
     test_label_dir = Data + '/labels/test'
    
  
    
     # 创建文件夹
    
     os.makedirs(train_img_dir)
    
     os.makedirs(train_label_dir)
    
     os.makedirs(val_img_dir)
    
     os.makedirs(val_label_dir)
    
     os.makedirs(test_img_dir)
    
     os.makedirs(test_label_dir)
    
  
    
     except:
    
     print('文件目录已存在')
    
  
    
     train, val, test = split_list
    
     all_img = os.listdir(img_path)
    
     all_img_path = [os.path.join(img_path, img) for img in all_img]
    
     train_img = random.sample(all_img_path, int(train * len(all_img_path)))
    
     train_img_copy = [os.path.join(train_img_dir, img.split('\ ')[-1]) for img in train_img]
    
     train_label = [toLabelPath(img, label_path) for img in train_img]
    
     train_label_copy = [os.path.join(train_label_dir, label.split('\ ')[-1]) for label in train_label]
    
     
    
     for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'):
    
     _copy(train_img[i], train_img_dir)
    
     _copy(train_label[i], train_label_dir)
    
     all_img_path.remove(train_img[i])
    
     val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path)))
    
     val_label = [toLabelPath(img, label_path) for img in val_img]
    
     
    
     for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'):
    
     _copy(val_img[i], val_img_dir)
    
     _copy(val_label[i], val_label_dir)
    
     all_img_path.remove(val_img[i])
    
     test_img = all_img_path
    
     test_label = [toLabelPath(img, label_path) for img in test_img]
    
     
    
     for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'):
    
     _copy(test_img[i], test_img_dir)
    
     _copy(test_label[i], test_label_dir)
    
  
    
  
    
 def _copy(from_path, to_path):
    
     shutil.copy(from_path, to_path)
    
  
    
  
    
 def toLabelPath(img_path, label_path):
    
     img = img_path.split('\ ')[-1]
    
     label = img.split('.jpg')[0] + '.txt'
    
     return os.path.join(label_path, label)
    
  
    
  
    
 if __name__ == '__main__':
    
     # 输入的图像路径
    
     img_path = '地址2'
    
     
    
     # 输入的标签路径
    
     label_path = '地址3'
    
  
    
     # 数据集的划分比例
    
     split_list = [0.81, 0.09, 0.1]
    
     split_img(img_path, label_path, split_list)
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/kgUpDERVfw12vQqxbB3T8n9iYoAd.png)

4. 数据集存放结构

路径:ultralytics\cfg\datasets

images(文件夹)

train(文件夹)

val(文件夹)

test(文件夹)

labels(文件夹)

train(文件夹)

val(文件夹)

test(文件夹)

标签数据格式转换.py

数据集划分.py

seaships.yaml(数据加载文件,自定义数据集中所有目标对象的类别信息)

5. 训练数据

路径:ultralytics\cfg\models\v8\yolov8.yaml 模型加载文件 只需要修改类别数nc即可

路径:ultralytics 存放一个预训练权重文件 yolov8s.pt

注释:在设置模型规模(n,s,m,l,x)的时候,直接通过Model_yaml参数来设置:model_yaml=r"D:\edge\YOLO\YOLOv8_Official_Version\ultralytics\cfg\models\v8\yolov8s.yaml",虽然该路径下没有yolov8s.yaml的文件,但是v8可以识别出来选择的模型类型。

注释:可以适当地调小batch和workers,不然可能会报错。

注释:路径:ultralytics\cfg\default.yaml 其他参数可以在这里进行设置

复制代码
 from ultralytics import YOLO

    
  
    
 if __name__ == "__main__":
    
     
    
     # 加载模型
    
     model_yaml = r"地址1\yolov8s.yaml"
    
     data_yaml = r"地址2\seaships.yaml"
    
     pre_model = r"地址3\yolov8s.pt"
    
  
    
     # build from YAML and transfer weights
    
     model = YOLO(model_yaml, task='detect').load(pre_model)
    
  
    
     # 训练模型
    
     results = model.train(data=data_yaml, epochs=100, imgsz=640, batch=4, workers=2)
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/ZgQNeBt0CDEijvnW7AcqdFps4URl.png)

全部评论 (0)

还没有任何评论哟~