Skip to content

MobileSAM

This is the official code for MobileSAM project that makes SAM lightweight for mobile applications and beyond!

MobileSAM vs FastSAM

FastSAM
SAM의 학습 데이터인 SA-1B에서 단 2%만 활용하여 학습하였고 NVIDIA GeForce RTX 3090 기준 약 50배 빠른 런타임 속도로 추론할 수 있다고 합니다.
파라미터 개수는 파격적으로 줄지는 않았지만 실제 속도는 매우 빨라진 것 같습니다.(SAM(Vit-B)-136M, FastSAM-68M)
MobileSAM
경희대학교에서 발표한 논문으로 SAM의 구조에서 이미지 인코더 부분만 TinyViT 로 변경한 모델입니다.
TinyViT(5M)는 ‘Fast Pretraining Distillation for Small Vision Transformers’ 논문에서 제안된 Distillation을 활용해 ViT를 경량화시킨 모델이며 SAM의 이미지 인코더인 ViT-H(632M) 보다 약 120배 적은 파라미터 수를 가지고 있습니다. FastSAM의 총 파라미터 수가 68M인 반면에 MobileSAM은 9.66M 인 것을 보면 얼마나 적은 파라미터인지 체감할 수 있습니다.

Install

CUDA 12.2 에서 설치: (opy 스크립트로 venv 공간 설치 완료 후 진행)

git clone --depth 1 https://github.com/ChaoningZhang/MobileSAM
cd MobileSAM
cp {나의 opy python 단독 스크립트} ./python
./python -m pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121
./python -m pip install timm
./python opencv-python

Example

# -*- coding: utf-8 -*-

from mobile_sam import sam_model_registry, SamPredictor
import cv2
import os
import numpy as np
from pathlib import Path
from glob import glob
from datetime import datetime

model_type = "vit_t"
sam_checkpoint = "./weights/mobile_sam.pt"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device="cuda")
sam.eval()

predictor = SamPredictor(sam)
# files = glob("mpv-shot*.jpg")
# files = glob("*_cam1.bmp")
# files = glob("*.png")
files = glob("./input/*.jpg")
files.sort()

output_dir = Path("output")
if not output_dir.is_dir():
    output_dir.mkdir()


for filename in files:
    file_prefix, ext = os.path.splitext(filename)

    input_point = np.array([[1673, 224]])  # Cam1 (4K)
    # input_point = np.array([[836, 112]])  # Cam1 (FHD)
    # input_point = np.array([[768, 224]])  # Cam1 (4k-cut)
    input_label = np.array([1])

    img = cv2.imread(filename)

    begin = datetime.now()
    predictor.set_image(img)
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )
    duration = (datetime.now() - begin).total_seconds()
    print(f"Predict duration: {duration:.02f}s")

    for i in range(masks.shape[0]):
        mask_filename = file_prefix + f"-{i}.png"
        mask_filepath = str(output_dir / mask_filename)
        cv2.imwrite(mask_filepath, masks[0] * 255)
        print(f"Save mask image: {mask_filename}")

ONNX Export

python scripts/export_onnx_model.py --checkpoint ./weights/mobile_sam.pt --model-type vit_t --output ./mobile_sam.onnx

ONNX 모델에는 SamPredictor.predict와 다른 입력 서명이 있습니다.

다음 입력이 모두 제공되어야 합니다. 포인트 입력과 마스크 입력 모두에 대한 특별한 경우에 유의하십시오. 모든 입력은 np.float32입니다.

  • image_embeddings - Predictor.get_image_embedding()에서 삽입된 이미지입니다. 길이가 1인 배치 인덱스가 있습니다.
  • point_coords - 포인트 입력과 상자 입력 모두에 해당하는 희소 입력 프롬프트의 좌표입니다. 상자는 두 개의 점(왼쪽 상단 모서리에 하나, 오른쪽 하단 모서리에 하나)을 사용하여 인코딩됩니다. 좌표는 이미 긴 쪽 1024로 변환되어 있어야 합니다. 배치 인덱스 길이는 1입니다.
  • point_labels - 희소 입력 프롬프트에 대한 레이블입니다. 0은 음수 입력 지점, 1은 양수 입력 지점, 2는 왼쪽 위 상자 모서리, 3은 오른쪽 아래 상자 모서리, -1은 패딩 지점입니다. 상자 입력이 없는 경우 레이블이 -1이고 좌표(0.0, 0.0)가 있는 단일 패딩 포인트를 연결해야 합니다.
  • Mask_input - 1x1x256x256 형태의 모델에 입력되는 마스크입니다. 마스크 입력이 없더라도 반드시 제공해야 합니다. 이 경우에는 0일 수도 있습니다.
  • has_mask_input - 마스크 입력에 대한 표시자입니다. 1은 마스크 입력을 나타내고, 0은 마스크 입력이 없음을 나타냅니다.
  • orig_im_size - 변환 전 (H,W) 형식의 입력 이미지 크기입니다.

또한 ONNX 모델은 출력 마스크 로짓의 임계값을 지정하지 않습니다. 바이너리 마스크를 얻으려면 sam.mask_threshold의 임계값(0.0과 동일)입니다.

session.get_inputs():

Expression: session.get_inputs()
- 0: <onnxruntime.capi.onnxruntime_pybind11_state.NodeArg object at 0x7f064d2d4170>
  + special variables:
  - name: 'image_embeddings'
  + shape: [1, 256, 64, 64]
  - type: 'tensor(float)'
