SegmentAnything官网demo使用vue+python实现

一、效果&准备工作

1.效果

没啥好说的,低质量复刻SAM官网 https://segment-anything.com/

需要提一点:所有生成embedding和mask的操作都是python后端做的,计算mask不是onnxruntime-web实现的,前端只负责了把rle编码的mask解码后画到canvas上,会有几十毫秒的网络传输延迟。我不会react和typescript,官网F12里的源代码太难懂了,生成的svg总是与期望的不一样

主页

SegmentAnything官网demo使用vue+python实现_第1张图片

鼠标移动动态分割(Hover)

throttle了一下,修改代码里的throttle delay,反应更快些,我觉得没必要已经够了,设置的150ms

SegmentAnything官网demo使用vue+python实现_第2张图片

点选前景背景(Click)

蓝色前景,红色背景,对应clickType分别为1和0

SegmentAnything官网demo使用vue+python实现_第3张图片

分割(Cut out object)

同官网,分割出该区域需要的最小矩形框部分
SegmentAnything官网demo使用vue+python实现_第4张图片

分割所有(Everything)

随便做了下,实在做不出官网的效果,可能模型也有问题 ,我用的vit_b,懒得试了,这功能对我来说没卵用

SegmentAnything官网demo使用vue+python实现_第5张图片

2.准备工作

安装依赖

