图像分割大模型SAM2 ONNX导出部署全流程

目录

1. 参考资料

2. 模型资料

SAM 2 checkpoints

 3. 环境安装

3.1 安装facebookresearch/sam2

3.1.1. 环境准备

3.1.2. 安装步骤

克隆仓库

安装依赖

3.3.3. 运行测试

基本推理示例

3.2 安装ONNX-SAM2-Segment-Anything

3.2.1. 环境准备

3.2.2. 安装步骤

(1) 克隆仓库

(2) 安装依赖

4. onnx导出(切换到sam2的目录)

第一步:下载想要的模型 

​编辑 第二步,编写导出代码

说明:

第三步 算子修改(非常重要,不要忽略)

5. 在onnx-runtime GPU上运行sam2


1. 参考资料

GitHub - ibaiGorordo/ONNX-SAM2-Segment-Anything: Python scripts for the Segment Anythin 2 (SAM2) model in ONNX

SAM2 ONNX Export.ipynb - Colab

GitHub - facebookresearch/sam2: The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.

2. 模型资料

SAM 2 checkpoints

The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows:

Model Size (M) Speed (FPS) SA-V test (J&F) MOSE val (J&F) LVOS v2 (J&F)
sam2_hiera_tiny
(config, checkpoint)
38.9 91.5 75.0 70.9 75.3
sam2_hiera_small
(config, checkpoint)
46 85.6 74.9 71.5 76.4
sam2_hiera_base_plus
(config, checkpoint)
80.8 64.8 74.7 72.8 75.8
sam2_hiera_large
(config, checkpoint)
224.4 39.7 76.0 74.6 79.8

 3. 环境安装

推荐python==3.10,或者3.9 (太高了cuda tensorRT不支持),建议重新建立一个虚拟环境

基本环境安装

我的基本环境为,可以参考,大多数的报错是各库的因为版本没有对应,这里可以使用deepseek等AI工具根据你的实际情况来推荐对应安装的其他组件,否则会冲突

python==3.9
onnx == 1.16.1
onnxsim == 0.4.36
onnxruntime ==1.18.0
torch == 2.1.2
cuda==11.8
protobuf == 5.27.3

3.1 安装facebookresearch/sam2

在Windows系统上安装SAM2(Segment Anything Model 2)需要遵循一定的步骤,并注意依赖项的版本兼容性。以下是基于搜索结果的安装指南:

3.1.1. 环境准备

  • Python版本:建议使用Python 3.10。

  • PyTorch版本:避免使用PyTorch 2.5及以上版本,否则可能报错 cuDNN Frontend error1。推荐安装 torch>=2.3.1 和 torchvision>=0.18.113。

  • CUDA支持:确保已安装兼容的CUDA和cuDNN版本。

3.1.2. 安装步骤

克隆仓库
git clone https://github.com/facebookresearch/sam2.git
cd sam2

或者在gitee手动下载项目首页 - sam2:The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model. - GitCode

安装依赖
  • 使用以下命令安装依赖项:

pip install -e .
!pip3 install onnx onnxscript onnxsim onnxruntime
  • 该命令会安装所有必要的Python包,但可能需要较长时间6。

  • 额外依赖

    • 如果遇到 cannot import name 'initialize_config_module' from 'hydra' 错误,安装 hydra_core

      pip install hydra_core

手动下载 sam2.1_hiera_large.pt 等模型文件。

3.3.3. 运行测试

基本推理示例
import torch
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

# 加载模型
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam2 = build_sam2(model_cfg, sam2_checkpoint, device="cuda")

# 生成掩码
mask_generator = SAM2AutomaticMaskGenerator(sam2)
image = cv2.imread("your_image.jpg")
masks = mask_generator.generate(image)

3.2 安装ONNX-SAM2-Segment-Anything

3.2.1. 环境准备

  • CUDA 和 cuDNN(GPU 用户):

    • 确保已安装与显卡匹配的 CUDA 工具包(如 CUDA 11.8 或 12.x)。

    • 下载对应版本的 cuDNN 并配置环境变量。

  • ONNX Runtime:根据硬件选择安装版本(GPU 或 CPU)。


3.2.2. 安装步骤

(1) 克隆仓库
git clone https://github.com/ibaiGorordo/ONNX-SAM2-Segment-Anything.git
cd ONNX-SAM2-Segment-Anything
(2) 安装依赖
pip install -r requirements.txt

4. onnx导出(切换到sam2的目录)

第一步:下载想要的模型 

项目首页 - sam2:The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model. - GitCode

图像分割大模型SAM2 ONNX导出部署全流程_第1张图片 第二步,编写导出代码

 完整代码为

from typing import Optional, Tuple, Any
import torch
from torch import nn
import torch.nn.functional as F


