オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第19回 文章ベクトル化モデルと ResNet50 で CLIP 風のモデルを作る
オージス総研 技術部 データエンジニアリングセンター
鵜野 和也
2022年2月21日

今回は前回の文章ベクトル化モデルを使って CLIP 風のモデルを作ります。CLIP は画像とテキストを同じ多次元ベクトル空間にエンコードするモデルで、テキストによる画像検索や Zero shot での画像分類が可能です。簡素化された(非公式)実装が公開されているので、日本語で動かして見ましょう。

1. はじめに

今回は前回の文章ベクトル化モデルを使って CLIP 風のモデルを作ります。CLIP1 は OpenAI が発表した画像とテキストを同一多次元ベクトル空間にエンコードするモデルで、テキストによる画像検索や Zero shot での画像分類が可能です。

オリジナルの CLIP は (画像,テキスト) の 4 億ペアを使い、バッチサイズ 32,768 でスクラッチから学習したようなので、とても Colab では動かせません。また公式実装2も公開されていますが、こちらは事前学習済みのモデルを使って推論をする為のコードのようですね。

「日本語で動かしてみたいんだがー」っと探してみると Hugging Face Transformers をベースにした非公式な簡易実装3を見つけました。公式とは細かいところで差異があるでしょうが、今回はこちらで雰囲気を味わってみましょう。

2. CLIP

CLIP に関してはもう日本語の詳しい解説記事45が公開されているので、詳細はそちらを参照してもらうことにして、要点だけ簡単に紹介することにします。

CLIP の学習を端的に表しているのが以下の図になります。

clip

(画像, テキスト) のペアをバッチサイズ N で CLIP に投入し、テキストを TextEncoder で、画像を Image Encoder でそれぞれ埋め込み表現にします。 T1~TN がテキストの、 I1~IN が画像の埋め込み表現です。

これらを図のように突き合わせて、テキスト<=>画像間のコサイン類似度を計算します。対角線の青い網掛け部分が正しいペアになりますね。 この青い部分の類似度を最大化、その他の部分の類似度を最小化するように学習します。 「似ているものは似た表現に、異なるものは異なる表現にする」よう学習する Contrastive Representation Learning と呼ばれる手法ですね6

学習済みの CLIP を Zero-shot の画像分類に応用すると以下の図のようなイメージになります。

zero-shot

図中の (2) Create dataset classifier from label text のところで、"dog" やら “plane” やら分類する対象の単語を “a photo of a {object}” というテンプレートで分類対象の埋め込み表現(T1,T2…,TN)を作っておきます。

続いて図中の (3) Use for Zero-shot prediction で、分類対象の画像を埋め込み表現(I1)にして、(2) で作った分類対象の埋め込み表現(T1,T2…,TN)との類似度を計算します。そして最も類似した埋め込み表現(上図では “a photo of a dog”)を選ぶという訳です。

普通に多クラスの分類問題としてモデルを作ってしまうと、分類するクラスを変更する度に教師データを用意して再学習ですが、この方式なら自由自在ですね。

さて、ちょっと早いですが動かしてみましょうか。

3. CLIP 風のモデルを作る

いつものように、記事内のコードスニペットは、特に断りがない場合は Google Colaboratory (以下、Colab)で動かす想定にしています。ノートブックを開き、アクセラレータは GPU を選んで下さい。

コードは https://github.com/moein-shariatnia/OpenAI-CLIP をベースに少々修正を加えたものになります。

以下の手順は Colab で実行可能なように調整してます。筆者が実行した時点では、ギリギリのギリだったので、データを消す手順は飛ばさずに実行して下さい。

セットアップ

まずは必要なものをインストールしていきます。

!pip install sentence-transformers==2.0.0
!apt-get install mecab mecab-ipadic-utf8 python-mecab libmecab-dev
!pip install mecab-python3 fugashi ipadic
!pip install timm

次は画像と日本語のキャプションデータです。

データセットの準備

画像は有名な物体検出/セグメンテーションデータセットである MS-COCO 7 を使います。

!wget http://images.cocodataset.org/zips/train2014.zip
!unzip -q train2014.zip
!ls ./train2014 | wc -l
# 82783
!rm train2014.zip

