yolov12 多进程

import cv2
import torch
import numpy as np
from ultralytics import YOLO
import argparse
import os
import os.path as osp
import multiprocessing
import simplepyutils as spu
import simplepyutils.argparse as spu_argparse
from simplepyutils.argparse import FLAGS
from collections import defaultdict
import time

def rot_detections_back(det, rotated_imshape):
    H, W = rotated_imshape[:2]
    x, y, w, h = det[..., 0], det[..., 1], det[..., 2], det[..., 3]
    if FLAGS.rot == 90:
        new_x = H - (y + h)
        new_y = x
        new_w = h
        new_h = w
    elif FLAGS.rot == 180:
        new_x = W - (x + w)
        new_y = H - (y + h)
        new_w = w
        new_h = h
    elif FLAGS.rot == 270:
        new_x = y
        new_y = W - (x + w)
        new_w = h
        new_h = w
    else:
        raise ValueError(f'Invalid rotation {FLAGS.rot}')
    
    rotated_back = np.stack([new_x, new_y, new_w, new_h], axis=1)
    return rotated_back.astype(np.float32)

def process_subdir(args):
    """处理单个子目录并保存独立pkl"""
    subdir_info, gpu_id = args
    subdir_path, output_path = subdir_info
    
    # 初始化模型
    detector = YOLO(FLAGS.model_path, verbose=False).to(f'cuda:{gpu_id}')
    detector.eval()
    
    img_files = ['%s/%s' % (i[0].replace("\\", "/"), j) for i in os.walk(subdir_path) for j in i[-1] if j.lower().endswith(('.jpg'))]
    
    subdir_results = {}
    for img_file in img_files:
        try:
            # 图像读取和处理
            img = cv2.imread(img_file)
            if img is None or img.size < 100:  # 过滤小文件
                continue
                
            # 推理
            start_time = time.time()
            with torch.no_grad():
                outputs = detector(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), 
                                verbose=False)
            
            # 结果解析
            detections = []
            for det in outputs:
                if det is None or len(det) == 0:
                    detections.append(np.zeros((0, 6), dtype=np.float32))
                    continue
                
                # 过滤低置信度检测
                indices = torch.where(det.boxes.conf > FLAGS.threshold)[0]
                if len(indices) == 0:
                    detections.append(np.zeros((0, 6), dtype=np.float32))
                    continue
                
                # 坐标转换
                boxes = det.boxes.xyxy[indices]
                x1, y1, x2, y2 = boxes.unbind(1)
                xywh = torch.stack([x1, y1, x2-x1, y2-y1], dim=1)
                
                # 构建检测结果
                det_array = torch.cat([
                    xywh,
                    det.boxes.conf[indices].unsqueeze(1),
                    det.boxes.cls[indices].unsqueeze(1)
                ], dim=1).cpu().numpy()
                
                if FLAGS.rot:
                    det_array = rot_detections_back(det_array, img.shape[:2])
                detections.append(det_array)
            
            subdir_results[img_file] = detections[0]
            print(f'Processed {img_file} in {time.time()-start_time:.2f}s')
        
        except Exception as e:
            print(f"Error in {img_file}: {str(e)}")
    
    # 保存当前子目录结果
    print('save path',output_path)
    spu.dump_pickle(subdir_results, output_path)
    return output_path

def main():
    root_dir = '/vepfs-cnbjc5abe70e7537/lbg/human_pose'
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_root', default=f"{root_dir}/arctic/images_zips")
    parser.add_argument('--out_dir', default=f"{root_dir}/agora/detections")  # 修改为输出目录
    parser.add_argument('--model_path', default='/shared_disk/models/others/yolov12x.pt')
    parser.add_argument('--threshold', type=float, default=0.2)
    parser.add_argument('--file-pattern', default='**/*.png')
    parser.add_argument('--rot', type=int, default=0)
    spu_argparse.initialize(parser)

    # 创建输出目录
    os.makedirs(FLAGS.out_dir, exist_ok=True)

    zip_files = ['%s/%s' % (i[0].replace("\\", "/"), j) for i in os.walk(FLAGS.image_root) for j in i[-1] if j.lower().endswith(('.zip'))]
    gpu_ids=[3,5,6,7]
    # 准备任务参数
    task_args = []
    for idx, zip_file in enumerate(zip_files):
        
        # 生成输出路径:保留原始目录结构
        subdir = zip_file[:-4]
        output_path =zip_file[:-4]+".pkl"
        
        if os.path.exists(subdir) and os.path.exists(zip_file):
        # 分配GPU
            gpu_index = idx % len(gpu_ids)
            task_args.append(((subdir, output_path),gpu_ids[gpu_index]))

    # 启动多进程
    with multiprocessing.Pool(processes=len(task_args)) as pool:  # 每个子目录一个进程
        results = pool.map(process_subdir, task_args)

    print(f'Processing complete. Results saved to: {FLAGS.out_dir}')

if __name__ == '__main__':
    main()

你可能感兴趣的:(python基础,YOLO)