前端使用了Vue3+ElementPlus(https://element-plus.org/zh-CN/#/zh-CN)+axios+lz-string,npm安装一下。

后端是fastapi(https://fastapi.tiangolo.com/),FastAPI 依赖 Python 3.8 及更高版本。

安装 FastAPI

pip install fastapi

另外我们还需要一个 ASGI 服务器,生产环境可以使用 Uvicorn 或者 Hypercorn:

pip install "uvicorn[standard]"
要用的js文件

@/util/request.js

import axios from "axios";
import { ElMessage } from "element-plus";

axios.interceptors.request.use(
    config => {
        return config;
    },
    error => {
        return Promise.reject(error);
    }
);

axios.interceptors.response.use(
    response => {
        if (response.data.success != null && !response.data.success) {
            return Promise.reject(response.data)
        }
        return response.data;
    },
    error => {
        console.log('error: ', error)
        ElMessage.error(' ');
        return Promise.reject(error);
    }
);

export default axios;

然后在main.js中绑定

import axios from './util/request.js'
axios.defaults.baseURL = 'http://localhost:9000'
axios.defaults.headers.post['Content-Type'] = 'application/x-www-form-urlencoded';
app.config.globalProperties.$http = axios

@/util/throttle.js

function throttle(func, delay) {
    let timer = null; // 定时器变量

    return function() {
        const context = this; // 保存this指向
        const args = arguments; // 保存参数列表

        if (!timer) {
            timer = setTimeout(() => {
                func.apply(context, args); // 调用原始函数并传入上下文和参数
                clearTimeout(timer); // 清除计时器
                timer = null; // 重置计时器为null
            }, delay);
        }
    };
}
export default throttle

@/util/mask_utils.js

/**
 * Parses RLE from compressed string
 * @param {Array} input
 * @returns array of integers
 */
export const rleFrString = (input) => {
    let result = [];
    let charIndex = 0;
    while (charIndex < input.length) {
        let value = 0,
            k = 0,
            more = 1;
        while (more) {
            let c = input.charCodeAt(charIndex) - 48;
            value |= (c & 0x1f) << (5 * k);
            more = c & 0x20;
            charIndex++;
            k++;
            if (!more && c & 0x10) value |= -1 << (5 * k);
        }
        if (result.length > 2) value += result[result.length - 2];
        result.push(value);
    }
    return result;
};

/**
 * Parse RLE to mask array
 * @param rows
 * @param cols
 * @param counts
 * @returns {Uint8Array}
 */
export const decodeRleCounts = ([rows, cols], counts) => {
    let arr = new Uint8Array(rows * cols)
    let i = 0
    let flag = 0
    for (let k of counts) {
        while (k-- > 0) {
            arr[i++] = flag
        }
        flag = (flag + 1) % 2
    }
    return arr
};

/**
 * Parse Everything mode counts array to mask array
 * @param rows
 * @param cols
 * @param counts
 * @returns {Uint8Array}
 */
export const decodeEverythingMask = ([rows, cols], counts) => {
    let arr = new Uint8Array(rows * cols)
    let k = 0;
    for (let i = 0; i < counts.length; i += 2) {
        for (let j = 0; j < counts[i]; j++) {
            arr[k++] = counts[i + 1]
        }
    }
    return arr;
};

/**
 * Get globally unique color in the mask
 * @param category
 * @param colorMap
 * @returns {*}
 */
export const getUniqueColor = (category, colorMap) => {
    // 该种类没有颜色
    if (!colorMap.hasOwnProperty(category)) {
        // 生成唯一的颜色
        while (true) {
            const color = {
                r: Math.floor(Math.random() * 256),
                g: Math.floor(Math.random() * 256),
                b: Math.floor(Math.random() * 256)
            }
            // 检查颜色映射中是否已存在相同的颜色
            const existingColors = Object.values(colorMap);
            const isDuplicateColor = existingColors.some((existingColor) => {
                return color.r === existingColor.r && color.g === existingColor.g && color.b === existingColor.b;
            });
            // 如果不存在相同颜色,结束循环
            if (!isDuplicateColor) {
                colorMap[category] = color;
                break
            }
        }
        console.log("生成唯一颜色", category, colorMap[category])
        return colorMap[category]
    } else {
        return colorMap[category]
    }
}

/**
 * Cut out specific area of image uncovered by mask
 * @param w image's natural width
 * @param h image's natural height
 * @param image source image
 * @param canvas mask canvas
 * @param callback function to solve the image blob
 */
export const cutOutImage = ({w, h}, image, canvas, callback) => {
    const resultCanvas = document.createElement('canvas'),
        resultCtx = resultCanvas.getContext('2d', {willReadFrequently: true}),
        originalCtx = canvas.getContext('2d', {willReadFrequently: true});
    resultCanvas.width = w;
    resultCanvas.height = h;
    resultCtx.drawImage(image, 0, 0, w, h)
    const maskDataArray = originalCtx.getImageData(0, 0, w, h).data;
    const imageData = resultCtx.getImageData(0, 0, w, h);
    const imageDataArray = imageData.data
    // 将mask的部分去掉
    for (let i = 0; i < maskDataArray.length; i += 4) {
        const alpha = maskDataArray[i + 3];
        if (alpha !== 0) { // 不等于0,是mask区域
            imageDataArray[i + 3] = 0;
        }
    }
    // 计算被分割出来的部分的矩形框
    let minX = w;
    let minY = h;
    let maxX = 0;
    let maxY = 0;
    for (let y = 0; y < h; y++) {
        for (let x = 0; x < w; x++) {
            const alpha = imageDataArray[(y * w + x) * 4 + 3];
            if (alpha !== 0) {
                minX = Math.min(minX, x);
                minY = Math.min(minY, y);
                maxX = Math.max(maxX, x);
                maxY = Math.max(maxY, y);
            }
        }
    }
    const width = maxX - minX + 1;
    const height = maxY - minY + 1;
    const startX = minX;
    const startY = minY;
    resultCtx.putImageData(imageData, 0, 0)
    // 创建一个新的canvas来存储特定区域的图像
    const croppedCanvas = document.createElement("canvas");
    const croppedContext = croppedCanvas.getContext("2d");
    croppedCanvas.width = width;
    croppedCanvas.height = height;
    // 将特定区域绘制到新canvas上
    croppedContext.drawImage(resultCanvas, startX, startY, width, height, 0, 0, width, height);
    croppedCanvas.toBlob(blob => {
        if (callback) {
            callback(blob)
        }
    }, "image/png");
}

/**
 * Cut out specific area of image covered by target color mask
 * PS: 我写的这代码有问题,比较color的时候tmd明明mask canvas中有这个颜色,
 * 就是说不存在这颜色,所以不用这个函数,改成下面的了
 * @param w image's natural width
 * @param h image's natural height
 * @param image source image
 * @param canvas mask canvas
 * @param color target color
 * @param callback function to solve the image blob
 */
export const cutOutImageWithMaskColor = ({w, h}, image, canvas, color, callback) => {
    const resultCanvas = document.createElement('canvas'),
        resultCtx = resultCanvas.getContext('2d', {willReadFrequently: true}),
        originalCtx = canvas.getContext('2d', {willReadFrequently: true});
    resultCanvas.width = w;
    resultCanvas.height = h;
    resultCtx.drawImage(image, 0, 0, w, h)
    const maskDataArray = originalCtx.getImageData(0, 0, w, h).data;
    const imageData = resultCtx.getImageData(0, 0, w, h);
    const imageDataArray = imageData.data

    let find = false

    // 比较mask的color和目标color
    for (let i = 0; i < maskDataArray.length; i += 4) {
        const r = maskDataArray[i],
            g = maskDataArray[i + 1],
            b = maskDataArray[i + 2];
        if (r != color.r || g != color.g || b != color.b) { // 颜色与目标颜色不相同,是mask区域
            // 设置alpha为0
            imageDataArray[i + 3] = 0;
        } else {
            find = true
        }
    }
    // 计算被分割出来的部分的矩形框
    let minX = w;
    let minY = h;
    let maxX = 0;
    let maxY = 0;
    for (let y = 0; y < h; y++) {
        for (let x = 0; x < w; x++) {
            const alpha = imageDataArray[(y * w + x) * 4 + 3];
            if (alpha !== 0) {
                minX = Math.min(minX, x);
                minY = Math.min(minY, y);
                maxX = Math.max(maxX, x);
                maxY = Math.max(maxY, y);
            }
        }
    }
    const width = maxX - minX + 1;
    const height = maxY - minY + 1;
    const startX = minX;
    const startY = minY;
    // console.log(`矩形宽度:${width}`);
    // console.log(`矩形高度:${height}`);
    // console.log(`起点坐标:(${startX}, ${startY})`);
    resultCtx.putImageData(imageData, 0, 0)
    // 创建一个新的canvas来存储特定区域的图像
    const croppedCanvas = document.createElement("canvas");
    const croppedContext = croppedCanvas.getContext("2d");
    croppedCanvas.width = width;
    croppedCanvas.height = height;
    // 将特定区域绘制到新canvas上
    croppedContext.drawImage(resultCanvas, startX, startY, width, height, 0, 0, width, height);
    croppedCanvas.toBlob(blob => {
        if (callback) {
            callback(blob)
        }
    }, "image/png");
}

/**
 * Cut out specific area whose category is target category
 * @param w image's natural width
 * @param h image's natural height
 * @param image source image
 * @param arr original mask array that stores all pixel's category
 * @param category target category
 * @param callback function to solve the image blob
 */
export const cutOutImageWithCategory = ({w, h}, image, arr, category, callback) => {
    const resultCanvas = document.createElement('canvas'),
        resultCtx = resultCanvas.getContext('2d', {willReadFrequently: true});
    resultCanvas.width = w;
    resultCanvas.height = h;
    resultCtx.drawImage(image, 0, 0, w, h)
    const imageData = resultCtx.getImageData(0, 0, w, h);
    const imageDataArray = imageData.data
    // 比较mask的类别和目标类别
    let i = 0
    for(let y = 0; y < h; y++){
        for(let x = 0; x < w; x++){
            if (category != arr[i++]) { // 类别不相同,是mask区域
                // 设置alpha为0
                imageDataArray[3 + (w * y + x) * 4] = 0;
            }
        }
    }
    // 计算被分割出来的部分的矩形框
    let minX = w;
    let minY = h;
    let maxX = 0;
    let maxY = 0;
    for (let y = 0; y < h; y++) {
        for (let x = 0; x < w; x++) {
            const alpha = imageDataArray[(y * w + x) * 4 + 3];
            if (alpha !== 0) {
                minX = Math.min(minX, x);
                minY = Math.min(minY, y);
                maxX = Math.max(maxX, x);
                maxY = Math.max(maxY, y);
            }
        }
    }
    const width = maxX - minX + 1;
    const height = maxY - minY + 1;
    const startX = minX;
    const startY = minY;
    resultCtx.putImageData(imageData, 0, 0)
    // 创建一个新的canvas来存储特定区域的图像
    const croppedCanvas = document.createElement("canvas");
    const croppedContext = croppedCanvas.getContext("2d");
    croppedCanvas.width = width;
    croppedCanvas.height = height;
    // 将特定区域绘制到新canvas上
    croppedContext.drawImage(resultCanvas, startX, startY, width, height, 0, 0, width, height);
    croppedCanvas.toBlob(blob => {
        if (callback) {
            callback(blob)
        }
    }, "image/png");
}

二、后端代码

1.SAM下载

首先从github上下载SAM的代码https://github.com/facebookresearch/segment-anything

然后下载模型文件,保存到项目根目录/checkpoints中,

  • default or vit_h: ViT-H SAM model.
  • vit_l: ViT-L SAM model.
  • vit_b: ViT-B SAM model.

2.后端代码

在项目根目录下创建main.py

main.py

import os
import time

from PIL import Image
import numpy as np
import io
import base64
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
from pycocotools import mask as mask_utils
import lzstring


def init():
    # your model path
    checkpoint = "checkpoints/sam_vit_b_01ec64.pth"
    model_type = "vit_b"
    sam = sam_model_registry[model_type](checkpoint=checkpoint)
    sam.to(device='cuda')
    predictor = SamPredictor(sam)
    mask_generator = SamAutomaticMaskGenerator(sam)
    return predictor, mask_generator


predictor, mask_generator = init()

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins="*",
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

last_image = ""
last_logit = None


@app.post("/segment")
def process_image(body: dict):
    global last_image, last_logit
    print("start processing image", time.time())
    path = body["path"]
    is_first_segment = False
    # 看上次分割的图片是不是该图片
    if path != last_image:  # 不是该图片,重新生成图像embedding
        pil_image = Image.open(path)
        np_image = np.array(pil_image)
        predictor.set_image(np_image)
        last_image = path
        is_first_segment = True
        print("第一次识别该图片,获取embedding")
    # 获取mask
    clicks = body["clicks"]
    input_points = []
    input_labels = []
    for click in clicks:
        input_points.append([click["x"], click["y"]])
        input_labels.append(click["clickType"])
    print("input_points:{}, input_labels:{}".format(input_points, input_labels))
    input_points = np.array(input_points)
    input_labels = np.array(input_labels)
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        mask_input=last_logit[None, :, :] if not is_first_segment else None,
        multimask_output=is_first_segment  # 第一次产生3个结果,选择最优的
    )
    # 设置mask_input,为下一次做准备
    best = np.argmax(scores)
    last_logit = logits[best, :, :]
    masks = masks[best, :, :]
    # print(mask_utils.encode(np.asfortranarray(masks))["counts"])
    # numpy_array = np.frombuffer(mask_utils.encode(np.asfortranarray(masks))["counts"], dtype=np.uint8)
    # print("Uint8Array([" + ", ".join(map(str, numpy_array)) + "])")
    source_mask = mask_utils.encode(np.asfortranarray(masks))["counts"].decode("utf-8")
    # print(source_mask)
    lzs = lzstring.LZString()
    encoded = lzs.compressToEncodedURIComponent(source_mask)
    print("process finished", time.time())
    return {"shape": masks.shape, "mask": encoded}


@app.get("/everything")
def segment_everything(path: str):
    start_time = time.time()
    print("start segment_everything", start_time)
    pil_image = Image.open(path)
    np_image = np.array(pil_image)
    masks = mask_generator.generate(np_image)
    sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
    img = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1]), dtype=np.uint8)
    for idx, ann in enumerate(sorted_anns, 0):
        img[ann['segmentation']] = idx
    #看一下mask是什么样
    #plt.figure(figsize=(10,10))
	#plt.imshow(img) 
	#plt.show()
    # 压缩数组
    result = my_compress(img)
    end_time = time.time()
    print("finished segment_everything", end_time)
    print("time cost", end_time - start_time)
    return {"shape": img.shape, "mask": result}