!wget http://images.cocodataset.org/zips/val2014.zip
!unzip -q val2014.zip
!ls ./val2014 | wc -l
# 40504
!rm val2014.zip

キャプションはこの連載でも何度か使用させていただいている STAIR Captions8 を使いました。 STAIR Captions MS-COCO の画像に対して、日本語のキャプションを付与したデータセットになります。

!git clone https://github.com/STAIR-Lab-CIT/STAIR-captions
!tar zxvf STAIR-captions/stair_captions_v1.2.tar.gz

import json
with open("stair_captions_v1.2_val.json", "r") as f:
  json_data_val = json.load(f)
with open("stair_captions_v1.2_train.json", "r") as f:
  json_data_train = json.load(f)

キャプションの学習データの件数は、

len(json_data_train["images"])
# 82783

検証データの件数は

len(json_data_val["images"])
# 40504

となっています。データの構造はこんな感じですね。各画像には一意の id が割り当てられています。

json_data_train["images"][:2]
# [{'coco_url': 'http://mscoco.org/images/57870',
#   'date_captured': '2013-11-14 16:28:13',
#   'file_name': 'COCO_train2014_000000057870.jpg',
#   'flickr_url': 'http://farm4.staticflickr.com/3153/2970773875_164f0c0b83_z.jpg',
#   'height': 480,
#   'id': 57870,
#   'license': 5,
#   'width': 640},
#  {'coco_url': 'http://mscoco.org/images/384029',
#   'date_captured': '2013-11-14 16:29:45',
#   'file_name': 'COCO_train2014_000000384029.jpg',
#   'flickr_url': 'http://farm3.staticflickr.com/2422/3577229611_3a3235458a_z.jpg',
#   'height': 429,
#   'id': 384029,
#   'license': 5,
#   'width': 640}]

次に STAIR Captions の学習データから画像ファイル名とキャプションのリストを生成します。

id2image_train = {example['id']: example['file_name'] for example in json_data_train["images"]}
image_filenames_train = []
captions_train = []

for anno in json_data_train["annotations"]:
  image_filenames_train.append(id2image_train[anno['image_id']])
  captions_train.append(anno['caption'])

検証データも画像ファイル名とキャプションのリストにします。 STAIR Captions は 1 画像につき 5 つのキャプションが付与されているのですが、 同じ画像に対するキャプションはフィルタし 1 画像 = 1 キャプションとして、そのうちの 10000 件を使用しました9

id2image_val = {example['id']: example['file_name'] for example in json_data_val["images"]}
image_filenames_val = []
captions_val = []
in_use = set()

for anno in json_data_val["annotations"]:
  if anno['image_id'] in in_use:
    continue
  in_use.add(anno['image_id'])
  image_filenames_val.append(id2image_val[anno['image_id']])
  captions_val.append(anno['caption'])
del in_use

image_filenames_val = image_filenames_val[:10000]
captions_val = captions_val[:10000]

件数を確認しておきましょう。学習データの画像ファイルの件数です。1 画像につきキャプションが 5 つずつあるので、このリストは重複を含みます。

len(image_filenames_train) 
# 413915

学習データのキャプションの件数です。画像と一致してますね。

len(captions_train)
# 413915

不要データはディスクから消してしまって、

!rm stair_captions_*.json

元にしたソースコードを動かす都合上、画像データは 1 フォルダにまとめます。