from sam2.modeling.sam2_base import SAM2Base

class SAM2ImageEncoder(nn.Module):
    def __init__(self, sam_model: SAM2Base) -> None:
        super().__init__()
        self.model = sam_model
        self.image_encoder = sam_model.image_encoder
        self.no_mem_embed = sam_model.no_mem_embed

    def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
        backbone_out = self.image_encoder(x)
        backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(
            backbone_out["backbone_fpn"][0]
        )
        backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(
            backbone_out["backbone_fpn"][1]
        )

        feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels:]
        vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels:]

        feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]

        # flatten NxCxHxW to HWxNxC
        vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
        vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]

        vision_feats[-1] = vision_feats[-1] + self.no_mem_embed

        feats = [feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
                 for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]

        return feats[0], feats[1], feats[2]


class SAM2ImageDecoder(nn.Module):
    def __init__(
            self,
            sam_model: SAM2Base,
            multimask_output: bool
    ) -> None:
        super().__init__()
        self.mask_decoder = sam_model.sam_mask_decoder
        self.prompt_encoder = sam_model.sam_prompt_encoder
        self.model = sam_model
        self.multimask_output = multimask_output

    @torch.no_grad()
    def forward(
            self,
            image_embed: torch.Tensor,
            high_res_feats_0: torch.Tensor,
            high_res_feats_1: torch.Tensor,
            point_coords: torch.Tensor,
            point_labels: torch.Tensor,
            mask_input: torch.Tensor,
            has_mask_input: torch.Tensor,
            img_size: torch.Tensor
    ):
        sparse_embedding = self._embed_points(point_coords, point_labels)
        self.sparse_embedding = sparse_embedding
        dense_embedding = self._embed_masks(mask_input, has_mask_input)

        high_res_feats = [high_res_feats_0, high_res_feats_1]
        image_embed = image_embed

        masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
            image_embeddings=image_embed,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embedding,
            dense_prompt_embeddings=dense_embedding,
            repeat_image=False,
            high_res_features=high_res_feats,
        )

        if self.multimask_output:
            masks = masks[:, 1:, :, :]
            iou_predictions = iou_predictions[:, 1:]
        else:
            masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(masks, iou_predictions)

        masks = torch.clamp(masks, -32.0, 32.0)
        print(masks.shape, iou_predictions.shape)

        masks = F.interpolate(masks, (img_size[0], img_size[1]), mode="bilinear", align_corners=False)

        return masks, iou_predictions

    def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:

        point_coords = point_coords + 0.5

        padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
        padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
        point_coords = torch.cat([point_coords, padding_point], dim=1)
        point_labels = torch.cat([point_labels, padding_label], dim=1)

        point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
        point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size

        point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
        point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)

        point_embedding = point_embedding * (point_labels != -1)
        point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (
                point_labels == -1
        )

        for i in range(self.prompt_encoder.num_point_embeddings):
            point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)

        return point_embedding

    def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
        mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling(input_mask)
        mask_embedding = mask_embedding + (
                1 - has_mask_input
        ) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
        return mask_embedding


model_type = 'sam2_hiera_l' #@param ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
# input_size = 768 #@param {type:"slider", min:160, max:4102, step:8}
input_size = 1024 # Bad output if anything else (for now)
multimask_output = False

if model_type == "sam2_hiera_tiny":
    model_cfg = "sam2_hiera_t.yaml"
elif model_type == "sam2_hiera_small":
    model_cfg = "sam2_hiera_s.yaml"
elif model_type == "sam2_hiera_base_plus":
    model_cfg = "sam2_hiera_b+.yaml"
else:
    model_cfg = "sam2_hiera_l.yaml"


import torch
from sam2.build_sam import build_sam2

sam2_checkpoint = f"sam2_hiera_large.pt"
# 这一device一定是cpu
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")

img=torch.randn(1, 3, input_size, input_size).cpu()

sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
high_res_feats_0, high_res_feats_1, image_embed = sam2_encoder(img)
print(high_res_feats_0.shape)
print(high_res_feats_1.shape)
print(image_embed.shape)

torch.onnx.export(sam2_encoder,
                  img,
                  f"{model_type}_encoder.onnx",
                  export_params=True,
                  opset_version=17,
                  do_constant_folding=True,
                  input_names = ['image'],
                  output_names = ['high_res_feats_0', 'high_res_feats_1', 'image_embed']
                )

sam2_decoder = SAM2ImageDecoder(sam2_model, multimask_output=multimask_output).cpu()