@app.get('/automatic_masks')
def automatic_masks(path: str):
    pil_image = Image.open(path)
    np_image = np.array(pil_image)
    mask = mask_generator.generate(np_image)
    sorted_anns = sorted(mask, key=(lambda x: x['area']), reverse=True)
    lzs = lzstring.LZString()
    res = []
    for ann in sorted_anns:
        m = ann['segmentation']
        source_mask = mask_utils.encode(m)['counts'].decode("utf-8")
        encoded = lzs.compressToEncodedURIComponent(source_mask)
        r = {
            "encodedMask": encoded,
            "point_coord": ann['point_coords'][0],
        }
        res.append(r)
    return res


# 就是将连续的数字统计个数,然后把[个数,数字]放到result中,类似rle算法
# 比如[[1,1,1,2,3,2,2,4,4],[3,3,4...]]
# result是[3,1,  1,2,  1,3,  2,2,  2,4,  2,3,...]
def my_compress(img):
    result = []
    last_pixel = img[0][0]
    count = 0
    for line in img:
        for pixel in line:
            if pixel == last_pixel:
                count += 1
            else:
                result.append(count)
                result.append(int(last_pixel))
                last_pixel = pixel
                count = 1
    result.append(count)
    result.append(int(last_pixel))
    return result

3.原神启动

