物体検出とは、画像内に写っている物体のカテゴリと物体位置を検出する技術です。身近な例としては、スマートフォンでの顔認証や工場における外観検査、自動運転の歩行者検出にも使われています。また、物体検出の歴史は長く、現在でも新しい物体検出手法が盛んに研究されています。
本連載では、自然言語処理でよく使われる「Transformer」を採用した物体検出モデルDETRについて紹介します。DETR自体は1年前に公開されたため、原理などの説明はGoogle検索すれば沢山出てくると思いますが、推論やFine-Tuningの方法を紹介している日本語記事が少ない印象を受けました。そのため、初心者向けにDETRで推論とFine-Tuningを実行する方法をご紹介したいと思います。
1.始めに
本記事では、物体検出の概要とこれまでの歴史を紹介いたします。それを踏まえて、DETRの原理について数学的な細かい説明を省いてなるべく簡単に解説します。また、公式チュートリアルを参考にDETRを使った推論方法と結果の可視化方法を紹介いたします。DETRとはどのようなもので、今までの物体検出手法とは何が違うのかをざっくりとでも理解していただければ嬉しいです。今回は詳細な解説は省いていますが、第2回以降の記事で解説していく予定です。
2.物体検出の歴史
物体検出とは最初にも記載した通り、画像内に写っている物体のカテゴリと物体位置を検出する技術です。事前にヒトやモノなどの物体を学習で覚えさせたモデルに画像を入力し、何処に既知の物体があるのかを検出し、その物体を矩形(以下、バウンディングボックス)で囲むように出力されます。出力イメージは、下記画像のようなものです。この画像は、物体検出分野では有名なデータセットであるCOCO dataset 1 の猫画像に対して、実際に物体検出を行ったものです。
物体検出のイメージを持っていただいたところで、本章ではDETR以前の物体検出の歴史について簡単に説明します。ディープラーニングが登場する以前は、SIFTやHOGなどの特徴量抽出を使った物体検出手法が主流でした。しかし、2013年に登場したR-CNN 2 以降、物体検出の研究はディープラーニング時代に突入しました。
ディープラーニングによる物体検出手法の時系列をまとめた図がGitHub 3 で公開されていますので、確認してみます。
上図は2013年から2019年までの時系列になりますが、様々な手法が提案されていることが分かるかと思います。赤字で表記されている手法は代表的な手法なので、本記事では紹介しませんが、興味のある方はぜひ調べてみて下さい。
物体検出では様々な手法が提案されてきましたが、今までTransformerを利用した物体検出の手法はありませんでした。しかし、2020年5月に物体検出で初めてTransformerを導入した「DETR」という手法が提案されました。自然言語処理分野でよく使われるTransformerを物体検出分野に適用し、既存の手法と異なるアプローチながらも今までの手法と遜色ない精度が出せています(詳細は後述)ので、今後DETRが代表的な手法と言われるようになる可能性が高いと考えられます。
3.DETRの概要
DETR (End-to-End Object Detection with Transformers) 4 は2020年5月に Facebook の研究チームが公開した論文で、初めてTransformerを採用した物体検出モデルです。ちなみに、DETRは DEtection TRansformer の頭文字をとったものです。
論文タイトルの最初に「End-to-End」と書かれています。「End-to-End」とは、前処理などを行わず入力と出力を単一モデルで学習することです。今までの物体検出では色々な処理を行って物体検出を行っていましたが、DETRではTransformerにインプット画像を入力するだけで物体検出ができる、というような非常にシンプルな作りになっています。手動でのハイパーパラメータ設計 (Anchor boxの数、アスペクト比、バウンディングボックスのデフォルト座標、NMSの閾値など) も大幅に減っています。
そのため、DETRはシンプルなモデルかつ、自然言語界隈のようにこれから色々な発展形の手法が生み出されるであろう、非常に可能性を感じるモデルです。
3-1.既存の物体検出との違い
ここでは、Transformerを使わない既存の物体検出と、DETRの物体検出の違いを説明します。下図はFacebook AIのブログ 5 から引用したものですが、上側にFaster R-CNN 6 の物体検出処理、下側にDETRの物体検出処理が描かれています。
Faster R-CNNを見てみると、画像をCNNに通して特徴マップに変換した後にNMS(同じクラスとして判定された重なっているバウンディングボックスを除去する処理)やRoIAlign(元画像と特徴マップとの座標ずれを補正する処理)など、色々な処理を施していることが分かります。対してDETRでは、CNNで特徴マップに変換した後はTransformerを適用することで物体検出ができます。非常にシンプルな「End-to-End」の作りになっていることが分かるかと思います。
また、既存手法であるFaster R-CNNとRetinaNet 7 と、DETRの性能比較を行った結果が論文に記載されていますので、下記に示します。
Faster R-CNNとRetinaNetの末尾の “+” は、通常の9倍のepoch数で学習させていることを表しています。また、DETR-DC5は通常のDETRの特徴マップの解像度を2倍にしたモデルになります。AP(Average Precision) が物体検出の性能の指標になりますが、既存手法と同程度の性能を出せていることが分かるかと思います。今までの物体検出で使われてきたNMSなどの手法を使わなくても、Transformerだけで同程度の性能であることは非常にすごいことです。
しかし、DETRにも弱点があります。APS(小さいオブジェクトに対するAP)とAPL(大きいオブジェクトに対するAP)の値を見ると、APLに関しては既存手法と比べて良い結果が得られていますが、APSに関しては値が低下しています。そのため、DETRは大きいオブジェクトの検出には強いですが、小さいオブジェクトに対しては弱いことが分かります。今後、この問題点が解消されたDETRベースのモデルが開発されることが期待されます。
また、Transformerについては、元は自然言語処理で提案された手法になります。本記事では詳しく説明はしませんが、詳細を知りたい方は弊社記事の 「はじめての自然言語処理 第3回 BERTを用いた自然言語処理における転移学習」 をぜひご覧下さい。
3-2.DETRのアーキテクチャ
DETRのアーキテクチャについて簡単に説明します。今回は後述の推論処理を理解できるレベルに、説明は最小限に留めます。詳細の説明は次回以降の記事で解説する予定です。
下の図は、DETRの論文 4 に記載されているアーキテクチャの図になります。説明のために①~⑤を追記しています。
①backbone
インプット画像に対してCNNでの畳み込みを行いd次元の特徴マップに変換します。また、Transformerのinputにするために次元削減も行います。文字だけでは伝わりにくいと思いますので、実際の処理イメージを下図に示します。
バッチサイズをB、入力画像の幅・高さ・チャネル数をそれぞれW、H、C、CNN適用後の特徴マップの幅・高さ・チャネル数をそれぞれW'、H'、dで表しています。
※CNNでの畳み込み後の特徴量dは、ハイパーパラメータになります。
②positional encoding
Transformerの前処理として、自然言語処理でも使われる処理です。TransformerはCNNやRNNを使わないので位置情報を保持できません。そのため、この処理で位置情報を付与します。
③encoder、④decoder
TransformerのEncoder、Decoderになります。Encoder側は②をインプットに、Decoder側は「object queries」(N個) をインプットにしています。最終的に、N個のd次元特徴量がアウトプットとして出力されます。object queriesについて少し補足しておくと、object queriesとはネットワークの重みのような、学習して決まるパラメータのことです。物体検出とクラス分類をするために学習時に学習されます。また、初期値はランダムなベクトル値で、任意のN個を設定します。
※object queriesの数Nは、ハイパーパラメータになります。
⑤prediction heads
FFN(Feed Forward Network)を通して、物体の位置座標とクラスラベルをデコードします。 (class, box)
の組み合わせがN個出力されます。ここで注意ですが、結果は必ずN個出力されます。例えばN=100とすると、アーキテクチャの図のようにカモメが2羽しかいない場合は、2個はカモメ2羽を判定し、残りの98個は 「no object(分類すべき対象が存在しない)」
という判定になります。
以上、ざっくりとですがDETRのアーキテクチャの説明でした。DETRについてのイメージは持てたでしょうか。
4.推論
それでは、実際にDETRでの推論を試してみます。今回は学習は行わず、公式で用意されている学習済みモデルを利用した推論方法について説明します。
推論に関しては、DETRのGitHubページ 8 にチュートリアル用のノートブック「DETR’s hands on Colab Notebook」、「Standalone Colab Notebook」 が公開されています。そちらを参考にしながら、Google Colaboratory(以下、Colab) 9 での推論方法を解説していきます。Colabのハードウェアアクセラレータですが、今回は推論のみなのでCPUでもGPUでもどちらでも構いません。(当然、GPUの方が処理速度は速いです。)
4-1. では、前章で説明したDETRのアーキテクチャと紐付けながら、モデルを自分で設計する方法を紹介いたします。今回は学習しませんので、学習済み重みパラメータをダウンロードして設計したモデルに適用します。そうすることで、推論を行うことができます。しかし、チュートリアル用ノートブックで公開されているモデルはGitHubで公開されているオリジナルのモデルから一部簡略化されているため、オリジナルよりも精度が低くなってしまうと考えられます。そのため、 4-2. ではオリジナルの学習済みモデル全体をダウンロードして推論する方法も紹介いたします。4-1. でDETRのアーキテクチャとソースコードの対応付けを理解していただき、4-2. で本気モードの推論を行う、といった流れになります。そして最後の 4-3. では、推論結果を画像+バウンディングボックスで可視化する方法も説明いたします。
とりあえずDETRで推論してみたい、という方は、4-2. と 4-3. のソースコードを上から順にColabにコピーすれば推論を試すことができます。
4-1.重みパラメータのみをダウンロードして推論
この章ではモデル全体をロードせず、モデルは自分で定義して重みパラメータのみをロードする方法を説明します。前章で説明したアーキテクチャと紐付けながら説明しますので、DETRについてより理解を深めることができると思います。
DETRの論文 4 に30行程度のモデルのコードが記載されていますが、positional encodingが省略されているなど、実際のアーキテクチャとの差異があります。また、論文モデルの重みパラメータは公開されていないので、推論を試す場合には一から学習を行う必要があります。ですので、チュートリアル用ノートブック 「Standalone Colab Notebook」 を参考にしながら説明していきます。こちらの方が実際のアーキテクチャに近く、専用の重みパラメータが公開されているので推論を試すには丁度よいです。ただし、それでもGitHubで公開されている実際の実装との差異がありますので注意が必要です。
つまり、「論文記載のソースコード」「チュートリアル用ノートブックの実装」「GitHubで公開されているオリジナルの実装」の3種類の実装がありますが、それぞれ差異があるというわけです。「チュートリアル用ノートブックの実装」と「GitHubで公開されているオリジナルの実装」との差異がある箇所を以下に列挙します。詳細な説明は割愛しますが、実際の実装よりも簡略化しているんだな、というところだけご理解いただければと思います。
- positional encodingの初期値がランダム値 (本来はsine positional encodingを使う)
- Attention時ではなく、Transformer入力時にpositional encodingを適用
- 全結合層(アーキテクチャ的にはFFN)にMLPを使わず、nn.Linear(全結合層)1つのみで構成
それでは早速コーディングしていきます。まずは必要なパッケージをインポートします。DETRはPyTorchで書かれていますが、Colabには既にインストール済みなので、別途インストールは不要です。
import torch, torchvision import torchvision.transforms as T from torch import nn from torchvision.models import resnet50 import requests import matplotlib.pyplot as plt from PIL import Image print(torch.__version__) # 1.9.0+cu102 print(torchvision.__version__) # 0.10.0+cu102
次に、DETRのモデルクラスを定義します。理解を深めるために、3-2. のアーキテクチャと紐づけて①~⑤をコメントとして記載しています。また上述の通り、実際のアーキテクチャと差異がある箇所がありますので、その箇所もコメントとして記述しています。
class DETRdemo(nn.Module): """ DETRのモデルクラス """ def __init__(self, num_classes, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6): super().__init__() # backboneとしてResNet-50を利用 self.backbone = resnet50() # 最終層である全結合層(Fully Connected Layer)を削除 del self.backbone.fc self.conv = nn.Conv2d(2048, hidden_dim, 1) # Transformerは標準的なTransformerを利用 self.transformer = nn.Transformer( hidden_dim, nheads, num_encoder_layers, num_decoder_layers) # FFN # ※全結合層(アーキテクチャ的にはFFN)にMLPを使わず、nn.Linear(全結合層)1つのみで構成 self.linear_class = nn.Linear(hidden_dim, num_classes + 1) self.linear_bbox = nn.Linear(hidden_dim, 4) # Object Queries self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) # Positional Encoding # ※positional encodingの初期値がランダム値 (本来はsine positional encodingを使う) self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) def forward(self, inputs): # ①backbone x = self.backbone.conv1(inputs) x = self.backbone.bn1(x) x = self.backbone.relu(x) x = self.backbone.maxpool(x) x = self.backbone.layer1(x) x = self.backbone.layer2(x) x = self.backbone.layer3(x) x = self.backbone.layer4(x) # hidden_dim次元に削減 h = self.conv(x) # ②positional encoding H, W = h.shape[-2:] pos = torch.cat([ self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), ], dim=-1).flatten(0, 1).unsqueeze(1) # ③encoder、④decoder # ※Attention時ではなく、Transformer入力時にpositional encodingを適用 h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1)).transpose(0, 1) # ⑤prediction heads return {'pred_logits': self.linear_class(h), 'pred_boxes': self.linear_bbox(h).sigmoid()}
PyTorchでモデルを定義する際は、nn.Module
クラスを継承します。__init__
関数にモデルの構成要素を定義して、forward
関数に順伝播の処理を記載します。
backboneのCNNには、ResNet-50を使っています。最終層のfc層を削除し、代わりにconv層を追加して hidden_dim
次元(前章のアーキテクチャの説明ではd次元で説明しました)に圧縮します。次にpositional encodingを定義していますが、行方向と列方向をそれぞれ hidden_dim//2
次元で定義します。それを torch.cat
関数で結合して、hidden_dim
次元にしているわけです。今回は簡略化のために初期値がランダム値ですが、実際は正弦波で定義されますので、半分は行方向、半分は列方向の位置情報を覚えさせています。正弦波を使ったpositional encodingの実装を確認したい場合は、公式の実装 10 をご確認下さい。次はtransformerの定義です。Encoderにはbackboneの出力値を平坦化したもの、Decoderにはobject queriesを入力します。今回はN=100を想定しているため、引数に100を決め打ちに入力しています(本記事は公式チュートリアル準拠ですが、__init__
関数の引数にした方がスマートかもしれません)。最後は、Transformerの出力をFFNに通して、各クラスの予測値とバウンディングボックスの位置を出力します。
文字だけでは中々理解することが難しいかもしれません。本来は図などを利用して細かく説明すべきですが、今回はあくまで初心者向けの記事になりますので、アーキテクチャの①~⑤がソースコードのこの辺りと対応しているんだな、程度に理解していただければ幸いです。詳細を説明するのは次回以降にする予定です。
ともあれ、これでDETRのモデルは定義できました。次は、DETRdemo
モデルクラスに対応した学習済み重みパラメータをダウンロードして適用してみましょう。公式の学習済みモデルや学習済み重みパラメータはModel Zooにアップロードされており、Torch Hubからダウンロードすることができます。
model = DETRdemo(num_classes=91) state_dict = torch.hub.load_state_dict_from_url( url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth', map_location='cpu', check_hash=True) model.load_state_dict(state_dict) model.eval()
DETRdemo
の引数 num_classes
は、COCO datasetで学習されているパラメータをロードするため、91を設定します。torch.hub.load_state_dict_from_url
の引数 map_location
は、Torch Hubからロードしたモデルやパラメータをどのデバイスに展開するかどうかです。今回はCPU上に展開するため cpu
にしています。
次は、推論する画像をPIL Imageとして読み込みます。画像は独自に用意しても大丈夫ですが、今回はCOCO dataset 1 のvalidation画像をダウンロードして使います。
url = 'http://images.cocodataset.org/val2017/000000079229.jpg' im = Image.open(requests.get(url, stream=True).raw)
どのような画像なのかも確認してみましょう。
plt.imshow(im)
上記の画像が表示されれば成功です。
次は、インプット画像の前処理を行う transform
メソッドを用意します。具体的にはリサイズ(T.Resize
)、Tensor型に変換(T.ToTensor
)、Tensorを平均値と標準偏差で正規化(T.Normalize
)の3つの処理を実行します。
transform = T.Compose([ T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])
もう少し transform
メソッドの中身を細かく解説します。
T.Compose
… 引数で渡された複数の処理を、先頭から順に処理。T.Resize
… 画像のリサイズ。引数がint型の場合、画像の縦横で小さい方が引数のサイズになるように、縦横比固定でリサイズされます。例えば、(640, 480)の画像にT.Resize(800)
を実行した場合、(1066, 800)にリサイズされます。T.ToTensor
… PIL Image(W, H, C)を、Tensor(C, W, H)に変換。また、[0~255]から[0~1]にスケーリング。詳細はtorchvisionの公式マニュアル 11 をご参照下さい。T.Normalize
… Tensorを平均値と標準偏差で正規化。引数の数値は、torchvisionで提供されている学習済みモデル(今回はResNet-50)を使う場合は、上記ソースコードの引数の値となります。torchvisionで提供されるモデルはImageNetデータセットで学習したモデルになるので、そのデータセット全体の平均値と標準偏差が[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
であるためです。もし自作のデータセットで一から学習する場合は、別途平均値と標準偏差を計算し、それを引数に設定する必要があります。詳細はPyTorchの公式マニュアル 12 をご参照下さい。
では最後に、PIL Imageに前処理を適用し、モデルの入力として渡します。PyTorchのモデルの入力はミニバッチでの入力を前提としているため、unsqueeze(0)
を使い3次元から4次元に変換しています。
img = transform(im).unsqueeze(0) outputs = model(img)
これでDETRによる推論は完了です。outputs
が推論結果になります。ただし、本章では簡略化している処理があるため、後述の 4-2. よりも精度は高くないと思われます。本章はあくまでアーキテクチャと紐付けて理解を深めてもらうことが目的ですので、実際に推論処理をやりたい場合は 4-2. の処理で行う方が良いでしょう。
また、推論結果がどのような構造になっているのか気になるかと思いますので、outputs
の中身を確認します。
print(outputs) # ⇒{'pred_logits': tensor([[[-18.3012, -0.1064, -4.2331, ..., -14.3635, -9.1968, 11.6259], # [-18.5986, -0.5623, -4.8728, ..., -12.7535, -9.2969, 11.7771], # [-18.1445, 0.1714, -4.0015, ..., -13.6043, -8.5909, 11.3812], # ..., # [-19.3517, -1.9142, -1.6362, ..., -15.5901, -11.3513, 11.7548], # [-19.3750, 2.9712, -7.0269, ..., -8.0860, -6.6140, 12.1833], # [-19.4573, -0.9082, -4.8223, ..., -15.2501, -11.0872, 12.1151]]], # grad_fn=<SelectBackward>), # 'pred_boxes': tensor([[[0.6088, 0.4622, 0.0395, 0.0446], # [0.5544, 0.4723, 0.1244, 0.0816], # [0.6013, 0.4728, 0.0484, 0.0717], # ..., # [0.3543, 0.7365, 0.3642, 0.1271], # [0.6205, 0.4677, 0.2931, 0.3890], # [0.4045, 0.6042, 0.4610, 0.3697]]], grad_fn=<SelectBackward>)}
outputs
の中身は、pred_logits
と pred_boxes
の2つの要素で構成されています。pred_logits
は各クラスの予測値です。92クラス(COCO datasetの91クラス+“no object”)の予測値が、object queries
の数N個(学習済みモデルではN=100)返されます。pred_boxes
はバウンディングボックスの位置です。(center_x, center_y, width, height)
の4つの座標がN個(N=100)返されます。
4-2.学習済みモデルをダウンロードして推論
さて、前章でDETRのアーキテクチャと紐付けながらソースコードを解説しましたので、DETRについてある程度理解できたかと思います。しかし、前章では一部簡略化された処理がありました。本章ではオリジナルの学習済みモデル全体をダウンロードして推論する方法を紹介いたします。
まずは 4-1. と同様に、必要なパッケージをインポートします。
import torch, torchvision import torchvision.transforms as T import requests import matplotlib.pyplot as plt from PIL import Image print(torch.__version__) # 1.9.0+cu102 print(torchvision.__version__) # 0.10.0+cu102
次に、学習済みモデルをダウンロードします。今回は推論を高速化するために、最もシンプルなモデルである DETR-R50 を使います。もし他のモデルを使用したい場合は、第2引数を変更して下さい。他モデルの種類は、DETRのGitHubページ 8 の「Model Zoo」の箇所に記載されています。DETR-DC5は性能比較でも紹介しましたが、特徴マップの解像度が2倍のDETRになります(最初の層のstrideを2→1に変更したもの)。性能は向上していますが、計算量は約2倍になっています。
# DETR_R101の場合は'detr_resnet101'、DETR-DC5_R50の場合は'detr_resnet50_dc5'のように指定 model = torch.hub.load('facebookresearch/detr', 'detr_resnet50_dc5', pretrained=True) model.eval()
上記ソースコードを実行してエラーが出なければ、モデルの準備は完了です。モデル全体をダウンロードしているので、これだけで推論処理の準備が終わりました。
次に推論処理を行っていきますが、ソースコードは 4-1. と同様なので、説明は割愛いたします。
url = 'http://images.cocodataset.org/val2017/000000079229.jpg' im = Image.open(requests.get(url, stream=True).raw) transform = T.Compose([ T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) img = transform(im).unsqueeze(0) outputs = model(img)
これで推論処理は完了です。4-1. と同様に outputs
が得られました。それでは、推論結果がどのようなものか、次章で可視化してみましょう。
4-3.推論結果の可視化
本章では、 4-1. または 4-2. で得られた outputs
を使い、画像とバウンディングボックスを表示する方法について説明いたします。基本はチュートリアル用のノートブック「DETR’s hands on Colab Notebook」 そのままですので、詳細を確認したい方はそちらをご覧下さい。
まず、COCOのクラスとカラーマップを定義します。学習済みモデルは COCO 2017 Dataset
を使って学習しているので、クラスは COCO 2017 Dataset
のものを決め打ちで設定します。カラーマップは好きな値に変更しても問題ありません。
CLASSES = [ 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ] COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
次に、outputs
を表示する前処理として、座標変換関数とバウンディングボックスのリスケール関数を定義します。4-1. の最後にも記載しましたが、outputs['pred_boxes']
に座標が格納されていますが、(center_x, center_y, width, height)
の形で保持されています。そのままでは表示する際に使いにくいので、(xmin, ymin, xmax, ymax)
の形に座標変換します。バウンディングボックスのリスケールでは、モデル適用時にインプット画像をリサイズしたので、当然 outputs['pred_boxes']
もリサイズ後のスケールになっています。そのため、元画像のサイズに合うようにリスケールします。細かい処理内容についてはコメントでも記載しましたので、より理解を深めたい場合はコメントをご参照下さい。
def box_cxcywh_to_xyxy(x): """ (center_x, center_y, width, height)から(xmin, ymin, xmax, ymax)に座標変換 """ # unbind(1)でTensor次元を削除 # (center_x, center_y, width, height)*N → (center_x*N, center_y*N, width*N, height*N) x_c, y_c, w, h = x.unbind(1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] # (center_x, center_y, width, height)*N の形に戻す return torch.stack(b, dim=1) def rescale_bboxes(out_bbox, size): """ バウンディングボックスのリスケール """ img_w, img_h = size b = box_cxcywh_to_xyxy(out_bbox) # バウンディングボックスの[0~1]から元画像の大きさにリスケール b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) return b
これで前処理を行う関数の定義は完了です。ついでにMatplotlibを使った画像とバウンディングボックス表示関数も定義しましょう。
def plot_results(pil_img, prob, boxes): """ 画像とバウンディングボックスの表示 """ plt.figure(figsize=(16, 10)) plt.imshow(pil_img) ax = plt.gca() colors = COLORS * 100 if prob is not None and boxes is not None: for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): # バウンディングボックスの表示 ax.add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, fill=False, color=c, linewidth=3)) # 最大の予測値を持っているクラスのindexを取得 cl = p.argmax() # クラス名と予測値の表示 text = f'{CLASSES[cl]}: {p[cl]:0.2f}' ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5)) plt.axis('off') plt.show()
これで関数の準備は完了です。それでは実際に、4-1. または 4-2. で得られた outputs
を使って画像とバウンディングボックスの表示をしてみましょう。また、表示するクラスの予測値の閾値は0.9で実施します。
threshold = 0.9 # no-objectを除いた91クラスでsoftmax probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] # 91クラスの中で一番大きい予測値を取得*N個して、閾値を超えればTrue、それ以下だとFalse keep = probas.max(-1).values > threshold # バウンディングボックスの前処理 bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) # 画像とバウンディングボックスの表示 plot_results(im, probas[keep], bboxes_scaled)
上記のように画像が表示されたでしょうか。DETRのライブラリ内には表示メソッドを持っていないので自分で実装する必要がありますが、上記で説明したソースコードで表示することができます。
5.おわりに
今回はなるべく簡単に、既存の物体検出とDETRの違い、DETRの原理について解説しました。また、DETRでの推論処理方法についてソースコードを交えながらご紹介しました。今回はインプット画像はCOCO datasetのものを使いましたが、自前で用意した画像をインプットにして物体検出してみて、どの程度精度が出るのか見てみるのも面白いかもしれません。
次回は、もう少し踏み込んでDETRの原理を詳細に解説しながら、自然言語処理のTransformerとの比較をする予定です。また、今回のようにColabを使ったFine-Tuning方法も解説する予定です。
-
https://ai.facebook.com/blog/end-to-end-object-detection-with-transformers/ ↩
-
https://colab.research.google.com/notebooks/welcome.ipynb?hl=ja ↩
-
https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py ↩
-
https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor ↩