yolo训练

发布时间:2024-10-31 07:43:17
修改时间:2025-04-29 15:28:15
总阅读数:68
今日阅读数:0
昨日阅读数:0
字数:3557

使用yolo实现自定义的目标检测

安装yolo依赖

pip3 install -U ultralytics 

新建数据集

目录结构:

├── images
│   ├── train
│   └── val
├── labels
│   ├── train
│   └── val
└── train.yaml

images/train 训练图片

images/train 验证图片

labels/train 训练标签

labels/val 验证标签

train.yaml

path: D:\yolo-test\dataset  # dataset root dir
train: images/train  # train images (relative to 'path')  
val: images/val  # val images (relative to 'path') 
test:  # test images (optional)

# Classes
names:
  0: label0
  1: label1

物品标注

下载标注工具,推荐使用 anylabeling 

https://github.com/vietanhdev/anylabeling 

格式转换

anylabeling标注后的格式是json格式,需要先转换成yolo的标签格式

json2txt.py

"""
将anyLabel生成的标签转为yolo的txt格式的标签
"""
import json
import os

import numpy as np
from ultralytics.utils.ops import xyxy2xywh


def obj2lines(obj):
    """
    将obj对象转为多行文本格式
    :param obj: obj对象
    :return:  标签多行格式
    """
    shapes = obj['shapes']
    lines = []

    height = obj['imageHeight']
    width = obj['imageWidth']

    for shape in shapes:
        label = shape['label']
        points_np = np.array(shape['points'])
        points_np = points_np / np.array([width, height])
        points_np = np.round(points_np, 7)
        points_np.shape = 1, points_np.shape[0] * points_np.shape[1]

        xywh = xyxy2xywh(points_np[0])

        line = typeDict[label] + " " + ' '.join(map(str, xywh))

        print(line)
        lines.append(line)
    return "\n".join(lines)


typeDict = {
    "label0": "0",
    "label1": "1",
}

folder_path = r"D:\yolo-test\dataset\labels\train"

if __name__ == '__main__':
    # 指定要读取的文件夹路径

    # 获取文件夹中所有文件的列表
    file_list = os.listdir(folder_path)

    # 遍历文件列表
    for file_name in file_list:
        if file_name.endswith(".json"):
            file_path = os.path.join(folder_path, file_name)
            with open(file_path, 'r') as file:
                # 读取文件内容
                data = file.read()
                # 解析为对象
                json_data = json.loads(data)

                # 这里可以对json_data进行操作
                print(json_data)
                par_res = obj2lines(json_data)
                print(par_res)

                file_path = folder_path + "\\" + file_name.replace(".json", ".txt")

                # 打开文件,写入模式为 'w',这会覆盖文件内容
                with open(file_path, 'w') as out_file:
                    out_file.write(par_res)

运行python代码进行格式转换

python json2txt.py

训练

train.py

import ultralytics
from ultralytics import YOLO

print("当前版本", ultralytics.__version__)

if __name__ == '__main__':
    model = YOLO('yolo11n.pt')

    model.train(data=r'D:\yolo-test\dataset\train.yaml', epochs=100,
                workers=2)

    print("训练完成")

epochs和workers的值,根据实际情况进行调整

预测

predict.py

import cv2
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO(r'D:\yolo-test\runs\detect\train\weights\last.pt')

    # 使用YOLO进行检测
    results = model.predict(source=r"D:\yolo-test\abc.jpg")

    res_frame = results[0].plot()

    cv2.imshow("YOLO res", res_frame)

    cv2.waitKey(0)