embed_dim = sam2_model.sam_prompt_encoder.embed_dim
embed_size = (sam2_model.image_size // sam2_model.backbone_stride, sam2_model.image_size // sam2_model.backbone_stride)
mask_input_size = [4 * x for x in embed_size]
print(embed_dim, embed_size, mask_input_size)

point_coords = torch.randint(low=0, high=input_size, size=(1, 10, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(1, 10), dtype=torch.float)
mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float)
has_mask_input = torch.tensor([0], dtype=torch.float)
orig_im_size = torch.tensor([input_size, input_size], dtype=torch.int32)

masks, scores = sam2_decoder(image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size)

torch.onnx.export(sam2_decoder,
                  (image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size),
                  "decoder.onnx",
                  export_params=True,
                  opset_version=16,
                  do_constant_folding=True,
                  input_names = ['image_embed', 'high_res_feats_0', 'high_res_feats_1', 'point_coords', 'point_labels', 'mask_input', 'has_mask_input', 'orig_im_size'],
                  output_names = ['masks', 'iou_predictions'],
                  dynamic_axes = {"point_coords": {0: "num_labels", 1: "num_points"},
                                  "point_labels": {0: "num_labels", 1: "num_points"},
                                  "mask_input": {0: "num_labels"},
                                  "has_mask_input": {0: "num_labels"}
                  }
                )

说明:

1. 模型类型与配置修改

  • 在代码开头处选择不同的 SAM2 变体(如 sam2_hiera_tinysam2_hiera_large 等):

    model_type = 'sam2_hiera_large'  # 可选: sam2_hiera_tiny/sam2_hiera_small/sam2_hiera_base_plus
    • 对应配置文件会自动加载(如 sam2_hiera_l.yaml)。

    • 注意:不同模型的计算量和显存需求差异较大,需根据硬件选择。

  • 输入分辨率
    默认 input_size = 1024,暂时不支持修改(其他分辨率可能导致输出异常)

第三步 算子修改(非常重要,不要忽略)

命令行

onnxsim {model_type}_encoder.onnx {model_type}_encoder.onnx
onnxsim decoder.onnx decoder.onnx

model_type用具体的模型代替

5. 在onnx-runtime GPU上运行sam2

切换到sam2-onnx项目中,运行下面测试代码,这段代码实现了一个基于 SAM2 (Segment Anything Model 2) 的图像分割工具,能够根据用户提供的点坐标对图像中的特定区域进行分割。代码封装了 SAM2 模型的加载、图像预处理、点坐标转换和分割结果后处理等功能。



from sam2 import SAM2Image
import cv2
import numpy as np

encoder_model_path = "sam2_hiera_l_encoder.onnx"
decoder_model_path = "decoder_fixed.onnx"
sam2_model = SAM2Image(encoder_model_path, decoder_model_path)


def scale_points(point_coords, original_size, new_size):
    # 原始图像的宽度和高度
    original_width, original_height = original_size
    # 新图像的宽度和高度
    new_width, new_height = new_size

    # 计算缩放比例
    scale_x = new_width / original_width
    scale_y = new_height / original_height

    # 缩放后的点坐标
    scaled_points = []

    for points in point_coords:
        scaled_points.append(np.round(points * np.array([scale_x, scale_y])).astype(int))
    scaled_points = np.array(scaled_points)
    return scaled_points
class SamIner():
    def __init__(self):

        # 原始图像大小
        self.original_size = (1920, 1080)
        # 新图像大小
        self.new_size = (1024, 1024)

    def __call__(self, image, box_points):
        if all(x == 0 and y == 0 for (x, y) in box_points):
            return None

        img_source = image.copy()
        img_source = cv2.resize(img_source, (1024, 1024))
        sam2_model.set_image(img_source)

        input_point = scale_points(box_points, self.original_size, self.new_size)
        input_point = [input_point]
        label_one = [1]*len(box_points)
        label_one = np.array(label_one)
        label_one = [label_one]


        point_coords = [arr.astype(np.float32) for arr in input_point]
        point_labels = [arr.astype(np.float32) for arr in label_one]

        for label_id, (point_coord, point_label) in enumerate(zip(point_coords, point_labels)):
            for i in range(point_label.shape[0]):
                sam2_model.add_point((point_coord[i][0], point_coord[i][1]), point_label[i], label_id)
            masks = sam2_model.get_masks()

        mask = masks[0]
        mask = mask*255
        mask = mask.astype(np.uint8)
        mask = cv2.resize(mask,(1920,1080))

        imgO = 0.5*mask+0.5*image[:,:,0]

        return imgO

infer = SamIner()

image = cv2.imread("example.jpg")

# 定义目标区域的点坐标(示例)
points = [(100, 200), (300, 400), (500, 600)]

# 执行分割
result = infer(image, points)

# 显示结果
cv2.imshow("Result", result)
cv2.waitKey(0)

你可能感兴趣的:(项目实战一,模型加速,人工智能,算法,nvidia,python,onnx,图像分割,模型部署)