!mv ./val2014/*.jpg ./train2014
!ls ./train2014 | wc -l
# 123287

学習の実行

まずは、必要なライブラリをあれこれとインポートします。

import os
import glob
import cv2
import gc
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import albumentations as A

import torch
from torch import nn
import torch.nn.functional as F
import timm
from sentence_transformers import SentenceTransformer 
from sentence_transformers.util import batch_to_device
import tensorflow as tf
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

文章のベクトル化モデルには前回で作成した 4 層に蒸留済みの Sentence Transformer を使うので GCS の認証を通して、ダウンロードします10

from google.colab import auth
auth.authenticate_user()

!gsutil cp gs://somewhere/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking.tar.gz .
!tar zxvf ./strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking.tar.gz
!rm strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking.tar.gz

text_model = SentenceTransformer("./strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking")
text_model.to(device)

試しに推論してみましょう。文章のベクトル化モデルへの入力はこんな感じですね。

features = text_model.tokenize(["吾輩は猫である", "本日は晴天なり"])
features = batch_to_device(features, device)
features
# {'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1],
#          [1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0'),
#  'input_ids': tensor([[    2,  7184, 30046,     9,  6040,    12,    31,     3],
#          [    2,   108, 28486,     9,  4798, 28849,   297,     3]],
#         device='cuda:0'),
#  'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')}

これが 768 次元の埋め込み表現になります。

output = text_model.forward(features)
output['sentence_embedding'].shape
# torch.Size([2, 768])

ハイパーパラメータはこんな感じです。# FIXED が入っている行が元から変更したところですね。

class CFG:
    debug = False
    image_path = "./train2014"                                     # FIXED
    captions_path = "C:/Moein/AI/Datasets/Flicker-8k"
    batch_size = 32
    num_workers = 2
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 4                                                     # FIXED
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_name = 'resnet50'
    image_embedding = 2048
    text_embedding = text_model.get_sentence_embedding_dimension() # FIXED
    max_length = text_model.max_seq_length                         # FIXED

    pretrained = True 
    trainable = True 
    temperature = 1.0

    # image size
    size = 224

    num_projection_layers = 1
    projection_dim = 256 
    dropout = 0.1

こちらは学習のメトリクスを計算するクラスとかですね。

class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

続いてデータセットです。オーグメンテーションの類は入れてないですが、そこはお好みで。

class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.encoded_captions = tokenizer(
            list(captions), padding=True, truncation=True, max_length=CFG.max_length
        )
        self.transforms = transforms

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }
        image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image']
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        item['caption'] = self.captions[idx]
        return item

    def __len__(self):
        return len(self.captions)

def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )

画像のエンコーダです。ハイパーパラメータで指定された名称のモデルをロードして、アベレージプーリングした後の 固定長ベクトルを返すだけですね。前述の設定ですと、"ResNet50" を使って 2048 次元の埋め込み表現を返すことになります。

class ImageEncoder(nn.Module):
    def __init__(
        self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name, pretrained, num_classes=0, global_pool="avg"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

こんどはテキストエンコーダです。先ほどロードした文章ベクトル化モデルをここで組み込みます。

class TextEncoder(nn.Module):
    def __init__(self, text_model, trainable=CFG.trainable):
        super().__init__()
        self.model = text_model 

        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, input_ids, attention_mask):
        output =  text_model.forward({'input_ids': input_ids, 'attention_mask': attention_mask})
        return output['sentence_embedding']

次に画像エンコーダとテキストエンコーダから得られる埋め込み表現を同じ次元数にするためのプロジェクションヘッドです。

class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

最後に今までの部品を組み立てた CLIP の全体像です。

オリジナルの CLIP は temperature も学習パラメータだった気がしますが、このサンプルでは固定値になってますね。

あとは targets でしょうか。(1) Contrastive pre-training の図では正解ラベルは単純な対角線の突き合わせで表現されていますが、 このサンプルではバッチの中に同じ画像が混入することを想定して、バッチ中の画像同士、テキスト同士の類似度の平均を使ったソフトラベルを作り、 それを画像<=>テキストの類似度と突き合わせる形になっています。

class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder(text_model)  # FIXED HERE
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

さて、ここから PyTorch の学習を回すための部品を定義していきます。

まずは DataLoader の構築です。テキストのトークナイザは文章ベクトル化モデルから拾い出します (実体は BertJapaneseTokenizer になっているはず)。

def build_loaders(image_filenames, captions, tokenizer, mode):     # FIXED
    transforms = get_transforms(mode=mode)
    dataset = CLIPDataset(
        image_filenames,                                           # FIXED
        captions,                                                  # FIXED
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=CFG.batch_size,
        num_workers=CFG.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader

tokenizer = text_model.tokenizer                                   # FIXED
train_loader = build_loaders(image_filenames_train, captions_train, tokenizer, mode="train")
valid_loader = build_loaders(image_filenames_val, captions_val, tokenizer, mode="valid")    

CLIP モデルのインスタンス、オプティマイザ、学習レートのスケジューラです。

model = CLIPModel().to(CFG.device)
params = [
    {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
    {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
    {"params": itertools.chain(
        model.image_projection.parameters(), model.text_projection.parameters()
    ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
]

optimizer = torch.optim.AdamW(params, weight_decay=0.)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
)

ここから元のコードを少々変更しています。 学習途中のチェックポイントを GCS にコピーするロジックを入れて、中断地点から再開できるようにしました。

ckpt_dir = './ckpt'
best_loss = float('inf')
current_step = 0
resumed_step = -1
batches_per_epoch = len(train_loader)
evaluation_steps=int(batches_per_epoch * 0.2)

def save_ckpt(ckpt_dir, model, optimizer, scheduler, current_step, valid_loss):
    ckpt = {
        'current_step': current_step,
        'valid_loss': valid_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }
    path = os.path.join(ckpt_dir, "ckpt-{}-{:.3f}.pt".format(current_step, valid_loss))
    print("Saving checkpoint at step {}".format(current_step))
    if not os.path.exists(ckpt_dir):
        os.mkdir(ckpt_dir)
    torch.save(ckpt, path)
    tf.io.gfile.copy(path, os.path.join("gs://somewhere/clip","ckpt-{}-{:.3f}.pt".format(current_step, valid_loss)))

def load_ckpt(path, model, optimizer, scheduler):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    print("Resuming checkpoint from step {}".format(checkpoint['current_step']))
    return model, optimizer, scheduler, checkpoint['current_step']    

ローカルディスクにチェックポイントの出力フォルダがあれば、ロードして最新の状態に復帰します。

中断した時点から再開する場合は、ここまでの手順を再実行した上で、 ckpt_dir に設定したディレクトリを手動で作成し、 GCS にコピーされたチェックポイント("gs://somewhere/clip/ckpt-*")をコピーしておいて下さい11

if os.path.exists(ckpt_dir):
  latest_ckpt_path = sorted([(ckpt.split("-")[1], ckpt) for ckpt in glob.glob("./ckpt/ckpt-*")], 
                            key=lambda x:int(x[0]), reverse=True)[0][1]
  model, optimizer, lr_scheduler, resumed_step = load_ckpt(latest_ckpt_path, model, optimizer, lr_scheduler)

ようやくメインの学習ループの実行です。

start_epoch = resumed_step // batches_per_epoch if resumed_step >= 0 else 0
current_step = start_epoch * batches_per_epoch

for epoch in range(start_epoch, CFG.epochs):
    print(f"Epoch: {epoch + 1}")

    model.train()
    train_loss = AvgMeter()
    tqdm_object = tqdm(train_loader, total=batches_per_epoch)

    for batch in tqdm_object:
        if current_step <= resumed_step: 
            if current_step == resumed_step:
                print("Resuming to step {} completed.".format(resumed_step))
            current_step += 1
            continue
        batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        count = batch["image"].size(0)
        train_loss.update(loss.item(), count)

        if current_step > 0 and current_step % evaluation_steps == 0:
            model.eval()
            with torch.no_grad():
                valid_loss = AvgMeter()
                tqdm_object = tqdm(valid_loader, total=len(valid_loader))
                for batch in tqdm_object:
                    batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
                    loss = model(batch)
                    count = batch["image"].size(0)
                    valid_loss.update(loss.item(), count)
                    tqdm_object.set_postfix(valid_loss=valid_loss.avg)

            save_ckpt(ckpt_dir, model, optimizer, lr_scheduler, current_step, valid_loss.avg)
            if valid_loss.avg < best_loss:
                best_loss = valid_loss.avg
            lr_scheduler.step(valid_loss.avg)
            model.train()

        current_step += 1
        tqdm_object.set_postfix(train_loss=train_loss.avg, lr=get_lr(optimizer))

学習ループを抜けたら、最後にもう一度検証をして保存しておきます。

model.eval()
with torch.no_grad():
    valid_loss = AvgMeter()
    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch in tqdm_object:
        batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)
        count = batch["image"].size(0)
        valid_loss.update(loss.item(), count)
        tqdm_object.set_postfix(valid_loss=valid_loss.avg)

save_ckpt(ckpt_dir, model, optimizer, lr_scheduler, current_step, valid_loss.avg)

学習曲線はこんな感じになりました。とりあえず 4 エポック回してみましたが、ちょうど良いくらいでしたね。

learning_curve

4. 推論の実行

それでは学習済みのモデル使って(ベタですが)日本語で画像検索をしてみましょう。

Colab のランタイムが初期化されている場合には前章の手順をこのセルまで再実行しておいて下さい。

学習曲線で最良だったステップ 43979 のチェックポイントを使います。

!gsutil cp gs://somewhere/clip/ckpt-43979-0.434.pt .
ckpt_path = "./ckpt-43979-0.434.pt"

トークナイザを文章ベクトル化モデルから取り出しておきます。

tokenizer = SentenceTransformer("./strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking").tokenizer

推論には検証セットのデータを使います。

valid_loader = build_loaders(image_filenames_val, captions_val, tokenizer, mode="valid")

CLIP のインスタンスを作って、学習済みチェックポイントを読み込み、評価モードにしておきましょう。

model = CLIPModel().to(CFG.device)
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['state_dict'])
model.eval()

あらかじめ、ロードした CLIP モデルを使って検証セットの画像を埋め込み表現にしておきます。

valid_image_embeddings = []
with torch.no_grad():
    for batch in tqdm(valid_loader):
        image_features = model.image_encoder(batch["image"].to(CFG.device))
        image_embeddings = model.image_projection(image_features)
        valid_image_embeddings.append(image_embeddings)
image_embeddings = torch.cat(valid_image_embeddings)

最後にクエリー文字列に似た画像を検索して表示する関数です。

表示する画像数を n で指定しますが、類似度 top-n を表示する訳ではなくて、top_(n * 5) を抽出して、 5 件おきに表示してますね。このほうが表示される画像が適度にバラついていい感じに見えるということでしょう。

def find_matches(model, image_embeddings, query, image_filenames, tokenizer, n=9):
    encoded_query = tokenizer([query])
    batch = {
        key: torch.tensor(values).to(CFG.device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)

    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T

    values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
    matches = [image_filenames[idx] for idx in indices[::5]]

    _, axes = plt.subplots(3, 3, figsize=(10, 10))
    for match, ax in zip(matches, axes.flatten()):
        image = cv2.imread(f"{CFG.image_path}/{match}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")

    plt.show()

検証セットの画像をランダムに抽出するとこんな感じです。

_, axes = plt.subplots(3, 5, figsize=(20, 10))
for match, ax in zip(np.random.choice(image_filenames_val, 15), axes.flatten()):
    image = cv2.imread(f"{CFG.image_path}/{match}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    ax.imshow(image)
    ax.axis("off")

plt.show()

coco_photos

それではやってみましょう。「駅のホームに列車が入ってきました」で検索してみます。

find_matches(model, 
             image_embeddings,
             query="駅のホームに列車が入ってきました",
             image_filenames=image_filenames_val,
             tokenizer=tokenizer,
             n=9)

trains

こんどは「野球をしている様子」です。

find_matches(model, 
             image_embeddings,
             query="野球をしている様子",
             image_filenames=image_filenames_val,
             tokenizer=tokenizer,
             n=9)

baseballs

今回、かなりモデルを小さくしたので、ちょっと心配でしたが、それっぽい結果がでましたね。

最後に Zero-shot 画像分類もやってみましょう。

日本語をプロットするのでフォントを追加します。

!apt install fonts-takao-pgothic

以下の動物画像を分類してみましょう。

sampled_images = ["COCO_val2014_000000318908.jpg", "COCO_val2014_000000347236.jpg", "COCO_val2014_000000318814.jpg"]
_, axes = plt.subplots(1, 3, figsize=(20, 10))
for match, ax in zip(sampled_images, axes.flatten()):
    image = cv2.imread(f"{CFG.image_path}/{match}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    ax.imshow(image)
    ax.axis("off")

plt.show()

animals

画像検索をする時に作った埋め込み表現を拾い出します。

sampled_image_indices = [image_filenames_val.index(image) for image in sampled_images]
sampled_image_embeddings = [image_embeddings[index] for index in sampled_image_indices]
sampled_image_embeddings = torch.vstack(sampled_image_embeddings)
sampled_image_embeddings.shape
# torch.Size([3, 256])

分類対象となるクラスはテキストで定義するんでしたね。

text_classes = ["犬の写真", "猫の写真", "馬の写真"]

これを埋め込み表現にして、

tokenized_classes = tokenizer(text_classes)
batch = { key: torch.tensor(values).to(CFG.device) for key, values in tokenized_classes.items()}
with torch.no_grad():
  text_features = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
  text_embeddings = model.text_projection(text_features)
text_embeddings.shape

画像とテキストの埋め込み表現同士を突き合わせて類似度を計算し、最も高い類似度だったものを分類結果とします。

sampled_image_embeddings_n = F.normalize(sampled_image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ sampled_image_embeddings_n.T
predictions = torch.argmax(dot_similarity, axis=1)
predictions = [int(p) for p in predictions]

分類結果と画像をセットで表示してみると。。。

from matplotlib.font_manager import FontProperties
fpath = '/usr/share/fonts/truetype/fonts-japanese-gothic.ttf'
fp = FontProperties(fname=fpath, size=96)
_, axes = plt.subplots(1, 3, figsize=(20, 10))
for prediction, match, ax in zip(predictions, sampled_images, axes.flatten()):
    image = cv2.imread(f"{CFG.image_path}/{match}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    height, width, _ = image.shape
    ax.imshow(image)
    ax.text(width/2-64, height/2+64, text_classes[prediction].replace("の写真", ""), color="white", fontproperties=fp)
    ax.axis("off")
plt.show()

prediction

ちゃんと分類できました。

計算した類似度の中身をみると結構危うい感じですが、まぁ全部動物の画像ですから、こんなものかもしれませんね。

dot_similarity
# tensor([[0.6361, 0.5378, 0.4134],
#         [0.5227, 0.6051, 0.4356],
#         [0.4659, 0.4405, 0.5953]], device='cuda:0')

5. おわりに

今回は、前回の文章ベクトル化モデルを再利用して画像と絡めてみました。モデル規模とバッチサイズを大きくすれば精度は向上すると思います。 さて、次回はどうしましょうか。 ExT512 や RETRO13 も興味あるんですが、さすがに Colab では動かせませんね。。。 なんてことを考えていたら、 RETRO の参考文献で Fusion-in-Decoder14 というのを見つけました。ソースを見ると、 transformers の T5 がベースになっているので、少し頑張ったら動かせそうです。これだったら、(完全ではなくとも) Colab で動かせなくはないかと。 以前から Wikipedia のような外部文書を参照して推論するテキスト生成モデルを動かしてみたいと思っていたので試してみようかと思っています。


  1. https://arxiv.org/abs/2103.00020 

  2. https://github.com/openai/CLIP 

  3. https://github.com/moein-shariatnia/OpenAI-CLIP 

  4. https://deepsquare.jp/2021/01/clip-openai/ 

  5. https://data-analytics.fun/2021/03/24/understanding-openai-clip/ 

  6. そういえば、前回の Multiple Negatives Ranking Loss も (hard negative が入ってますが) Contrastive Representation Learning ですね。 

  7. https://cocodataset.org/ 

  8. http://captions.stair.center/ 

  9. すいません。理由をちゃんと覚えていません。たしか「Colab の GPU ランタイムはすぐに時間切れになっちゃうので、検証データは小さくってもイイでしょー。」くらいのノリだったかと思うんですが。。。 

  10. 消しちゃった方(普通はそうでしょうが。。。)は前回を参照してもう一度作ってください。 

  11. すいません。この辺は手抜きです。 

  12. https://arxiv.org/abs/2111.10952 

  13. https://arxiv.org/abs/2112.04426 

  14. https://arxiv.org/abs/2007.01282