在cmd或者pycharm终端,cd到项目根目录下,输入uvicorn main:app --port 8006,启动服务器

三、前端代码

1.页面代码

template

script

style

2.代码说明

  • 本项目没做上传图片分割,就是简单的选择本地图片分割,data中url是img的src,path是绝对路径用来传给python后端进行分割,我是从我项目的系统获取的,请自行修改代码成你的图片路径,如src: “/assets/test.jpg”, path:“D:/project/segment/assets/test.jpg”
  • 由于pycocotools的rle encode是从上到下进行统计连续的0和1,为了方便,我在【@/util/mask_utils.js:decodeRleCounts】解码Click点选产生的mask时将(H,W)的矩阵转成了(W,H)顺序存储的Uint8array;而在Everything分割所有时,我没有使用pycocotools的encode,而是main.py中的my_compress函数编码的,是从左到右进行压缩,因此矩阵解码后仍然是(H,W)的矩阵,所以在drawCanvasdrawEverythingCanvas中的二层循环xy的顺序不一样,我实在懒得改了,就这样就可以了。

关于上面所提rle,可以在项目根目录/notebooks/predictor_example.ipynb中产生mask的位置添加代码自行观察他编码的rle,他只支持矩阵元素为0或1,result的第一个位置是0的个数,不管矩阵是不是0开头。

  • [0,0,1,1,0,1,0],rle counts是[2(两个0), 2(两个1), 1(一个0), 1(一个1), 1(一个0)];

  • [1,1,1,1,1,0],rle counts是[0(零个0),5(五个1),1(一个0)]

