Skip to content

Segment Anything

이미지에서 어떤 객체든 추출해주는 Meta의 AI모델

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.

Features

  • "이미지 세그멘테이션"을 위한 첫번째 파운데이션 모델
    • 픽셀이 어떤 객체에 속해있는지를 식별하는 것
  • SAM 과 10억개의 파라미터 데이터셋 (SA-1B) 를 공개
  • SAM은 물체가 무엇인지에 대한 일반적인 개념을 학습 했고, 훈련중에 만나지 못한 물체 및 이미지 유형에 대해서도 이미지/비디오의 모든 객체에 대해서 마스크를 생성 가능
    • 추가적인 훈련 없이도 새로운 이미지 도메인(물속 사진이나, 세포 현미경 사진등)에도 적용 가능

Projects

Predict example

The code requires python>=3.8, as well as pytorch>=1.7 and torchvision>=0.8.

Install Segment Anything:

pip install git+https://github.com/facebookresearch/segment-anything.git

The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. jupyter is also required to run the example notebooks.

pip install opencv-python pycocotools matplotlib onnxruntime onnx

Download Model Checkpoints:

Three model versions of the model are available with different backbone sizes. These models can be instantiated by running

from segment_anything import sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")

Click the links below to download the checkpoint for the corresponding model type.

from segment_anything import SamPredictor, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)

Example code

# -*- coding: utf-8 -*-
# curl -O https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

from segment_anything import SamPredictor, sam_model_registry
import cv2
import os
import numpy as np
from datetime import datetime

use_gpu = False

sam = sam_model_registry["default"](checkpoint="./sam_vit_h_4b8939.pth")
if use_gpu:
    sam.to(device="cuda")
predictor = SamPredictor(sam)
files = (
    "mpv-shot0001.jpg",
    "mpv-shot0002.jpg",
    "mpv-shot0003.jpg",
    "mpv-shot0004.jpg",
    "mpv-shot0005.jpg",
    "mpv-shot0006.jpg",
    "mpv-shot0007.jpg",
    "mpv-shot0008.jpg",
    "mpv-shot0009.jpg",
    "mpv-shot0010.jpg",
    "mpv-shot0011.jpg",
    "mpv-shot0012.jpg",
    "mpv-shot0013.jpg",
    "mpv-shot0014.jpg",
    "mpv-shot0015.jpg",
    "mpv-shot0016.jpg",
)

for filename in files:
    file_prefix, ext = os.path.splitext(filename)
    input_point = np.array([[1950, 378]])
    input_label = np.array([1])

    img = cv2.imread(filename)

    begin = datetime.now()
    predictor.set_image(img)
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )
    duration = (datetime.now() - begin).total_seconds()
    print(f"Predict duration: {duration:.02f}s")

    for i in range(masks.shape[0]):
        mask_filename = file_prefix + f"-{i}.png"
        cv2.imwrite(mask_filename, masks[0] * 255)
        print(f"Save mask image: {mask_filename}")

NVIDIA GeForce RTX 3070 Ti 에서 Predict duration 이 약 0.72s 나온다.

내부 테스트 결과

NVIDIA GeForce RTX 3070 Ti
3840x2160 (4K Original) 이미지 - 약 0.7s 소요

See also

Favorite site