掘金 人工智能 07月17日 17:40
手撕ultralytics,换用Lightning训练yolo模型
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文详细介绍了如何摆脱ultralytics库的model.train()函数,转而采用标准的PyTorch训练模式来训练YOLO模型。文章首先阐述了直接使用ultralytics进行模型训练的局限性,然后详细展示了如何通过自定义Dataset和DataLoader,以及对YOLO模型进行封装,使其能够接入标准的PyTorch训练流程。此外,文中还提到了在使用torchvision数据增强时遇到的权重性能问题,强调了使用ultralytics自带增强的重要性,并引入了Lightning库来简化训练代码的编写。最终目标是实现更灵活、更可控的模型训练,满足更高的自定义需求。

🎯 **自定义数据准备**:为了实现更灵活的训练,需要构造符合ultralytics要求的Dataset。该Dataset需包含'img'、'bboxes'(xywh格式)、'cls'、'bbox_format'、'normalized'、'ori_shape'、'ratio_pad'等关键信息,并实现`__init__`(包含数据增强逻辑)、`__len__`和`__getitem__`方法。特别强调了使用ultralytics自带的数据增强,而非torchvision,以避免预训练权重性能下降的问题。

📦 **数据加载器优化**:文章提供了从ultralytics库中摘抄的`collate_fn()`函数,用于将自定义Dataset生成的样本批次进行正确整合,以供PyTorch DataLoader使用。通过实例化DataLoader并传入自定义的`collate_fn`,确保了数据能够以正确的格式输入模型。

🛠️ **模型封装与训练流程**:通过创建一个继承自ultralytics.nn.tasks.DetectionModel的`YOLOModule`类,将YOLO模型进行了封装,使其能够方便地加载预训练权重,并定义了`forward`、`get_loss`和`get_bboxes`等核心方法。随后,利用Lightning库的`BaseModule`,定义了`LightningModel`,将封装后的YOLO模型接入Lightning的训练框架,实现了训练和验证步骤的逻辑拆分,并详细阐述了如何获取最终的loss。

⚡ **Lightning简化训练**:文章引入Lightning框架,旨在简化PyTorch训练代码的编写。通过`LightningModel`的`training_step`和`validation_step`方法,将训练和验证过程中的数据处理、模型前向传播、loss计算及日志记录等逻辑进行了清晰的封装,使得整个训练流程更加易于管理和理解。

YOLO 模型作为目标检测的一座高峰不必多说,快又好用。一般来说是用叫做 ultralytics 的 Python 库使用和训练 YOLO 模型。库写得非常好,能很简便地用一个函数启用模型训练。

from ultralytics import YOLOmodel = YOLO("yolo12l.pt")results = model.train(    data="/mnt/sda/data/20250312_SARDet100K/sar100k.yaml",    epochs=100,    imgsz=640,)

但如果有更高的自定义需求,这种一键训练的方式就不够用了。如果能把训练代码写成以下标准的 PyTorch 训练形式,那添加自定义修改就方便多了。

train_loader = DataLoader(train_dataset, batch_size=..., shuffle=True)val_loader = DataLoader(val_dataset, batch_size=...)for epoch in range(...):    model.train()    for x, y in train_loader:        x, y = x.to(device), y.to(device)        pred = model(x)        loss = criterion(pred, y)        optimizer.zero_grad()        loss.backward()        optimizer.step()    model.eval()    with torch.no_grad():        for x, y in val_loader:            x, y = x.to(device), y.to(device)            pred = model(x)            val_loss = criterion(pred, y)

经过几周的鏖战,终于是把 ultralytics 手撕得差不多,摆脱了 model.train() 的束缚。现在能自由训练目标检测了。

总览

ultralytics 库的逻辑写得很紧凑,完全改写是相当困难的。比较现实的修改方法是借用和继承原库的一些库和方法,使用符合 ultralytics 的数据形式。

中途还遇到了个奇怪的问题。使用 torchvision 的数据增强方法会损坏 YOLO 预训练权重性能,必须用 ultralytics 的数据增强。即使是很小心地控制变量、只选择两者都有的数据增强方法,肉眼完全看不出图像和标注框差异,实验都只能得出一样的结果。那就这样吧。

Lightning 是一个辅助编写 PyTorch 训练代码的库,可以把像是训练循环封装成一个函数,不论是编写还是查阅都会轻松许多。即使没接触过 Lightning 也没关系,后文看函数名也能知道我写的啥逻辑。

本文尽可能简化代码逻辑,主要起示例作用。

数据准备

Dataset

需要构造出一个符合 ultralytics 吸怪的数据集。这个数据集需要是一个字典,包含这些键:

具体实现看代码。