def decode_rle(rle_string): # 这是将pycocotools的counts编码的字符串转成counts数组,而非转成原矩阵
    result = []
    char_index = 0
    
    while char_index < len(rle_string):
        value = 0
        k = 0
        more = 1
        
        while more:
            c = ord(rle_string[char_index]) - 48
            value |= (c & 0x1f) << (5 * k)
            more = c & 0x20
            char_index += 1
            k += 1
            if not more and c & 0x10:
                value |= -1 << (5 * k)
        
        if len(result) > 2:
            value += result[-2]
        result.append(value)
    return result

from pycocotools import mask as mask_utils
import numpy as np
mask = np.array([[1,1,0,1,1,0],[1,1,1,1,1,1],[0,1,1,1,0,0],[1,1,1,1,1,1]])
mask = np.asfortranarray(mask, dtype=np.uint8)
print("原mask:\n{}".format(mask))
res = mask_utils.encode(mask)
print("encode:{}".format(res))
print("rle counts:{}".format(decode_rle(res["counts"].decode("utf-8"))))
# 转置后好看
print("转置:{}".format(mask.transpose()))
# flatten后更好看
print("flatten:{}".format(mask.transpose().flatten()))
#numpy_array = np.frombuffer(res["counts"], dtype=np.uint8)
# 打印numpy数组作为uint8array的格式
#print("Uint8Array([" + ", ".join(map(str, numpy_array)) + "])")

输出:

SegmentAnything官网demo使用vue+python实现_第6张图片

你可能感兴趣的:(vue.js,python,前端)