Advertisement

记录深度学习小白研究生经历(2)

阅读量:

在实习期间的主要工作涉及对数据集的整理,在这一过程中包含了对图像格式进行转换、执行标签修改以及调整标签格式的操作

以下是我从网络上及GPT搜罗的加自己修改的代码。

目录

1. 数据集格式转换tiff转png/jpg

2. json格式标签转txt格式

3. xml格式标签转换txt格式

4. SARdet-100k数据集的图像及标签处理

1.图像处理

2. 标签处理

5. yolo格式标签处理(删除特定id标签信息)


1. 数据集格式转换tiff转png/jpg

新到的数据集采用tiff格式,在体积上非常庞大,并且在进行训练时需转换成png或jpg格式。

转换代码如下:

复制代码
 import os

    
 from osgeo import gdal
    
 from tqdm import tqdm
    
 """用于tiff格式批量转png/jpg"""
    
 file_folder = 'tiff_folder_path' # 储存tiff格式的文件夹
    
 output_folder = 'jpg/png_folder_path'  # 储存jpg格式的文件夹
    
  
    
 # 如果输出文件夹不存在,则创建它
    
 if not os.path.exists(output_folder):
    
     os.makedirs(output_folder)
    
  
    
 file_list = os.listdir(file_folder)
    
  
    
 for file_name in tqdm(file_list):
    
     file_path = os.path.join(file_folder, file_name)
    
     print(file_path)
    
  
    
     # 检查是否为 TIFF 文件
    
     if file_path.endswith('.tiff'):
    
     ds = gdal.Open(file_path)
    
     if ds is None:
    
         print(f"Failed to open {file_path}")
    
         continue
    
  
    
     driver = gdal.GetDriverByName('PNG')
    
     # driver = gdal.GetDriverByName('JPEG') 
    
     # 创建输出文件的路径
    
     output_file_path = os.path.join(output_folder, file_name.replace('.tiff', '.png'))
    
     # output_file_path = os.path.join(output_folder, file_name.replace('.tiff', '.jpg'))
    
     dst_ds = driver.CreateCopy(output_file_path, ds)
    
     del dst_ds  # 释放资源

如果想转jpg,只需要把png换成jpg

2. json格式标签转txt格式

为了解决这个问题,我们需要利用 YOLOv8 进行训练,并将标签转换为 YOLO 格式文件 (.txt 格式文件)。

格式为 <class_id> <x_center> <y_center>

坐标及宽高均为归一化值

目标json的格式如下(部分信息省略)

复制代码
 {

    
     "shapes": [
    
       {
    
     "label": "name1",
    
     "line_color": [
    
       0,
    
       0,
    
       0,
    
       255
    
     ],
    
     "fill_color": [
    
       0,
    
       0,
    
       0,
    
       255
    
     ],
    
     "points": [
    
       [
    
         5916.970277568591,
    
         1897.6849325148487
    
       ],
    
       [
    
         6001.932538131099,
    
         1816.4436468674849
    
       ]
    
     ],
    
     "probability": 10.0,
    
     "shape_type": "rectangle"
    
       },
    
       {
    
     "label": "name2",
    
     "line_color": [
    
       0,
    
       0,
    
       0,
    
       255
    
     ],
    
     "fill_color": [
    
       0,
    
       0,
    
       0,
    
       255
    
     ],
    
     "points": [
    
       [
    
         5820.845092260643,
    
         1957.8406936430474
    
       ],
    
       [
    
         5918.830765026164,
    
         1873.498595566396
    
       ]
    
     ],
    
     "probability": 10.0,
    
     "shape_type": "rectangle"
    
       }],
    
       "lineColor": [
    
     0,
    
     255,
    
     0,
    
     128
    
       ],
    
       "fillColor": [
    
     255,
    
     0,
    
     0,
    
     128
    
       ],
    
       "imagePath": "iamge_name",
    
       "imageData": null,
    
       "imageHeight": 12643,
    
       "imageWidth": 12997,
    
       "geoTrans": [
    
     0,
    
     1,
    
     0,
    
     12643,
    
     0,
    
     -1
    
       ]
    
     }

此JSON文件中的bbox框所指定的坐标为左下方点至右上方点。起初我以为该bbox框是由左上方至右下方定义的(略感疑惑),结果计算出的高度全部为负值

转换代码:

复制代码
 import json

    
 import os
    
 from tqdm import tqdm
    
 """用于此数据集标签批量转换"""
    
 json_folder_path = '/home/lyr/中电科SAR/SAR.json'
    
 txt_folder_path = '/home/lyr/中电科SAR/txt'
    
  
    
 if not os.path.exists(txt_folder_path):
    
     os.makedirs(txt_folder_path)
    
  
    
 json_filenames = os.listdir(json_folder_path)
    
  
    
 def convert(img_size, box):
    
     x_center = (box[0] + box[2])  /  2.0 
    
     y_center = (box[1] + box[3])  /  2.0
    
     
    
     w_1 = box[2] - box[0]
    
     h_1 = box[1] - box[3]
    
     
    
     x1_normal = x_center / img_size[0]
    
     y1_normal = y_center / img_size[1]
    
     
    
     w_1_normal = w_1 / img_size[0]
    
     y_1_normal = h_1 / img_size[1]
    
     
    
     return (x1_normal, y1_normal, w_1_normal, y_1_normal)
    
  
    
  
    
 def decode_json(json_floder_path, txt_folder_path, json_name, class_list):   #输入json文件夹路径 ,输出文件夹路径, json文件名, 类列表
    
     txt_file_path = txt_folder_path + '/' + json_name[0:-5] + '.txt'     # 输出路径 + / + json文件去除后缀 + .txt  =  txt文件路径
    
     txt_file = open(txt_file_path, 'w')
    
  
    
     json_path = os.path.join(json_floder_path, json_name)               # json文件夹路径和json文件名拼起来 = json文件路径
    
     data = json.load(open(json_path, 'r', encoding='utf-8'))               #读取json文件
    
  
    
     img_w = data['imageWidth']            
    
     img_h = data['imageHeight']
    
  
    
     for i in data['shapes']:
    
             
    
     x1 = float(i['points'][0][0])
    
     y1 = float(i['points'][0][1])
    
     x2 = float(i['points'][1][0])
    
     y2 = float(i['points'][1][1])        
    
     bb = (x1, y1, x2, y2)
    
     bbox = convert((img_w, img_h), bb)
    
  
    
     obj_name = i['label']
    
     obj_num = class_list.index(obj_name)
    
     obj_num = str(obj_num)
    
     txt_file.write( obj_num + " " + " ".join([str(i) for i in bbox]) + '\n')
    
  
    
  
    
 class_list = ['name1', 'name2','name3', 'name4', 'name5', 'name6']    #类别名称
    
  
    
 for json_name in tqdm(json_filenames):
    
     decode_json(json_folder_path, txt_folder_path, json_name, class_list)

也是做过这个工作之后啊认识到不同数据集json文件差别大。进行坐标转换时要注意look into convert函数的计算过程。

3. xml格式标签转换txt格式

这个直接放一个VOC格式的MSTAR数据集的xml标签

复制代码
 <annotation>

    
 	<folder>VOC2007</folder>
    
 	<filename>000002.jpg</filename>
    
 	<source>
    
 		<database>The VOC2007 Database</database>
    
 		<annotation>PASCAL VOC2007</annotation>
    
 		<image>flickr</image>
    
 		<flickrid>NULL</flickrid>
    
 	</source>
    
 	<owner>
    
 		<flickrid>NULL</flickrid>
    
 		<name>hunterlew</name>
    
 	</owner>
    
 	<size>
    
 		<width>491</width>
    
 		<height>594</height>
    
 		<depth>1</depth>
    
 	</size>
    
 	<segmented>0</segmented>
    
 	<object>
    
 		<name>3</name>
    
 		<pose>Unspecified</pose>
    
 		<truncated>0</truncated>
    
 		<difficult>0</difficult>
    
 		<bndbox>
    
 			<xmin>393</xmin>
    
 			<ymin>140</ymin>
    
 			<xmax>474</xmax>
    
 			<ymax>220</ymax>
    
 		</bndbox>
    
 	</object>
    
 	<object>
    
 		<name>0</name>
    
 		<pose>Unspecified</pose>
    
 		<truncated>0</truncated>
    
 		<difficult>0</difficult>
    
 		<bndbox>
    
 			<xmin>149</xmin>
    
 			<ymin>377</ymin>
    
 			<xmax>242</xmax>
    
 			<ymax>471</ymax>
    
 		</bndbox>
    
 	</object>
    
 </annotation>

可以说借助这个工作机会,我们可以顺便学习一些XML文件的操作;在使用XML文件时,请确保先进行正确的解析

复制代码
 import xml.etree.ElementTree as ET

    
 tree = ET.parse(xml_file_path)
    
 root = tree.getroot()
复制代码
 import xml.etree.ElementTree as ET

    
 import os
    
 from tqdm import tqdm
    
  
    
 xml_folder_path = 'path_to_xml_folder'
    
 txt_folder_path = 'path_to_txt_folder'
    
 xmls = os.listdir(xml_folder_path)
    
 if not os.path.exists(txt_folder_path):
    
     os.makedirs(txt_folder_path)
    
  
    
 for xml in tqdm(xmls):
    
     xml_file_path = xml_folder_path + '/' + xml
    
     tree = ET.parse(xml_file_path)  # 解析XML文件
    
     root = tree.getroot()
    
     
    
     prefix, suffix = os.path.splitext(xml)
    
     txt_file_path = txt_folder_path + '/' + prefix + '.txt'
    
     txt_file = open(txt_file_path, 'w')
    
     for obj in root.findall('object'):
    
     
    
     clsid = obj.find('name').text
    
     image_width = root.find('size/width').text
    
     image_height = root.find('size/height').text        
    
     xmin = float(obj.find('bndbox/xmin').text)
    
     ymin = float(obj.find('bndbox/ymin').text)
    
     xmax = float(obj.find('bndbox/xmax').text)
    
     ymax = float(obj.find('bndbox/ymax').text)
    
  
    
     x_center = (xmin + xmax) / 2 / int(image_width)
    
     y_center = (ymin + ymax) / 2 / int(image_height)
    
     width = (xmax - xmin) / int(image_width)
    
     height = (ymax - ymin) / int(image_height)
    
     bbox = (x_center, y_center, width, height)
    
     txt_file.write( clsid + " " + " ".join([str(i) for i in bbox]) + '\n')

