segment-anything:画像内の全ての物体を画素単位で切り出す事が出来るMetaの基盤モデル

アプリケーション

1.segment-anything:画像内の全ての物体を画素単位で切り出す事が出来るMetaの基盤モデルまとめ

・画像内にある物体を自動で認識して、切り出せるようにセグメンテーションマスクを作ってくれるSegment-Anything ModelをMetaがオープンソースで公開
・写真画像の性能がとても良い事は既にわかっているのでイラストなど二次元画像を対象にして評価をしてみたところ人型を認識する精度はかなり高かった
・指定した部分の境界抽出する使い方などもColabで公開されているため画像生成モデルや物体認識モデルの性能の底上げにかなり貢献する事になりそう

2.SAM(segment-anything model)とは?

アイキャッチ画像は「様々な物体に囲まれた場面でもセグメンテーションマスクができる事を表現した宮崎駿スタイルのイメージ画像を作りたいんですよね」とchatGPT先生に無茶ぶりして作成したプロンプトでカスタムStable Diffusion先生に作って貰ったイラスト。

宮崎駿スタイルにインスパイアされて、豊かな自然に囲まれた若い女性がモチーフとなったようなのですが、セグメンテーションマスクは普通に顔マスクになってしまったので、そこは私の方で修正した結果、何の画像かよくわからなくなってしまってます。

基盤モデル(Foundation Model)と言う概念は最近、日本でも本が出版されたためか、よく目にするようになった気がしますが、やや曖昧な概念のような気がしています。

私の理解するところによれば「事前学習済み大規模モデルの事。微調整して様々な固有タスクに応用ができるので非常にありがたい存在」です。

しかし、定義によっては「様々な種類のデータ(マルチモーダル)で学習した、様々なタスク(マルチタスク)に適用できるモデル」と、マルチである事が要件になっているように読める事があります。

いや、何が気になったかと言うと、今回、紹介するMetaのsegment-anythingはYann LeCun教授がTwitterで
「SAM: Segment Anything Model from FAIR.Foundation model for image segmentation.」
っと「イメージセグメンテーション向けの基盤モデル」とツイートしていたので、イメージセグメンテーションは(対象が静止画と動画などはあるとはいえ)は、広義で言えば下流タスクの一つだと思っていたので、やっぱりマルチは要件ではなくて良いのかなっと。

脱線しましたが、画像内にある物体を自動で認識して、切り出せるようにセグメンテーションマスクを作ってくれる、SAM(Segment-Anything Model)をMetaがオープンソースとして公開してくれていたので、以下で評価してみました。

写真画像の性能がとても良いであろうことはわかりきっているので、イラストなど二次元画像を対象にしています。Deep-MARCと比較してもだいぶ進化していると感じます。

スタジオジブリの素材集からお借りした画像

難しそうな素材を敢えて選んでおり、且つSegment-Anything Modelはディフォルト設定です。そのため、もっと精度を上げる事はできると思います。学習済みモデルも3種公開されていますが、ディフォルトのもののみを使って評価しています。

以下は、あくまでも第一印象とポテンシャルを確認する位置づけの比較画像とお考え下さい。人型の切り出しはかなり精度が高いように思います。



Stable Diffusionで作成した画像を評価

Stable Diffusion生成イラストは奥行や境界があり得ない構造になる事があるので、境界認識は難しそうですが、かなり対応できています。しかし、絵柄によってはチューニングが必要そうです。




以下、更に難度を上げて元の生成画像をOutpaintingで拡張した画像

かなり対応できていますが、ここまでくると2段階で抽出する等の工夫が必要かもしれません。

 

なお、画像内の抜き出して欲しい部分にマークをつけてその部分を抽出するデモのColabも公開されているので、上記で出来る事が全てではないです。これを使って学習元画像をより丁寧に切り出すようにすると、画像生成モデルや物体認識モデルの性能が底上げされる事になりそうだな、と思っています。

ローカル環境でのSetup情報

2種類のColabが公開されているので、さっと試すだけでしたらColabで動かした方が良いです。

ローカルでのセットアップはそれほど苦戦しませんでした。RTX 3060 メモリ12GBでも十分動きます。ただ、大きな画像だと一枚処理するのに10秒以上かかるのでそれなりに重いです。

conda create -n segment-anything
conda activate segment-anything

conda install python=3.10 pyparsing

# 以下は公式サイト(https://pytorch.org/get-started/locally/)を見て自分の環境に合わせて変えてください。
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia

git clone https://github.com/facebookresearch/segment-anything.git

cd segment-anything

pip install git+https://github.com/facebookresearch/segment-anything.git
pip install opencv-python pycocotools matplotlib onnxruntime onnx

wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

mkdir sample
mkdir out

#sample 配下に対象となる画像をコピーしてください

python3 scripts/amg.py --checkpoint sam_vit_h_4b8939.pth --input sample/ --output out/

上記の付属サンプルスクリプトだと、マスク部分だけ抽出した細かいpngファイルが大量に出来る or 座標のjsonファイルが出来るだけなので、全景がわかりにくくなってしまうので、Colabのコードを雑にコピーして作ったサンプルは以下です。

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_anns(anns):
    if len(anns) == 0:
        return

    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
            ax.imshow(np.dstack((img, m*0.35)))

# 以下を自分のファイル名に書き換えてください
image = cv2.imread('sample/test1.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')

sam_checkpoint = "sam_vit_h_4b8939.pth"

device = "cuda"
model_type = "default"

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)

print(len(masks))
print(masks[0].keys())
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()

3.segment-anything:画像内の全ての物体を画素単位で切り出す事が出来るMetaの基盤モデル関連リンク

1)ai.facebook.com
Segment Anything

2)github.com
facebookresearch / segment-anything

3)colab.research.google.com
automatic_mask_generator_example.ipynb (気軽に試してみたい方はこちら)
predictor_example.ipynb (画像内の抜き出して欲しい部分にマークをつけてその部分を抽出するデモ)

タイトルとURLをコピーしました