目录
1. 参考资料
2. 模型资料
SAM 2 checkpoints
3. 环境安装
3.1 安装facebookresearch/sam2
3.1.1. 环境准备
3.1.2. 安装步骤
克隆仓库
安装依赖
3.3.3. 运行测试
基本推理示例
3.2 安装ONNX-SAM2-Segment-Anything
3.2.1. 环境准备
3.2.2. 安装步骤
(1) 克隆仓库
(2) 安装依赖
4. onnx导出(切换到sam2的目录)
第一步:下载想要的模型
编辑 第二步,编写导出代码
说明:
第三步 算子修改(非常重要,不要忽略)
5. 在onnx-runtime GPU上运行sam2
GitHub - ibaiGorordo/ONNX-SAM2-Segment-Anything: Python scripts for the Segment Anythin 2 (SAM2) model in ONNX
SAM2 ONNX Export.ipynb - Colab
GitHub - facebookresearch/sam2: The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows:
Model | Size (M) | Speed (FPS) | SA-V test (J&F) | MOSE val (J&F) | LVOS v2 (J&F) |
---|---|---|---|---|---|
sam2_hiera_tiny (config, checkpoint) |
38.9 | 91.5 | 75.0 | 70.9 | 75.3 |
sam2_hiera_small (config, checkpoint) |
46 | 85.6 | 74.9 | 71.5 | 76.4 |
sam2_hiera_base_plus (config, checkpoint) |
80.8 | 64.8 | 74.7 | 72.8 | 75.8 |
sam2_hiera_large (config, checkpoint) |
224.4 | 39.7 | 76.0 | 74.6 | 79.8 |
推荐python==3.10,或者3.9 (太高了cuda tensorRT不支持),建议重新建立一个虚拟环境
基本环境安装
我的基本环境为,可以参考,大多数的报错是各库的因为版本没有对应,这里可以使用deepseek等AI工具根据你的实际情况来推荐对应安装的其他组件,否则会冲突
python==3.9
onnx == 1.16.1
onnxsim == 0.4.36
onnxruntime ==1.18.0
torch == 2.1.2
cuda==11.8
protobuf == 5.27.3
在Windows系统上安装SAM2(Segment Anything Model 2)需要遵循一定的步骤,并注意依赖项的版本兼容性。以下是基于搜索结果的安装指南:
Python版本:建议使用Python 3.10。
PyTorch版本:避免使用PyTorch 2.5及以上版本,否则可能报错 cuDNN Frontend error
1。推荐安装 torch>=2.3.1
和 torchvision>=0.18.1
13。
CUDA支持:确保已安装兼容的CUDA和cuDNN版本。
git clone https://github.com/facebookresearch/sam2.git
cd sam2
或者在gitee手动下载项目首页 - sam2:The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model. - GitCode
使用以下命令安装依赖项:
pip install -e .
!pip3 install onnx onnxscript onnxsim onnxruntime
该命令会安装所有必要的Python包,但可能需要较长时间6。
额外依赖:
如果遇到 cannot import name 'initialize_config_module' from 'hydra'
错误,安装 hydra_core
:
pip install hydra_core
手动下载 sam2.1_hiera_large.pt
等模型文件。
import torch
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
# 加载模型
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam2 = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
# 生成掩码
mask_generator = SAM2AutomaticMaskGenerator(sam2)
image = cv2.imread("your_image.jpg")
masks = mask_generator.generate(image)
CUDA 和 cuDNN(GPU 用户):
确保已安装与显卡匹配的 CUDA 工具包(如 CUDA 11.8 或 12.x)。
下载对应版本的 cuDNN 并配置环境变量。
ONNX Runtime:根据硬件选择安装版本(GPU 或 CPU)。
git clone https://github.com/ibaiGorordo/ONNX-SAM2-Segment-Anything.git
cd ONNX-SAM2-Segment-Anything
pip install -r requirements.txt
项目首页 - sam2:The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model. - GitCode
完整代码为
from typing import Optional, Tuple, Any
import torch
from torch import nn
import torch.nn.functional as F
from sam2.modeling.sam2_base import SAM2Base
class SAM2ImageEncoder(nn.Module):
def __init__(self, sam_model: SAM2Base) -> None:
super().__init__()
self.model = sam_model
self.image_encoder = sam_model.image_encoder
self.no_mem_embed = sam_model.no_mem_embed
def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
backbone_out = self.image_encoder(x)
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(
backbone_out["backbone_fpn"][0]
)
backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(
backbone_out["backbone_fpn"][1]
)
feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels:]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels:]
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
# flatten NxCxHxW to HWxNxC
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
feats = [feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]
return feats[0], feats[1], feats[2]
class SAM2ImageDecoder(nn.Module):
def __init__(
self,
sam_model: SAM2Base,
multimask_output: bool
) -> None:
super().__init__()
self.mask_decoder = sam_model.sam_mask_decoder
self.prompt_encoder = sam_model.sam_prompt_encoder
self.model = sam_model
self.multimask_output = multimask_output
@torch.no_grad()
def forward(
self,
image_embed: torch.Tensor,
high_res_feats_0: torch.Tensor,
high_res_feats_1: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
mask_input: torch.Tensor,
has_mask_input: torch.Tensor,
img_size: torch.Tensor
):
sparse_embedding = self._embed_points(point_coords, point_labels)
self.sparse_embedding = sparse_embedding
dense_embedding = self._embed_masks(mask_input, has_mask_input)
high_res_feats = [high_res_feats_0, high_res_feats_1]
image_embed = image_embed
masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
image_embeddings=image_embed,
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embedding,
dense_prompt_embeddings=dense_embedding,
repeat_image=False,
high_res_features=high_res_feats,
)
if self.multimask_output:
masks = masks[:, 1:, :, :]
iou_predictions = iou_predictions[:, 1:]
else:
masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(masks, iou_predictions)
masks = torch.clamp(masks, -32.0, 32.0)
print(masks.shape, iou_predictions.shape)
masks = F.interpolate(masks, (img_size[0], img_size[1]), mode="bilinear", align_corners=False)
return masks, iou_predictions
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
point_coords = point_coords + 0.5
padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
point_coords = torch.cat([point_coords, padding_point], dim=1)
point_labels = torch.cat([point_labels, padding_label], dim=1)
point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size
point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1)
point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (
point_labels == -1
)
for i in range(self.prompt_encoder.num_point_embeddings):
point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
return point_embedding
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling(input_mask)
mask_embedding = mask_embedding + (
1 - has_mask_input
) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
return mask_embedding
model_type = 'sam2_hiera_l' #@param ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
# input_size = 768 #@param {type:"slider", min:160, max:4102, step:8}
input_size = 1024 # Bad output if anything else (for now)
multimask_output = False
if model_type == "sam2_hiera_tiny":
model_cfg = "sam2_hiera_t.yaml"
elif model_type == "sam2_hiera_small":
model_cfg = "sam2_hiera_s.yaml"
elif model_type == "sam2_hiera_base_plus":
model_cfg = "sam2_hiera_b+.yaml"
else:
model_cfg = "sam2_hiera_l.yaml"
import torch
from sam2.build_sam import build_sam2
sam2_checkpoint = f"sam2_hiera_large.pt"
# 这一device一定是cpu
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
img=torch.randn(1, 3, input_size, input_size).cpu()
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
high_res_feats_0, high_res_feats_1, image_embed = sam2_encoder(img)
print(high_res_feats_0.shape)
print(high_res_feats_1.shape)
print(image_embed.shape)
torch.onnx.export(sam2_encoder,
img,
f"{model_type}_encoder.onnx",
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names = ['image'],
output_names = ['high_res_feats_0', 'high_res_feats_1', 'image_embed']
)
sam2_decoder = SAM2ImageDecoder(sam2_model, multimask_output=multimask_output).cpu()
embed_dim = sam2_model.sam_prompt_encoder.embed_dim
embed_size = (sam2_model.image_size // sam2_model.backbone_stride, sam2_model.image_size // sam2_model.backbone_stride)
mask_input_size = [4 * x for x in embed_size]
print(embed_dim, embed_size, mask_input_size)
point_coords = torch.randint(low=0, high=input_size, size=(1, 10, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(1, 10), dtype=torch.float)
mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float)
has_mask_input = torch.tensor([0], dtype=torch.float)
orig_im_size = torch.tensor([input_size, input_size], dtype=torch.int32)
masks, scores = sam2_decoder(image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size)
torch.onnx.export(sam2_decoder,
(image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size),
"decoder.onnx",
export_params=True,
opset_version=16,
do_constant_folding=True,
input_names = ['image_embed', 'high_res_feats_0', 'high_res_feats_1', 'point_coords', 'point_labels', 'mask_input', 'has_mask_input', 'orig_im_size'],
output_names = ['masks', 'iou_predictions'],
dynamic_axes = {"point_coords": {0: "num_labels", 1: "num_points"},
"point_labels": {0: "num_labels", 1: "num_points"},
"mask_input": {0: "num_labels"},
"has_mask_input": {0: "num_labels"}
}
)
1. 模型类型与配置修改
在代码开头处选择不同的 SAM2 变体(如 sam2_hiera_tiny
、sam2_hiera_large
等):
model_type = 'sam2_hiera_large' # 可选: sam2_hiera_tiny/sam2_hiera_small/sam2_hiera_base_plus
对应配置文件会自动加载(如 sam2_hiera_l.yaml
)。
注意:不同模型的计算量和显存需求差异较大,需根据硬件选择。
输入分辨率:
默认 input_size = 1024
,暂时不支持修改(其他分辨率可能导致输出异常)
命令行
onnxsim {model_type}_encoder.onnx {model_type}_encoder.onnx
onnxsim decoder.onnx decoder.onnx
model_type用具体的模型代替
切换到sam2-onnx项目中,运行下面测试代码,这段代码实现了一个基于 SAM2 (Segment Anything Model 2) 的图像分割工具,能够根据用户提供的点坐标对图像中的特定区域进行分割。代码封装了 SAM2 模型的加载、图像预处理、点坐标转换和分割结果后处理等功能。
from sam2 import SAM2Image
import cv2
import numpy as np
encoder_model_path = "sam2_hiera_l_encoder.onnx"
decoder_model_path = "decoder_fixed.onnx"
sam2_model = SAM2Image(encoder_model_path, decoder_model_path)
def scale_points(point_coords, original_size, new_size):
# 原始图像的宽度和高度
original_width, original_height = original_size
# 新图像的宽度和高度
new_width, new_height = new_size
# 计算缩放比例
scale_x = new_width / original_width
scale_y = new_height / original_height
# 缩放后的点坐标
scaled_points = []
for points in point_coords:
scaled_points.append(np.round(points * np.array([scale_x, scale_y])).astype(int))
scaled_points = np.array(scaled_points)
return scaled_points
class SamIner():
def __init__(self):
# 原始图像大小
self.original_size = (1920, 1080)
# 新图像大小
self.new_size = (1024, 1024)
def __call__(self, image, box_points):
if all(x == 0 and y == 0 for (x, y) in box_points):
return None
img_source = image.copy()
img_source = cv2.resize(img_source, (1024, 1024))
sam2_model.set_image(img_source)
input_point = scale_points(box_points, self.original_size, self.new_size)
input_point = [input_point]
label_one = [1]*len(box_points)
label_one = np.array(label_one)
label_one = [label_one]
point_coords = [arr.astype(np.float32) for arr in input_point]
point_labels = [arr.astype(np.float32) for arr in label_one]
for label_id, (point_coord, point_label) in enumerate(zip(point_coords, point_labels)):
for i in range(point_label.shape[0]):
sam2_model.add_point((point_coord[i][0], point_coord[i][1]), point_label[i], label_id)
masks = sam2_model.get_masks()
mask = masks[0]
mask = mask*255
mask = mask.astype(np.uint8)
mask = cv2.resize(mask,(1920,1080))
imgO = 0.5*mask+0.5*image[:,:,0]
return imgO
infer = SamIner()
image = cv2.imread("example.jpg")
# 定义目标区域的点坐标(示例)
points = [(100, 200), (300, 400), (500, 600)]
# 执行分割
result = infer(image, points)
# 显示结果
cv2.imshow("Result", result)
cv2.waitKey(0)