4. SARdet-100k数据集的图像及标签处理

1.图像处理

SARdet-100k这个数据集里,图像格式有三种,(png,jpg,bmp),下面是统一格式代码

复制代码
 from PIL import Image

    
 import os
    
 from tqdm import tqdm
    
 # 设置原始文件夹路径
    
 original_folder_path = 'dataset_path'
    
 # 设置新文件夹路径,用于存放转换后的图片
    
 new_folder_path = 'output_path'
    
  
    
 # 确保新文件夹存在
    
 if not os.path.exists(new_folder_path):
    
     os.makedirs(new_folder_path)
    
  
    
 # 支持的图片格式
    
 extensions = ['.jpg', '.jpeg', '.png', '.bmp']
    
  
    
 # 遍历原始文件夹中的所有文件
    
 for filename in tqdm(os.listdir(original_folder_path)):
    
     # 检查文件扩展名是否在支持的列表中
    
     if any(filename.lower().endswith(ext) for ext in extensions):
    
     # 构造完整的原始文件路径
    
     original_file_path = os.path.join(original_folder_path, filename)
    
     
    
     # 打开图片
    
     with Image.open(original_file_path) as img:
    
         # 构造新的文件名,转换为.png格式
    
         new_filename = os.path.splitext(filename)[0] + '.png'
    
         # 构造新的文件路径,保存在新文件夹中
    
         new_file_path = os.path.join(new_folder_path, new_filename)
    
         
    
         # 转换图片格式并保存到新文件夹
    
         img.save(new_file_path, 'PNG')
    
  
    
         
    
  
    
 print('All supported images have been converted to PNG format in the new folder.')

2. 标签处理

此数据集的标签为json格式,数据集的全部信息存储在一个json文件中

转为yolo格式标签的代码

复制代码
 import json

    
 import os
    
 from tqdm import tqdm
    
 """用于SARdet100K数据集标签批量转换"""
    
 json_path = 'json_file_path'
    
 txt_folder_path = 'output_path'
    
  
    
 def convert(img_size, box):
    
     x1_center = box[0] + box[2] / 2.0
    
     y1_center = box[1] + box[3] / 2.0    
    
     w_1 = box[2]
    
     h_1 = box[3]    
    
     x1_normal = x1_center / img_size[0]
    
     y1_normal = y1_center / img_size[1]    
    
     w_1_normal = w_1 / img_size[0]
    
     y_1_normal = h_1 / img_size[1]    
    
     return (x1_normal, y1_normal, w_1_normal, y_1_normal)
    
  
    
 with open(json_path, 'r', encoding='utf-8') as f:
    
     data = json.load(f)
    
     
    
 data_items = data['images']
    
 data_bbox_name = data['annotations']
    
  
    
 for item in tqdm(data_items):
    
     data_name = item['file_name']
    
     prefix, suffix = os.path.splitext(data_name)
    
     txt_file_path = txt_folder_path + '/' + prefix + '.txt'
    
     txt_file = open(txt_file_path, 'w')
    
     for ann in data_bbox_name:       
    
     if item['id'] == ann['image_id']:            
    
         bb = ann['bbox']             #list 
    
         img_w = item['width']
    
         img_h = item['height']
    
         bbox = convert((img_w, img_h), bb)
    
         category_id = ann['category_id']
    
         txt_file.write( str(category_id) + " " + " ".join([str(i) for i in bbox]) + '\n')

5. yolo格式标签处理(删除特定id标签信息)

我的工作需要将多个数据集整合在一起,并且标签的一致性至关重要。在此过程中,应删除一些在训练过程中没有用处的目标类别信息。

复制代码
 import os

    
  
    
 def remove_specific_classes(label_file, output_file, classes_to_remove):
    
     with open(label_file, 'r') as f:
    
     lines = f.readlines()
    
     with open(output_file, 'w') as f:
    
     for line in lines:
    
         parts = line.strip().split()
    
         if len(parts) >= 1 and int(parts[0]) not in classes_to_remove:
    
             f.write(line)
    
 # 使用示例
    
 txt_folder_path = 'your_txt_folder_path'
    
 output_file = 'output_paht'  # 输出文件路径
    
 classes_to_remove = [num1, num1,.......]  # 要删除的类别索引列表,替换成自己想删除的
    
 txt_files = os.listdir(txt_folder_path)
    
 for txt in txt_files:
    
     txt_file_path = txt_folder_path + '/' + txt   # 标签文件路径
    
     txt_correct_path = output_file + '/' + txt
    
     remove_specific_classes(txt_file_path, txt_correct_path, classes_to_remove)

全部评论 (0)

还没有任何评论哟~