従来のAI技術では、画像処理分野では画像(視覚情報)、音声認識分野では音声(聴覚情報)など、1つの情報(モダリティ)を使うことが一般的です。しかし、人間など動物は単一のモダリティに基づいて思考や判断をしておらず、人間のように複数モダリティを利用する技術はAI技術の発展に繋がると考えられます。そのため、複数のモダリティを統合して扱うAI技術として登場したのがマルチモーダルAIです。複数のモダリティを活用することで、より正確に予測や推論ができるようになったり、今まで解決できなかった問題を解くことができるようになることが期待されており、様々な研究や論文が発表されている分野です。今回は、『Transformerを使った初めての物体検出「DETR」』で紹介した物体検出を行うDETRと、自然言語処理のRoBERTaを組み合わせた「MDETR」について解説します。また、MDETRで解決タスクとGoogle Colaboratory(以下Colab)での推論方法を紹介します。
1.はじめに
今回の記事では、2021年4月に公開されたマルチモーダル推論モデルである MDETR (Modulated Detection for End-to-End Multi-Modal Understanding) 1 について紹介していきます。MDETRでは画像処理を行うDETR 2 と、自然言語処理を行うRoBERTa 3 を組み合わせたモデルとなっており、MDETRでは様々なタスクを解くことができます。例えば、テキストで記述した内容のみを検出する物体検出タスクである ①Referring Expression Comprehension(参照表現理解) 、同じくテキストの物体を検出してセグメンテーションするタスクである ②Referring Expression Segmentation(参照表現セグメンテーション) 、画像に対する質問応答タスクである ③Visual Question Answering(視覚的質問応答) などのタスクを解決できます。詳細は後述の 3-3.MDETRで解決できるタスク でご紹介します。以下に、COCO dataset 4 を使った各タスクの推論例を示します。
MDETRの何がすごいかと言うと、従来のマルチモーダル推論モデルでは出来なかった画像とテキストの関連性(アライメント)を学習している点です。そのため、「灰色の象」と「ピンク色のボール」などを学習していれば、学習時に経験していない「ピンク色の象」といった新しい組み合わせも検出することができます。
また、本記事ではDETRやRoBERTaの解説は省略しますので、DETRについては弊社記事の 『Transformerを使った初めての物体検出「DETR」』、BERTやRoBERTaについてはmm0824氏の「RoBERTaを理解する」や、弊社記事の 「はじめての自然言語処理 第3回 BERTを用いた自然言語処理における転移学習」 などをご確認下さい。MDETR以前の従来手法についてもいくつか出てきますが、そちらの解説も割愛させていただきます。論文への引用はなるべく追記していますので、興味のある方はそちらをご確認いただければ幸いです。
2.マルチモーダル深層学習
MDETRの解説に入る前に、マルチモーダル深層学習分野ではどのような研究がされてきたのかについて紹介し、2-1. ではどのようなユースケースがあるのかについて簡単にご紹介します。なお、本章は2017年に公開されたマルチモーダル深層学習のサーベイ論文 5 を参考にした内容となります。
サーベイ論文 5 では、マルチモーダル深層学習の問題設定には色々なタスクがあり、これまでに数多くの研究が行われてきた結果、大まかに5つのカテゴリに分類されるとしています。
①Representation (表現)
マルチモーダルのデータをどのように表現したり要約するかについて解決するタスク。
(例) テキスト情報と音声信号データを同一空間で扱えるか、など。
②Translation (変換)
あるモダリティのデータを別のモダリティのデータに変換するタスク。
(例) 画像から説明文を生成する、など。
③Alignment (アライメント)
複数のモダリティ間の直接的な関係を明らかにするタスク。
(例) 料理を作る動画の各シーンを正確に並び替えるためにレシピ情報(テキスト)と画像中の情報を結びつける、など。
④Fusion (融合)
ある予測をするために複数のモダリティの情報を利用するタスク。
(例) スピーチ内容を正確に予測するためにスピーチ音声と話者の口の動き(動画)を使う、など。
⑤Co-learning (共学習)
あるモダリティ内で作られた推論モデルやベクトル表現などを別のモダリティに転移させるタスク。
(例) Zero-shot learningなど。
以上の5つのカテゴリに分けることができます。MDETRを上記カテゴリに当てはめると、「③Alignment (アライメント)」 または 「④Fusion (融合)」 に相当すると思われます。今回は、その中でも 「④Fusion (融合)」 をピックアップして詳細を解説します。他カテゴリの詳細について知りたい場合はサーベイ論文 5 にまとめられていますので、興味のある方はご確認下さい。
2-1.Fusion(融合)の解説
Fusion(融合)は前述の通り、複数のモダリティから推論や予測をするタスクです。マルチモーダル深層学習の中では最も歴史が長いカテゴリで、様々な研究やアプローチが提案されています。例えば、複数モダリティのデータを用いることでよりロバストな推論を行うことができたり、一部分のモダリティのデータが欠けていても推論ができるような研究が行われています。一方で、異なるモダリティ間で互いに冗長性があり、次元や構造等の形式も大きく異なることから、単純に結合するだけで推論はできません。何らかの方法で適切な同一空間に写像する等、工夫が必要になります。
応用分野としては、Audio Visual Speech Recognition(AVSR)、Visual Question Answering(VQA)、感情認識、医用画像解析、マルチメディアイベント検知等があります。以下で、いくつかの応用分野について紹介します。
1つ目は、Audio Visual Speech Recognition(AVSR)についてです。これは、音声と画像(主に唇の動き)の情報を用いて精度の高い音声認識を行うタスクです。マルチモーダル学習の最初期の研究で、1986年頃から研究が始まっています 6 。当初はディープラーニングの技術は使われておらず、パターン認識で唇の画像からテキスト変換を行うLip Readingをしていました。1994年の「Visual Speech Recognition with Stochastic Network」7 という論文では確率的ネットワークが使われる等しましたが、2011年に公開された「Multimodal Deep Learning」8 という論文辺りから、マルチモーダル学習にディープラーニングが使われ始めました。下図がディープラーニングを使ったAVSRの実装例になります。音声にはAutoEncoder、画像にはCNNを使ってマルチモーダル深層学習をしていることが分かるかと思います。また、2021年に公開された最新の研究である「END-TO-END AUDIO-VISUAL SPEECH RECOGNITION WITH CONFORMERS」9 では、ConFormer 10 を使っているようです。
2つ目は、Visual Question Answering(VQA)について紹介します。VQAとは、ある画像とその画像に対する質問文から、正しい答えを導き出すタスクです。マルチモーダルの中でも画像とテキストの情報を使っているので、Vision and Languageと呼ばれる分野になります。VQAタスクは2016年のCVPRでコンペティションが行われたことがきっかけで話題となり、同年に公開された「VQA: Visual Question Answering」11 では下図のように、画像を処理するCNNの出力と、テキストを処理するLSTM 12 ネットワークの出力を掛け合わせてマルチモーダル表現空間を形成しています。自然言語でTransformerを使ったBERTが流行り始めた2018年以降には、ViLBERT 13、LXMERT 14、VL-BERT 15、UNITER 16 等のBERT型のマルチモーダル深層学習が提案されています。ちなみにViLBERT等の手法では、VQAタスク以外にもVisual Commonsense Reasoning(ある画像とその画像に関する質問文から、正しい答えとその理由を導出)、Referring Expression Comprehension(文章に合致しそうな物体の検出)、Caption-Based Image Retrieval(テキストに対して適した画像を検索)といったタスクを解くこともできます。
3.MDETR
マルチモーダル深層学習について説明したところで、改めて MDETR (Modulated Detection for End-to-End Multi-Modal Understanding) 1 について解説します。MDETRは2021年4月に公開されました。ニューヨーク大学とFacebook AIが共同研究しているもので、DETR 2 をベースにしたエンドツーエンドのマルチモーダル推論モデルになります。3-1. ではMDETRの概要や特徴について解説し、3-2. ではアーキテクチャの説明、3-3. ではMDETRで解決することのタスクについて解説していきます。
3-1.MDETRの特徴
MDETRの特徴として、モデルの初期段階で画像とテキストの2つのモダリティをFusion(融合)させて共同で推論していることが挙げられます。前章で紹介した LXMERT 14 等の従来手法では、あらかじめ事前学習させた物体検出モデルを用いて画像から物体を検出し、検出できた物体の特徴量とテキストの特徴量を使ってマルチモーダル推論を行っています。下図は LXMERT 14 のアーキテクチャ図になりますが、画像側はFaster-RCNN 17 を使って検出した物体の特徴量を画像側のエンコーダに通し、テキスト側は埋め込みベクトルに変換した後にテキスト側のエンコーダに通しています。そしてそれぞれの出力を使ってクロスアテンションを繰り返すことによりマルチモーダル推論を行います。
しかしこのようなパイプライン構造の場合は、下流モデル(上図のCross-Modality Encoder)がアクセスできるのはFaster-RCNN 17 で検出できた物体のみで画像全体ではない、という制限があります。テキスト側で自由形式で色々な表現を入力しても、画像側で検出できていない物体に対しては何をやっても認識できないことが課題として挙げられます。(そのため、物体検出タスクでなるべく多くのオブジェクトを検出できるような論文 18 も考案されています。)
MDETRではこれらの問題を解決するために、物体検出側とテキスト側をエンドツーエンド構造にして共同で推論できるような仕組みを考案しています。画像側のDETR 2 の構造を基本にして、Transformerの手前にテキスト側をRoBERTa 3 で処理したものを接ぎ木しているようなアーキテクチャになっています(実際のアーキテクチャ図は3-2.に掲載しています)。このような構造になっているため自由形式のテキストも認識して物体検出できるようになり、下図のように学習していない “A pink elephant” (ピンク色の象)も検出することができます。
3-2.MDETRのアーキテクチャ
本節では、MDETRのアーキテクチャについて解説します。なお、具体的な学習方法や出力構造等は今回は解説せず、別記事で解説予定です。
MDETRのアーキテクチャ図は上図のようになっており、上図の例ではインプットとして画像とテキストをセットで入力し、アウトプットとしてテキストに登場する物体のみを検出しています。前節でも説明した通り、青色の点線部分のDETR 2 を基本にして、テキスト側のRoBERTaを接ぎ木しているような構造になっています。DETRの説明は 『Transformerを使った初めての物体検出「DETR」』 で解説していますので、詳細はそちらをご参照下さい。
DETRと違う点はいくつかあります。まず1点目は、Transformerに画像の特徴マップ(ベクトルのシーケンス)を入力する前にテキスト情報をConcatしている点です。画像とテキストをConcatする時のイメージ図は、以下のようになります。
画像側ではDETR 2 と同様に、画像をCNNで特徴マップに変換してreshapeを行います。上図のように画像をブロックごとに分割してベクトル化するイメージになります。埋め込みベクトルの次元数をdとすると、shapeは (ブロック数, d)
となります。テキスト側では、RoBERTa 3 を使ってテキストをエンコードすることでベクトルのシーケンスを生成します。画像の特徴マップとConcatしたいため、テキスト側も画像と同じ埋め込みベクトルの次元数dになるように変換する必要があります。その場合、shapeは (シーケンス長, d)
となります。最後に画像とテキストの特徴マップをシーケンス次元で連結して、((ブロック数+シーケンス長), d)
となるように単一シーケンスにしています。
2点目は、学習が「事前学習」(Pre-training)と「ファインチューニング」(Fine-Tuning) の二段階で構成されており、「事前学習」では色々なデータセットが必要になる点です。
MDETRでは「ファインチューニング」で様々なタスクに特化したモデルを作ることができるため、「事前学習」にはBERT等と同じく大量データを使って画像とテキストの組み合わせの特性を学習させる必要があります。BERTの場合はMasked Language Modeling(MLM)
というマスクで文章の一部を欠損させる処理を使うことで教師なし学習を行っており、Wikipediaのダンプ等の大量の文章データを利用していました。しかし、MDETRの場合は教師あり学習であるため、画像とアノテーションとテキストが揃った大量データを用意する必要がありますが、そのようなデータセットはありません。また、MDETRは前述の通り自由形式のテキストを認識することができるため、色々な種類のデータセットで幅広く「事前学習」を行った方が性能向上に繋がると考えられます。
以上の理由より、MDETRの「事前学習」では色々な種類のデータセットを組み合わせて、大量データとして学習させています。例えば、画像+アノテーションとそれを説明するテキストをセットにしたデータセット 19 20 や、画像に対して質問と回答を行うデータセット 21 等を「事前学習」で使っています。このようなデータセットを使うことで、下図の左側のように「事前学習」では画像とアノテーションとテキストの関連性を学習しています。そして下図の右側の「ファインチューニング」では、テキストの代表的な物体(下図の例では猫)に対してアノテーションするように学習しています。「ファインチューニング」は代表的な物体を検出するタスク以外にも様々なタスクを解くことができますので、詳細を次節で紹介します。
なお、具体的なデータセット名や事前学習方法については、別記事で解説予定です。
3-3.MDETRで解決できるタスク
本節では、MDETRで解くことのできるタスクについて解説します。MDETRの論文 1 ではこのタスクのことを「Downstream Tasks (下流タスク)」と表現しており、「Referring Expression Comprehension(参照表現理解)」「Referring Expression Segmentation(参照表現セグメンテーション)」「Visual Question Answering(視覚的質問応答)」といった3つの下流タスクがあります。もう1つ「Phrase Grounding(フレーズグラウンディング)」という下流タスクもあるようですが、MDETRのGitHubページ 22 には推論プログラムが無いため本記事では省略します。興味のある方は調べてみて下さい。また、今回の記事では各タスクの説明だけにとどめておき、具体的なファインチューニング方法等は別記事で説明予定です。
①Referring Expression Comprehension(参照表現理解)
「Referring Expression Comprehension(参照表現理解)」は、入力テキストに対応した物体検出を行うタスクです。
本タスクのインプットは上図の通り画像とテキストで、アウトプットはテキスト全体で表している物体をバウンディングボックスとして返します。上図の例では「A green umbrella.」と短いテキストですが、「The woman wearing a blue dress standing next to rose bush.(青いドレスを着た女性がバラの木の横に立っています)」のような長文のテキストでも検出することができます。長文の場合は“the woman"、"a blue dress"、"rose bush"など色々な物体が画像にもテキストにも存在しますが、アウトプットでは主要な物体である "the woman"のバウンディングボックスのみが出力されます。また、従来の手法と比較した結果が以下の図になります。
上図は参照表現理解タスクの精度を従来手法と比較している表であり、「Method」では手法の名称、「Detection backbone」では画像側のCNNで使われているネットワークの種類がResNet-101 か EfficientNet-B3 かを表しています。「Pre-training image data」では事前学習に使用したデータセットを表記しており、重要なのはデータセットのサイズです。MDETRは200kであり、従来手法と比べるとサイズが小さいことが分かります。「RefCOCO」「RefCOCO+」「RefCOCOg」では、それぞれのデータセットで「val」「testA」「testB」を使って推論を行った結果の正答率です。全項目において従来手法よりも性能が向上していることが分かります。
②Referring Expression Segmentation(参照表現セグメンテーション)
「Referring Expression Segmentation(参照表現セグメンテーション)」は、入力テキストに対応した物体のセグメンテーションを行うタスクです。
本タスクのインプットは前節と同じく画像とテキストで、アウトプットはテキスト全体で表している物体のセグメンテーションを行います。本タスクのインプットとなるテキストですが、比較的短いテキストを使う必要があり、前節のような長文を使うことができません。その理由は、本タスクで使ったデータセット 23 のテキストが非常に短い1文構成になっているためです(データセットの詳細については本記事では省略)。また、従来手法と比較した結果を以下に示します。
上図は参照表現セグメンテーションの精度を従来手法と比較している表であり、「Method」と「Backbone」は前節と同じく手法名称とネットワークの種類になります。「PhraseCut」の「M-IoU」では推論結果と正解のセグメンテーションの平均IoU (推論結果と正解がどれくらい重なっているか)を表しています。「Pr@0.5
」「Pr@0.7
」「Pr@0.9
」では、それぞれ閾値0.5、0.7、0.9よりも高いIoUを持つ場合に成功とした場合の結果です。参照表現セグメンテーションでも、前節と同じく全項目で性能が向上しています。
③Visual Question Answering(視覚的質問応答)
「Visual Question Answering(視覚的質問応答)」は前章でも紹介した通り、ある画像とその画像に対する質問文から正しい答えを導き出すタスクです。
本タスクのインプットは画像と質問文で、アウトプットは質問に対する回答をテキストとして出力します。他タスクと異なりアーキテクチャ等にも工夫がされていますが、今回の記事では説明を省略し別記事で説明予定です。また、従来手法と比較した結果は以下の表の通りです。
上図はVQAタスクの精度を従来手法と比較している表であり、「Method」では手法の名称、「Pre-training img data」では事前学習に使用したデータセットを表記しています。「Test-dev」と「Test-std」はGQA dataset 21 の中のテスト用のデータセット群であり、それぞれで正答率を求めています。従来手法と比較すると、同程度の事前学習データセット量を使っているLXMERT 14 やVL-T5 24 の精度を上回っているだけでなく、より多くのデータを使っている OSCAR 25 よりも精度が上回っています。VinVL 18 には精度が負けていますが、MDETRでもデータセット量を増やすことで精度向上が期待できるため、もっとデータセットのサイズを大きくすればVinVL 18 の精度に近づくまたは超えることが考えられます。
4.MDETRの推論
それでは、実際にMDETRでの推論を試してみます。今回は学習は行わず、公式で用意されている学習済みモデルを利用した推論方法について、前章で紹介した「Referring Expression Comprehension(参照表現理解)」「Referring Expression Segmentation(参照表現セグメンテーション)」「Visual Question Answering(視覚的質問応答)」をタスクごとに説明します。なお、ソースコードはMDETRのGitHubページ 22 が公開している「MDETR_demo.ipynb」 を参考にさせていただいております。
ソースコードは、Colab上に順番にコピー&ペーストすれば動作するように記載していきます。4-1. で各タスクで必要となる共通処理を紹介し、4-2.、4-3.、4-4. で各タスクの詳細な推論処理について紹介していきます。また、MDETRでは推論時でもGPUモードにする必要があります。以下の画像のようにColabのメニューの「ランタイム」→「ランタイムのタイプを変更」から「GPU」を選択して、GPUモードに変更して下さい。
4-1.必要なパッケージの準備と前処理・後処理の定義
ここでは、各タスク共通で必要なパッケージのインストールやインポート作業と、前処理・後処理を行うメソッドを事前に定義します。前述の通り、各タスクの推論を実行する前に必要となるため、Colabにコピーして実行する際には必ず 4-1. のソースコードを実行するようにお願いいたします。
それでは最初に、Colabにインストールされていないパッケージのインストールを行います。
!pip install timm transformers
上記パッケージをインストールしないと、モデルのロード時にエラーになってしまいます。timm
パッケージはPyTorch用の画像系モデルや最適化手法が実装されており、MDETRでは画像側のCNNにEfficientNet 26 等を使うために必要となります。transformers
パッケージは米国のHugging Face社 27 が公開しているTransformerを使っている最先端の自然言語処理を行うためのライブラリで、RoBERTa 3 等を利用するため必要となります。
インストール完了後、必要なパッケージやライブラリのインポートを行います。
import torch from PIL import Image import requests import json import torchvision.transforms as T import matplotlib.pyplot as plt import numpy as np import torch.nn.functional as F from skimage.measure import find_contours from matplotlib.patches import Polygon from collections import defaultdict torch.set_grad_enabled(False);
次に、結果表示用のカラーマップ、インプット画像の前処理を行うtransformメソッド、座標変換関数とバウンディングボックスのリスケール関数を定義します。これらは 『Transformerを使った初めての物体検出「DETR」』 で解説したソースコードそのままですので、詳細は前回記事をご確認下さい。
# カラーマップ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] # インプット画像の前処理 transform = T.Compose([ T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def box_cxcywh_to_xyxy(x): """ (center_x, center_y, width, height)から(xmin, ymin, xmax, ymax)に座標変換 """ x_c, y_c, w, h = x.unbind(1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=1) def rescale_bboxes(out_bbox, size): """ バウンディングボックスのリスケール """ img_w, img_h = size b = box_cxcywh_to_xyxy(out_bbox) b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) return b
次はDETRには無かったメソッドで、セグメンテーションした領域を塗り潰したマスク画像を、元の画像に適用する処理です。引数でマスク画像の透過値も変更することができます。
def apply_mask(image, mask, color, alpha=0.5): """ セグメンテーション用の領域を塗りつぶしたマスクを画像に適用 Parameters ---------- image : numpy.ndarray 適用元の画像 mask : tensor 適用するマスクの領域座標 color : list カラーマップ alpha : float 透過値 """ for c in range(3): image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - alpha) + alpha * color[c] * 255, image[:, :, c]) return image
最後に、Matplotlibを使った結果画像を表示する関数も定義します。各タスクの結果表示関数から呼び出せるように、引数には画像イメージ、予測値のリスト、バウンディングボックス座標、物体のラベルリスト、セグメンテーション用のマスクデータのリストを指定できるようになっています。なお、この処理は 『Transformerを使った初めての物体検出「DETR」』 で解説した plot_results
メソッドにマスク処理を追加しているものです。変更点にはコメントを追記していますので、併せてご確認下さい。
def plot_results(pil_img, scores, boxes, labels, masks=None): """ 結果表示 Parameters ---------- pil_img : PIL.Image 画像 scores : list 検出された物体の予測値のリスト boxes : list 検出された物体のバウンディングボックス座標(center_x, center_y, width, height)のリスト labels : list 検出された物体のラベルのリスト masks : list セグメンテーション用のマスクのリスト """ plt.figure(figsize=(16,10)) # PIL.Imageをnumpy.ndarrayに変換 np_image = np.array(pil_img) ax = plt.gca() colors = COLORS * 100 if masks is None: # マスクが無い場合は、len(scores)のNoneのリストで埋める masks = [None for _ in range(len(scores))] # リストの長さが違う場合、例外をスロー assert len(scores) == len(boxes) == len(labels) == len(masks) for s, (xmin, ymin, xmax, ymax), l, mask, c in zip(scores, boxes.tolist(), labels, masks, colors): ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3)) text = f'{l}: {s:0.2f}' ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8)) if mask is None: continue # マスク適用 np_image = apply_mask(np_image, mask, c) # マスク部分の輪郭線を描画。maskデータはバイナリなので、中間の0.5の位置で輪郭線を描画する。 padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8) padded_mask[1:-1, 1:-1] = mask contours = find_contours(padded_mask, 0.5) for verts in contours: # (y, x)を(x, y)に反転 verts = np.fliplr(verts) - 1 # 輪郭線の内部を塗り潰し p = Polygon(verts, facecolor="none", edgecolor=c) ax.add_patch(p) plt.imshow(np_image) plt.axis('off') plt.show()
これで各タスクで共通で利用する関数の準備は完了です。次節から、各タスクでの推論方法を紹介いたします。
4-2.Referring Expression Comprehension(参照表現理解)の推論
本節では、入力文章に対応した物体検出を行う「Referring Expression Comprehension(参照表現理解)」について紹介します。まず最初に学習済みモデルをダウンロードします。DETRと同じく、公式の学習済みモデルはModel Zooにアップロードされており、Torch Hubからダウンロードすることができます。画像側のCNNとしてResNet-101、EfficientNet-B3、EfficientNet-B5を使っている学習済みモデルがありますが、今回は最も性能の良かったEfficientNet-B5を使ったモデルを利用しています。
model = torch.hub.load('ashkamath/mdetr:main', 'mdetr_efficientnetB5', pretrained=True) model = model.cuda() model.eval();
モデルサイズが2.53GB程度あるため、実行に時間がかかるかもしれません。モデルのロードが終わりましたら、次は前節で定義した plot_results
を使いつつ、本タスク用の結果表示メソッドを定義します。基本的な流れは 『Transformerを使った初めての物体検出「DETR」』 で紹介した推論方法と同じですが、テキスト側の処理が追加されています。変更点にはコメントを追記しています。
def plot_inference(im, caption): """ 参照表現理解用結果表示 Parameters ---------- im : PIL.Image 画像 caption : string 入力文字 """ img = transform(im).unsqueeze(0).cuda() # モデル伝搬 # Encoder側のみ処理 memory_cache = model(img, [caption], encode_and_save=True) # Decoder側のみ処理 outputs = model(img, [caption], encode_and_save=False, memory_cache=memory_cache) probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu() keep = (probas > 0.7).cpu() bboxes_scaled = rescale_bboxes(outputs['pred_boxes'].cpu()[0, keep], im.size) # 各ボックスで予測されるテキストスパンを抽出 # 閾値以上の予測値のリストから、対応するトークンの場所を抽出 positive_tokens = (outputs["pred_logits"].cpu()[0, keep].softmax(-1) > 0.1).nonzero().tolist() # 空の辞書を作成 predicted_spans = defaultdict(str) for tok in positive_tokens: item, pos = tok if pos < 255: # Encoderで処理したテキスト情報から、テキストスパンを取得 span = memory_cache["tokenized"].token_to_chars(0, pos) # 入力文字(caption)から対応する文字を取得 predicted_spans[item] += " " + caption[span.start:span.end] labels = [predicted_spans[k] for k in sorted(list(predicted_spans.keys()))] plot_results(im, probas[keep], bboxes_scaled, labels)
主な変更点として、モデルへの伝搬処理をEncoder、Decoderに分けて2回実行されていることが挙げられます。参照表現理解用のモデルの構造上、テキスト側(RoBERTa)ではDecoderを使わないので、encode_and_save
のフラグによって別々にモデルを処理することができるようになっています。そのため、予測値等の値を取得する際には outputs["pred_logits"]
のように outputs
からアクセスし、テキスト情報を取得する際には memory_cache["tokenized"]
のように memory_cache
からアクセスする必要があります。
結果表示用のメソッドの定義が完了したので、次は推論用の画像を用意します。独自に用意した画像でも良いですが、今回はCOCO dataset 4 からダウンロードした画像を使います。実際に画像も表示してみましょう。5人の女性が傘を持っている画像が表示されるかと思います。
url = "http://images.cocodataset.org/val2017/000000281759.jpg" im = Image.open(requests.get(url, stream=True).raw) # 画像表示 plt.imshow(im)
それでは、いよいよ参照表現理解の推論を行います。テキスト部分には検出したい物体の説明文を英語で入力します。例えば、"A green umbrella.” と入力して結果を確認してみます。
plot_inference(im, "A green umbrella.")
上記の通り、一番左の女性が持っている緑色の傘のみを物体検出して、他の傘は検出されていないことが分かります。このように、MDETRではテキストで指定した物体のみを検出することができます。テキスト部分を “A pink striped umbrella.” とすれば右から2番目の女性の傘が、"A plain white umbrella.“ とすれば一番右の女性の傘が検出できます。また、検出できるのは傘だけではないので、"A car.” と入力すれば左奥に写っている車も物体検出できます。
4-3.Referring Expression Segmentation(参照表現セグメンテーション)の推論
次は参照表現セグメンテーションの推論方法についてです。基本的な流れは前節で同じで、「モデルのダウンロード」->「結果表示メソッドの定義」->「画像ダウンロード」->「推論」の順番で解説します。
# Torch Hubからモデル読込 model_pc = torch.hub.load('ashkamath/mdetr:main', 'mdetr_efficientnetB3_phrasecut', pretrained=True, return_postprocessor=False) model_pc = model_pc.cuda() model_pc.eval(); def plot_inference_segmentation(im, caption): """ 参照表現セグメンテーション用結果表示 Parameters ---------- im : PIL.Image 画像 caption : string 入力文字 """ img = transform(im).unsqueeze(0).cuda() # モデル伝搬 outputs = model_pc(img, [caption]) probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu() keep = (probas > 0.9).cpu() bboxes_scaled = rescale_bboxes(outputs['pred_boxes'].cpu()[0, keep], im.size) # マスク(tensor)を正しいサイズに補間(バイリニア補間) w, h = im.size masks = F.interpolate(outputs["pred_masks"], size=(h, w), mode="bilinear", align_corners=False) masks = masks.cpu()[0, keep].sigmoid() > 0.5 # Encoderからテキスト情報取得 tokenized = model_pc.detr.transformer.tokenizer.batch_encode_plus([caption], padding="longest", return_tensors="pt").to(img.device) # 各ボックスで予測されるテキストスパンを抽出 # 閾値以上の予測値のリストから、対応するトークンの場所を抽出 positive_tokens = (outputs["pred_logits"].cpu()[0, keep].softmax(-1) > 0.1).nonzero().tolist() predicted_spans = defaultdict(str) for tok in positive_tokens: item, pos = tok if pos < 255: # Encoderで処理したテキスト情報から、テキストスパンを取得 span = tokenized.token_to_chars(0, pos) # 入力文字(caption)から対応する文字を取得 predicted_spans [item] += " " + caption[span.start:span.end] labels = [predicted_spans [k] for k in sorted(list(predicted_spans .keys()))] plot_results(im, probas[keep], bboxes_scaled, labels, masks) return outputs # 画像をCOCO datasetから取得 url = "https://s3.us-east-1.amazonaws.com/images.cocodataset.org/val2017/000000218091.jpg" im2 = Image.open(requests.get(url, stream=True).raw) # 結果表示 outputs = plot_inference_segmentation(im2, "bed")
Torch Hubからのモデル読込方法、結果表示メソッドの中身や引数は、前節と同じです。前節との主な違いとしては2点あり、1点目はセグメンテーションのマスク画像をリスケールする処理が追加されています。リスケールした後、バイリニア補間によって滑らかなマスク画像を表現しています。2点目はモデル伝搬処理をEncoderとDecoderに分けておらず、1回で実行している点です。参照表現セグメンテーションタスクではPhraseCutデータセット 23 を使っており、他タスクと違い非常に短い1文構成となっているので、独自のモデル構造になっているためです。それに伴い、Encoderからテキスト情報を取得する際も memory_cache
モデルから取得するのではなく、model_pc.detr.transformer.tokenizer
にアクセスしてテキスト情報を取得しています。ちなみに前節のモデルでも同様にテキスト情報を取得することができ、model.transformer.tokenizer
にアクセスすれば memory_cache
を使わなくてもテキスト情報を取得できます。
画像のダウンロードでは、前節と同じくCOCO dataset 4 からホテルの一室を撮影した画像を使います。
今回は上記画像からベッドのみをセグメンテーションしてみたいので、plot_inference_segmentation
の引数に “bed” を設定して実行します。下記画像のように、ベッド部分のみがセグメンテーションされた画像を取得することができます。
4-4. Visual Question Answering(視覚的質問応答)の推論
最後のタスクとして、VQAの推論方法について解説します。こちらも基本的な流れは 4-2. と同じで、モデルの伝搬処理もEncoderとDecoderの2つに分かれています。 4-2. との違いは、モデルの出力として与えられる回答IDを単語に変換するマッピングファイルをダウンロードする必要があることです。以下のソースコードを実行することで、マッピングファイルをダウンロードできます。
answer2id_by_type = json.load(requests.get("https://nyu.box.com/shared/static/j4rnpo8ixn6v0iznno2pim6ffj3jyaj8.json", stream=True).raw) id2answerbytype = {} for ans_type in answer2id_by_type.keys(): curr_reversed_dict = {v: k for k, v in answer2id_by_type[ans_type].items()} id2answerbytype[ans_type] = curr_reversed_dict
ついでに中身も確認してみましょう。
id2answerbytype # ⇒ {'answer_attr': {0: 'no', # 1: 'large', # 2: 'yes', # 3: 'young', # ... # 402: 'unknown'}, # 'answer_cat': {0: 'remote control', # 1: 'television', # 2: 'horse', # 3: 'shirt', # ... # 677: 'unknown'}, # 'answer_global': {0: 'sidewalk', # 1: 'field', # 2: 'no', # 3: 'yes', # ... # 110: 'unknown'}, # 'answer_obj': {0: 'no', 1: 'yes', 2: 'unknown'}, # 'answer_rel': {0: 'no', # 1: 'yes', # 2: 'towel', # 3: 'man', # ... # 999: 'loaf', # ...}}
上記のように、辞書型でいくつかの要素を持っていることが分かります。要素は質問文に対してのタイプを表していて、学習で使われたGQAデータセット 21 には「obj」「attr」「rel」「global」「cat」の5つがあり、それぞれのタイプに対するマッピングデータになります。それぞれのタイプについて例を交えながら紹介します。
①obj
物体の存在に対しての質問文。
(例) Are there fences that are made of wood? (木製のフェンスはありますか?)
Is there a soccer ball in this scene? (このシーンの中にサッカーボールはありますか?)
②attr
物体の属性(attribute)に対しての質問文。
(例) What color is the hat? (帽子は何色ですか?)
What color is the fence in the top of the image? (画像上部のフェンスは何色ですか?)
③rel
物体の関係性(relation)に対しての質問文。
(例) What is the vegetable inside the bowl? (ボウルの中の野菜は何ですか?)
Are the drawers to the left or to the right of the table? (引き出しはテーブルの左側ですか、右側ですか?)
④global
画像全体に対しての質問文。
(例) Is it outdoors or indoors? (屋外ですか?屋内ですか?)
Which place is it? (どの場所ですか?)
⑤cat
物体が何のカテゴリ(category)なのかに対しての質問文。
(例) What is the name of the item of furniture that is not white? (白くない家具の名前を教えてください。)
What piece of furniture is blue? (青い家具は何ですか?)
回答のマッピングファイルが準備できたので、次は「モデルのダウンロード」->「結果表示メソッドの定義」->「画像ダウンロード」->「推論」の処理を行います。
# Torch Hubからモデル読込 model_qa = torch.hub.load('ashkamath/mdetr:main', 'mdetr_efficientnetB5_gqa', pretrained=True, return_postprocessor=False) model_qa = model_qa.cuda() model_qa.eval(); def plot_inference_qa(im, caption): img = transform(im).unsqueeze(0).cuda() # モデル伝搬 memory_cache = model_qa(img, [caption], encode_and_save=True) outputs = model_qa(img, [caption], encode_and_save=False, memory_cache=memory_cache) probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu() keep = (probas > 0.7).cpu() bboxes_scaled = rescale_bboxes(outputs['pred_boxes'].cpu()[0, keep], im.size) # 各ボックスで予測されるテキストスパンを抽出 positive_tokens = (outputs["pred_logits"].cpu()[0, keep].softmax(-1) > 0.1).nonzero().tolist() predicted_spans = defaultdict(str) for tok in positive_tokens: item, pos = tok if pos < 255: # Encoderで処理したテキスト情報から、テキストスパンを取得 span = memory_cache["tokenized"].token_to_chars(0, pos) # 入力文字(caption)から対応する文字を取得 predicted_spans [item] += " " + caption[span.start:span.end] labels = [predicted_spans [k] for k in sorted(list(predicted_spans .keys()))] plot_results(im, probas[keep], bboxes_scaled, labels) # 質問タイプの分類 type_conf, type_pred = outputs["pred_answer_type"].softmax(-1).max(-1) ans_type = type_pred.item() types = ["obj", "attr", "rel", "global", "cat"] # 質問タイプから一番マッチする回答をIDからマッピングして取得 ans_conf, ans = outputs[f"pred_answer_{types[ans_type]}"][0].softmax(-1).max(-1) answer = id2answerbytype[f"answer_{types[ans_type]}"][ans.item()] print(f"Predicted answer: {answer}\t confidence={round(100 * type_conf.item() * ans_conf.item(), 2)}")
末尾の質問タイプ分類と回答取得処理以外は、4-2. の処理と全く同じです。VQAのoutputsには pred_logits
と pred_boxes
以外に、pred_answer_obj
、pred_answer_attr
、pred_answer_rel
、pred_answer_global
、pred_answer_cat
という質問タイプごとの要素と、pred_answer_type
という入力文がどの質問タイプかを判定する要素を持っています。質問タイプ分類処理では、pred_answer_type
から質問タイプを取得して対応する 「pred_answer_●●●
」 を取得します。そこから回答IDを取得し、マッピングデータと紐付けることでIDから回答文に変換しています。
それではCOCO dataset 4 から画像をダウンロードしてVQAの推論を行います。画像は、下記のような駅構内のテーブルでノートPCを弄っている男性が写っている画像を使います。
url = "https://s3.us-east-1.amazonaws.com/images.cocodataset.org/val2017/000000076547.jpg" im3 = Image.open(requests.get(url, stream=True).raw) plot_inference_qa(im3, "What is on the table?")
質問文としては、"What is on the table?(テーブルの上には何がありますか?)“ を入力としています。質問タイプとしては "rel” にあたるので、outputs["pred_answer_rel"]
から回答IDを導出し、マッピングデータと紐付けて “laptop(ノートPC)” という回答が出力されます。なお、画像に表示されるバウンディングボックスは物体の関係性を求めるのに使われるだけで、出力として見るべき場所は画像下部に表示される文字列です。
5.おわりに
今回の記事では、マルチモーダル深層学習について簡単なご紹介とMDETRの基礎知識、各タスクの推論方法についてソースコードを交えながら解説しました。"umbrella" しか学習していないのに “green umbrella” を物体検出できる等、今まで出来なかった物体検出ができるようになったことをご確認いただけたかと思います。マルチモーダル界隈はまだ発展途上な分野ですので、これからも様々な手法が提案されていき様々なタスクが解決できるようになっていくことでしょう。
次回は未定ですが、MDETRで説明できていない箇所(アーキテクチャの詳細解説やソフトトークン予測、対比アライメント等)やFine-Tuning方法についての解説または別のマルチモーダル深層学習についてご紹介する予定です。