from torch.utils.data import Datasetfrom ultralytics.data.augment import (    Compose,    Format,    LetterBox,    RandomPerspective,    RandomHSV,    RandomFlip,)from ultralytics.utils.ops import resample_segmentsfrom ultralytics.utils.instance import Instancesclass MyDataset(Dataset):    def __init__(self, dataset):        self.dataset = dataset        pre_transform = RandomPerspective(            degrees=0.0,            translate=0.0,            scale=0.5,            shear=0.0,            perspective=0.0,            pre_transform=LetterBox(new_shape=(512, 512), scaleup=False),        )        self.transforms = Compose(            [                pre_transform,                RandomHSV(hgain=0.015, sgain=0.7, vgain=0.4),                RandomFlip(direction="vertical", p=0.0),                RandomFlip(direction="horizontal", p=0.5),            ]        )        self.transforms.append(            Format(                bbox_format="xywh",                normalize=True,                return_mask=False,                return_keypoint=False,                return_obb=False,                batch_idx=True,                mask_ratio=4,                mask_overlap=True,                bgr=0.0,            )        )    def __len__(self):        return len(self.dataset)    def update_labels_info(self, label: Dict) -> Dict:        """        Update label format for different tasks.        Args:            label (dict): Label dictionary containing bboxes, segments, keypoints, etc.        Returns:            (dict): Updated label dictionary with instances.        Note:            cls is not with bboxes now, classification and semantic segmentation need an independent cls label            Can also support classification and semantic segmentation by adding or removing dict keys there.        """        bboxes = label.pop("bboxes")        segments = label.pop("segments", [])        keypoints = label.pop("keypoints", None)        bbox_format = label.pop("bbox_format")        normalized = label.pop("normalized")        # NOTE: do NOT resample oriented boxes        segment_resamples = 1000        if len(segments) > 0:            # make sure segments interpolate correctly if original length is greater than segment_resamples            max_len = max(len(s) for s in segments)            segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples            # list[np.array(segment_resamples, 2)] * num_samples            segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)        else:            segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)        bboxes = bboxes if bboxes.size else np.zeros((0, 4), dtype=np.float32)        label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)        return label    def __getitem__(self, idx):        image_path, annotations = self.dataset[idx]        with Image.open(image_path) as img:            this_img = img.convert("RGB")        original_size = this_img.size        boxes = []        classes = []        for one_box in annotations:            bbox = one_box["bbox"]            category_id = one_box["category_id"]            x, y, w, h = bbox            boxes.append([x, y, w, h])            classes.append([category_id])        bboxes = np.array(boxes, dtype=np.float32)        cls = np.array(classes, dtype=np.float32)        label = {            'img': np.array(this_img),            'bboxes': bboxes,            'cls': cls,            'bbox_format': 'xywh',            'normalized': True,            'ori_shape': original_size,            'ratio_pad': None,        }        label = self.update_labels_info(label)        label = self.transforms(label)        label["img"] = label["img"] / 255.0        return label

DataLoader

从 ultralytics 摘抄 collate_fn(),之后要传入到 DataLoader 代替默认 collator。

def collate_fn(batch: List[Dict]) -> Dict:    """    Collate data samples into batches.    Args:        batch (List[dict]): List of dictionaries containing sample data.    Returns:        (dict): Collated batch with stacked tensors.    """    new_batch = {}    batch = [dict(sorted(b.items())) for b in batch]  # make sure the keys are in the same order    keys = batch[0].keys()    values = list(zip(*[list(b.values()) for b in batch]))    for i, k in enumerate(keys):        value = values[i]        if k in {"img", "text_feats"}:            value = torch.stack(value, 0)        elif k == "visuals":            value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)        if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:            value = torch.cat(value, 0)        new_batch[k] = value    new_batch["batch_idx"] = list(new_batch["batch_idx"])    for i in range(len(new_batch["batch_idx"])):        new_batch["batch_idx"][i] += i  # add target image index for build_targets()    new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)        return new_batch

实例化 dataloader。

from torch.utils.data import DataLoadertrain_dataset = MyDataset(train_dataset)train_loader = DataLoader(    train_dataset,    batch_size=16,    shuffle=True,    collate_fn=collate_fn,    num_workers=4,    pin_memory=True,)

模型定义

这一步会给原来的 YOLO 模型套层壳,方便后面使用。

from types import SimpleNamespacefrom ultralytics.nn.tasks import DetectionModelfrom ultralytics.utils import opsclass YOLOModule(DetectionModel):    def __init__(self, num_class, channels, model="yolo11n.pt", pretrained=False):        model = YOLO(model)        cfg = model.yaml        args = model.args        args.update(            {                "box": 7.5,                "cls": 0.5,                "dfl": 1.5,            }        )        self.args = SimpleNamespace(**args)        self.overrides = args        super().__init__(cfg, nc=num_class, ch=channels, verbose=False)        if pretrained:            self.load(model.model)    def forward(self, x):        preds = self.predict(x)        return preds    def get_loss(self, batch, preds):        return self.loss(batch, preds)[0]    def get_bboxes(self, preds):        preds = ops.non_max_suppression(            preds,            conf_thres=0.25,            iou_thres=0.7,            max_det=300,            return_idxs=False,        )        return preds

训练代码 / Lightning Module 定义

以下代码主要看 training_step()validation_step() 的逻辑,看是如何得到最终的 loss 的(Lightning 会帮忙调用 loss.backward() 等函数)。

class LightningModel(BaseModule):    def __init__(self, model):        super().__init__()        self.model = model    def forward(self, x):        return self.model(x)    def training_step(self, batch, batch_idx):        x = batch['img']        batch_size = x.shape[0]        loss_out = self(x)        loss = self.model.get_loss(            batch=batch,            preds=loss_out,        )        box_loss, cls_loss, dfl_loss = loss / batch_size        loss = box_loss + cls_loss + dfl_loss        self.log('train/loss', loss, on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True)                return loss    def validation_step(self, batch, batch_idx):        x = batch['img']        batch_size = x.shape[0]        inference_out, loss_out = self(x)        loss = self.model.get_loss(            batch=batch,            preds=loss_out,        )        box_loss, cls_loss, dfl_loss = loss / batch_size        loss = box_loss + cls_loss + dfl_loss        self.log('val/loss', loss, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True)        return loss

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

YOLO 目标检测 PyTorch ultralytics Lightning
相关文章