- 1: <onnxruntime.capi.onnxruntime_pybind11_state.NodeArg object at 0x7f064d2d41b0>
  + special variables:
  - name: 'point_coords'
  + shape: [1, 'num_points', 2]
  - type: 'tensor(float)'
- 2: <onnxruntime.capi.onnxruntime_pybind11_state.NodeArg object at 0x7f064d2d41f0>
  + special variables:
  - name: 'point_labels'
  + shape: [1, 'num_points']
  - type: 'tensor(float)'
- 3: <onnxruntime.capi.onnxruntime_pybind11_state.NodeArg object at 0x7f064d2d4230>
  + special variables:
  - name: 'mask_input'
  + shape: [1, 1, 256, 256]
  - type: 'tensor(float)'
- 4: <onnxruntime.capi.onnxruntime_pybind11_state.NodeArg object at 0x7f064d2d4270>
  + special variables:
  - name: 'has_mask_input'
  + shape: [1]
  - type: 'tensor(float)'
- 5: <onnxruntime.capi.onnxruntime_pybind11_state.NodeArg object at 0x7f064d2d42b0>
  + special variables:
  - name: 'orig_im_size'
  + shape: [2]
  - type: 'tensor(float)'

session.get_outputs():

Expression: session.get_outputs()
- 0: <onnxruntime.capi.onnxruntime_pybind11_state.NodeArg object at 0x7f064d2d42f0>
  + special variables:
  - name: 'masks'
  + shape: ['Resizemasks_dim_0', 'Resizemasks_dim_1', 'Resizemasks_dim_2', 'Resizemasks_dim_3']
  - type: 'tensor(float)'
- 1: <onnxruntime.capi.onnxruntime_pybind11_state.NodeArg object at 0x7f064d2d4330>
  + special variables:
  - name: 'iou_predictions'
  + shape: ['Gemmiou_predictions_dim_0', 4]
  - type: 'tensor(float)'
- 2: <onnxruntime.capi.onnxruntime_pybind11_state.NodeArg object at 0x7f064d2d4370>
  + special variables:
  - name: 'low_res_masks'
  + shape: ['Reshapelow_res_masks_dim_0', 'Reshapelow_res_masks_dim_1', 'Reshapelow_res_masks_dim_2', 'Reshapelow_res_masks_dim_3']
  - type: 'tensor(float)'

ONNX Example

# -*- coding: utf-8 -*-

from mobile_sam import sam_model_registry, SamPredictor
import cv2
import os
import sys
import numpy as np
from pathlib import Path
from glob import glob
from datetime import datetime

import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from mobile_sam import sam_model_registry, SamPredictor
from mobile_sam.utils.onnx import SamOnnxModel

import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

files = glob("./input/*.jpg")
if not files:
    print("Not found any files", file=sys.stderr)
    exit(1)

files.sort()

output_dir = Path("output")
if not output_dir.is_dir():
    output_dir.mkdir()

onnx_model_path = "mobile_sam.onnx"
session = onnxruntime.InferenceSession(onnx_model_path, providers=['CUDAExecutionProvider'])
iobinding = session.io_binding()

model_type = "vit_t"
sam_checkpoint = "./weights/mobile_sam.pt"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device="cuda")
sam.eval()

predictor = SamPredictor(sam)

for filename in files:
    file_basename = os.path.basename(filename)
    file_prefix, ext = os.path.splitext(file_basename)

    input_point = np.array([[1673, 224]])  # Cam1 (4K)
    # input_point = np.array([[836, 112]])  # Cam1 (FHD)
    # input_point = np.array([[768, 224]])  # Cam1 (4k-cut)
    input_label = np.array([1])

    image = cv2.imread(filename)
    begin = datetime.now()

    predictor.set_image(image)
    image_embedding = predictor.get_image_embedding().cpu().numpy()
    print(f"image_embedding.shape: {image_embedding.shape}")

    onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
    onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
    onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
    onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    onnx_has_mask_input = np.zeros(1, dtype=np.float32)
    ort_inputs = {
        "image_embeddings": image_embedding,
        "point_coords": onnx_coord,
        "point_labels": onnx_label,
        "mask_input": onnx_mask_input,
        "has_mask_input": onnx_has_mask_input,
        "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
    }
    masks, _, low_res_logits = session.run(None, ort_inputs)
    masks = masks > predictor.model.mask_threshold

    # masks, scores, logits = predictor.predict(
    #     point_coords=input_point,
    #     point_labels=input_label,
    #     multimask_output=True,
    # )

    duration = (datetime.now() - begin).total_seconds()
    print(f"Predict duration: {duration:.02f}s")

    for i in range(masks.shape[0]):
        mask_filename = file_prefix + f"-{i}.png"
        mask_filepath = str(output_dir / mask_filename)
        cv2.imwrite(mask_filepath, masks[0] * 255)
        print(f"Save mask image: {mask_filepath}")

내부 테스트 결과

NVIDIA GeForce RTX 3070 Ti
3840x2160 (4K Original) 이미지 - 약 0.12s (약 120ms)
1600x2160 (ROI Crop) 이미지 - 약 0.07s (약 70ms)
NVIDIA GeForce RTX 4090
3840x2160 (4K Original) 이미지 - 0.0391s ~ 0.0458s (약 45ms)
1660x1926 (ROI Crop) 이미지 - 0.026s ~ 0.031s (약 31ms)

TODO

See also

Favorite site