オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第20回 Fusion-In-Decoder でクイズに答えるモデルを作る
オージス総研 技術部 データエンジニアリングセンター
鵜野 和也
2022年4月21日

今回は Fusion-In-Decoder を使ってクイズに答えるモデルを作ります。以前から Wikipedia 等の外部情報を参照できるテキスト生成モデルを試してみたいと思っていました。Fusion-In-Decoder の発表は 2020 年なので少し前のモデルですが、T5 ベースで手軽に試せるサイズ感ですので、日本語で試してみましょう。

1. はじめに

今回紹介する Fusion-In-Decoder(以下、FiD )1 は Meta AI (当時は Facebook AI Research) が発表した Open Domain question Answering タスクを解くテキスト生成モデルです。

じつは、以前から外部情報を参照できるテキスト生成モデルを試してみたくて2、 Google の RETRO3 の論文を読んでたんです。 なのですが、外部情報のサイズ感が 1000 B トークンくらい欲しい感じなんですよね(ちなみに日本語 Wikipedia 全体でも 1 B トークンを超えるくらいだと思います)。「うーむ、どうしたものか。。。」と思っていたところ、参考文件で FiD を見つけました4

Meta AI 発のモデルは「日本語で動かそうと思うと事前学習モデルがなぁ。。。」となるのですが、 FiD は、なんと Google の T5 がベースになっていて、コード5を見ると Hugging Face の transformers で実装してあります。 transformers で T5 とくれば Megagon Labs さんが日本語の事前学習済みモデル6を公開して下さっていますので、「これならチョチョイのチョイで動くんじゃ?」とすごく軽い気持ちで今回のネタが決定したのでありました。。。

タスクとしては Open Domain Question Answering(以下、OpenQA) になります。第13回の ORQA でも試したので二回目ですね。 軽く復習しておくと、モデルに対する入力としては“質問"のみ が与えられ、(多くの場合は Wikipedia 等の外部知識を参照して)質問に解答するタスクになります。 ORQA の時は正答率 30.5 % でしたが、今回はどの程度までスコアを伸ばせるでしょうか?

まずは、 FiD がどんなモデルなのか見てみましょう。

2. Fusion-In-Decoder

FiD も ORQA と同じく Retriever + Reader の構成になっています。OpenQA タスクのモデルではよく見る構成ですね。 以下はそのイメージです。

retriever-reader

質問を投入された Retriever は質問をキーとして外部知識(主にWikipediaの記事等)を検索し、回答が含まれる可能性が高い文書群を抽出します。 次に質問と抽出された文書群を Reader に投入し、Reader はそこから回答を導きだします。

FiD の論文で対象となっているのは Reader の部分であり、Retriever の学習については別の論文7で学習済みの Reader を用いる手法が提案されているので、あとで紹介しますね。

さて、 Reader の学習は以下のように質問と質問に関連する文書群がそろったところからのスタートになります。

retriever-reader

ところが、FiD は非常にシンプルで Reader の学習についてはあまり説明することがありません。以下の図が処理の概要になります。

fusion-in-decoder ここで encoder, decoder はそれぞれ、 T5 の encoder と decoder です。質問と回答が含まれる可能性が高い文書(図中では Passage 1~N)を連結し、 個別に encoder で埋め込み表現のシーケンスにします。あとは個々の埋め込み表現を連結して単一の長いシーケンスにしたら、 decoder に投入し、自己再帰でテキスト生成するだけです。

ORQA の時は ある入力シーケンス("Question + Passage *”)から推論した回答の中で最もそれらしいものを選んでいました。 つまり最終的な回答の元ネタとなった passage は 1 ~ N のうちのどれか一つです。

それに対して、 FiD は回答生成時に全ての入力シーケンスの埋め込み表現を見ることができるので、「Passage A のコレと Passage B のソレから回答を生成する」あるいは「Passage X にこう書いてあるから、Passage Y のここを回答にしよう」という動きができるようになっています。

少し細かい話

上図では各 passage (青、黄、緑)が不等長で描かれていますが実装上は少し違います。 バッチサイズを B 、passage の数を N 、各入力シーケンス("Question + Passage *“) の最大長を L とすると、 encoder には各入力シーケンスを padding あるいは truncate して長さ L でそろえ、[B*N, L] のテンソルとして投入します。 そして、その出力を [B, N*L] に reshape して decoder に投入します。FiD のソースコード4で言うと以下の部分ですね。

# https://github.com/facebookresearch/FiD/blob/25ed1ff0fe0288b80fb5e9e5de8d6346b94b8d48/src/model.py#L139-L147
139|    def forward(self, input_ids=None, attention_mask=None, **kwargs,):
140|        # total_length = n_passages * passage_length
141|        bsz, total_length = input_ids.shape
142|        passage_length = total_length // self.n_passages
143|        input_ids = input_ids.view(bsz*self.n_passages, passage_length)
144|        attention_mask = attention_mask.view(bsz*self.n_passages, passage_length)
145|        outputs = self.encoder(input_ids, attention_mask, **kwargs)
146|        outputs = (outputs[0].view(bsz, self.n_passages*passage_length, -1), ) + outputs[1:]
147|        return outputs

では、話をもどして Retriever の方を見ていきましょう。

Retriever の学習

ここから別の論文7の話になりますが、ややこしいので本記事では参考文献11, 77をまとめて FiD として説明することにします。

Retriever + Reader の構成をとる最近のモデルは End-to-End で Retriever と Reader を同時に学習するようなものが多い8のですが、 FiD はこの二つを個別に学習します。まず、Reader は前述の方法で一旦学習したとしましょう。

次に Retriever を学習させるわけですが FiD では学習済みの Reader から蒸留するという手法をとります。

具体的には次のような処理になります。

学習済みの Reader は回答を生成するとき、N 個の入力シーケンス("Question + Passage *”) のどこかに注目します。 このとき、より多く注目された入力シーケンスは回答を生成するのに有益なものだと言えそうなので、 この注目具合を Cross-Attention の重みから抽出して質問に対する passage のスコアとします。

質問のシーケンスを q、質問に関連する passage を pn (1 ≤ n ≤ N )とした q に対する pn のスコア Gq, pn

  • K : decoder のレイヤ数
  • H : decoder のヘッド数
  • N : decoder に入力する入力シーケンス("Question + Passage *“) の数
  • L : パディング済み入力シーケンス長
  • 連結済みシーケンス : N 個のパディング済み入力シーケンスを連結したもの
  • j : 連結済みシーケンスにおけるインデックス
  • αi,j,k,h: decoder の k 番目のレイヤ、 h 番目のヘッドにおける出力シーケンス i 番目トークンから見た連結済みシーケンス j 番目トークンに対するアテンションの重み(ここで αi,:,k,h は長さ N * L のシーケンス)
  • Μ: 連結済みシーケンスに対するマスク(有効トークンなら 1、パディングなら 0 を値とする長さ N * L のシーケンス)

とすると以下のようになります。

Gqp

この数式、論文に書いてあるものではなくて筆者がソースコード見ながら書き起こしましたが9、分かりにくいですね。 平たく言うと decoder で最初のトークンを推論するときの Cross Attention の重みを全部のレイヤ、全部のヘッドで pn に対応する部分(α0,1+(n-1)*L:n*L,:,:)についてパディングを除外しつつ足し合わせ、最後にレイヤ数×ヘッド数× pn の有効トークン数で割ってトークン辺りの平均を出すってだけです。

あとは Retriever を学習するときに質問と passage の埋め込み表現の内積が上記のスコアを反映するように学習する訳です。

  • E : 文章をベクトルにエンコードする関数
  • Q : 質問の集合
  • Dq : 質問 q に関連する passage の集合
  • Sθ(q, p) : エンコードした質問 q と passage p のスコア。E(p)TE(q)

とすると、まずは Gq, p と Sθ(q, p) を以下のように正規化します。

normalized_G normalized_S

そして、この二つの KL ダイバージェンスを損失とします。

lossKL

KL ダイバージェンスはこの連載で出てきたかどうか覚えていないのですが、二つの分布の距離のようなものです。 上記の損失、つまり G~q,pとS~θ(q, p) の距離を最小化することで、関数 E を使って q, p の内積を計算すれば Gq,p に応じたスコアが得られるようになります。

関数 E は文章を埋め込み表現にできれば何でもよいのでしょうが、 FiD では BERT の ”[CLS]“ トークンに対する埋め込みを 使っています。

このようにして

  • Retriever を学習できたら、
  • 学習済み Retriever を使ってデータセットの質問に対する passage を Wikipedia 等から検索して、Reader の学習データを再生成します。
  • この再生成したデータで Reader を再学習し、
  • 再学習済み Reader の Cross-Attention の重みを使って Retriever を…

と Reader と Retriever を交互に学習します。論文によると 4 周目ぐらいまでスコアが向上したようです。

では、実際に日本語のデータを使って動かしてみましょう。

注記

以後の手順にどうにも Colab では動かせない箇所ができてしまいました。。。予めご了承ください。

3. Wiki40b を passage に分割する

まず FiD を動かすには検索対象となる外部知識が必要です。今回は Wiki40b を使うことにしました。

Wiki40b は一言でいうと Wikipedia のコンテンツに対して、曖昧さ回避ページの除外やページ内のマークアップや参考文献等を除去する処理が適用されたものになります。クイズの回答を探す対象文書としては、生の Wikipedia よりもこちらのほうが良さそうですね。Wiki40b については Hironsan さんのページ10で詳しく解説されているのでそちらを参考にして下さい。

今回も記事内のコードスニペットは、特に断りがない場合は Google Colaboratory (以下、Colab)で動かす想定にしています。 この章の内容は GPU は不要だと思います。

GCS を使うので認証を通しておきます。

from google.colab import auth
auth.authenticate_user()

FiD の論文によると passage は 100 単語程度の長さで重なりがないように分割したようです。 今回は分割の切れ目を文の境界として、単語と文の分割には GiNZA を使いました。

!pip install ginza ja_ginza==5.1.0

passage に含まれる単語数を 100 として

MAX_WORDS_PER_PASSAGE = 100

Wiki40b をロードします。

import tensorflow_datasets as tfds
import pickle
import spacy
ds = tfds.load('wiki40b/ja')

Wiki40b は _START_ARTICLE__START_PARAGRAPH__NEWLINE_ のような文章の構造を示すシンボルが入っているので、 それらを処理する関数です。

def parse_example(example):
  article = example["text"].numpy().decode("utf-8")
  previous = ""
  title = ""
  lines = []
  for element in article.split("\n"):
    if previous == "_START_ARTICLE_":
      title = element
    if previous == "_START_PARAGRAPH_":  
      lines.extend(element.split("_NEWLINE_"))
    previous=element
  return title, lines

複数行の文章を単語数 100 に収まるように文の境界で passage に分割する関数です。

def lines2passages(lines, nlp):
  docs = nlp.pipe(lines, disable=['ner', 'bunsetu_recognizer'])
  sents = []
  for doc in docs:
    for sent in doc.sents:
      sents.append(sent)

  passages =[]
  num_words_of_passages = []
  passage = ""
  num_words_of_passage = 0

  for sent in sents:
    num_words = len(sent)
    if num_words_of_passage + num_words > MAX_WORDS_PER_PASSAGE:
      passages.append(passage)
      num_words_of_passages.append(num_words_of_passage)
      passage = ""
      num_words_of_passage = 0
    passage += str(sent)
    num_words_of_passage += num_words
  passages.append(passage)
  num_words_of_passages.append(num_words_of_passage) 
  return passages, num_words_of_passages

Wiki40b を読み込んで passage のリストを返す関数です。 全体を一度に処理するのは時間的に厳しいと思うので、head, tail で処理範囲を指定します。 また、ループを回していると nlp がどんどんメモリを食いつぶすので、 10,000 記事毎にリロードしてます。

def build_contexts(head=0, tail=None):
  nlp = spacy.load('ja_ginza')
  contexts =[]
  i = 0
  for split in ds.keys():
    if tail is not None and i >= tail:
      break
    for example in ds[split]:
      if i < head:
        i += 1
        continue
      if head > 0 and i == head:
        print("Resume to head[{}] is completed.".format(head), flush=True)
      if i % 10000 == 0 or (i < 2000 and i % 10 == 0):
        print("processing example[{}]".format(i), flush=True)
      if i > 0 and i % 10000 == 0:
        print("Reaload nlp...")
        nlp = spacy.load('ja_ginza')
      title, lines = parse_example(example)
      passages, num_words_of_passages = lines2passages(lines, nlp)
      for passage in passages:
        contexts.append((title, passage))
      i += 1
      if tail is not None and i >= tail:
        print("The number of processed examples is reached to the specified tail[{}].".format(tail), flush=True)
        break
  return contexts, i

Wiki40b 全体の記事数としては 828,236 件ありました。 64,000 件ぐらいで区切って実行するのが無難でしょうか。 あまり欲張ると pickle のサイズが 4GB を越えて悲しい思いをしたりします。

とりあえず先頭の 64,000 件を処理してみましょう。

head=0
tail=64000
contexts, tail = build_contexts(head=head, tail=tail)
fname = "contexts-{}-{}.pkl".format(head, tail)
print("Dumping contexts as pickle : {}...".format(fname), flush=True)
with open(fname, "wb") as f:
  pickle.dump(contexts, f)

この段階でlen(contexts) しても 64,000 にならないので注意してください(64,000件の記事がそれぞれ単語数100以下の passage に分割されているので)。

処理が終わった分は GCS に退避しておきます。

!gsutil cp contexts-{head}-{tail}.pkl gs://somewhere/FiD

上記の要領で、

head=64001
tail=128000

というふうに、あと 12 回繰り返します。。。うーん、GCE でインスタンス立てた方がよさそうですね。

4. JAQKET データセットの取得と加工、ロード関数

今回はクイズに答えるモデルを作るので JAQKET データセットを使います。JAQKET データセット については第13回で説明しているので、 そちらを参照してください。

ここからは、新たに Colab のノートブックを開いて、アクセラレータに GPU を選んでください。

JAQKET データセットを取得します。

!wget https://jaqket.s3.ap-northeast-1.amazonaws.com/data/aio_01/train_questions.json
!wget https://jaqket.s3.ap-northeast-1.amazonaws.com/data/aio_01/dev1_questions.json
!wget https://jaqket.s3.ap-northeast-1.amazonaws.com/data/aio_01/dev2_questions.json

JAQKET データセットを FiD のソースコード向けに変換する関数です。

import json
def convert_jaqket(filename, split, dataset_name="jaqket"):
  with open(filename, "r") as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]
    examples = [json.loads(line) for line in lines]
    examples = [{"id": example["qid"], 
                 "question": example["question"], 
                 "answers": [example["answer_entity"]] } for example in examples]
  with open("%s.resplit.%s.jsonl" % (dataset_name, split), "w") as f:
    lines = []
    for example in examples:
      line = json.dumps(example, ensure_ascii=False)
      lines.append(line)
    f.write("\n".join(lines))
  #return examples

各スプリットを変換しましょう。

convert_jaqket("train_questions.json", "train")
convert_jaqket("dev1_questions.json", "dev")
convert_jaqket("dev2_questions.json", "test")

この時点で変換後のファイルはこんな感じです。

!head -5 jaqket.resplit.train.jsonl
{"id": "ABC01-01-0003", "question": "格闘家ボブ・サップの出身国はどこでしょう?", "answers": ["アメリカ合衆国"]}
{"id": "ABC01-01-0004", "question": "ロシア語で「城」という意味がある、ロシアの大統領府の別名は何でしょう?", "answers": ["クレムリン"]}
{"id": "ABC01-01-0005", "question": "織田信長、豊臣秀吉、徳川家康という3人の戦国武将の性格を表現するのに用いられる鳥は何でしょう?", "answers": ["ホトトギス"]}
{"id": "ABC01-01-0006", "question": "人気タレント・タモリの本名は何でしょう?", "answers": ["タモリ"]}
{"id": "ABC01-01-0008", "question": "「国際連合」の旗に描かれている植物といったら何でしょう?", "answers": ["オリーブ"]}

タモリさんの本名を直しておきます。

!sed -i '/^{"qid": "ABC01-01-0006"/s/"タモリ"/"森田一義"/' jaqket.resplit.train.jsonl

変換したファイルをロードする関数です。

def load_jsonl(fname):
  with open(fname, "r") as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]
    examples = [json.loads(line) for line in lines]
    return examples

再び GCS の認証を通してもらって、

from google.colab import auth
auth.authenticate_user()

加工したファイルを保存しておきます。

!gsutil cp jaqket.resplit.*.jsonl gs://somewhere/FiD/

このまま続いて question と passage を埋め込み表現に変換しましょう。

開いているノートブックはそのままで次章に進んでください。

5. JAQKET データセットの question と passage の埋め込み

分割した passage と JAQKET データセットの question を埋め込み表現に変換していきます。

文章の埋め込みには sentence-transformers を使うのでインストールします。

!pip install sentence-transformers==2.0.0
!apt-get install mecab mecab-ipadic-utf8 python-mecab libmecab-dev
!pip install mecab-python3 fugashi ipadic

モデルは前々回で作った JSNLI で学習したものを使います11

!gsutil cp gs://somewhere/strf_cl-tohoku_bert-base-japanese-whole-word-masking.tar.gz .
!tar zxvf ./strf_cl-tohoku_bert-base-japanese-whole-word-masking.tar.gz
!rm strf_cl-tohoku_bert-base-japanese-whole-word-masking.tar.gz

トークナイザをロードします。

from transformers import BertJapaneseTokenizer
bert_japanese_tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

文字列のバッチをコードする関数です。

def encode_batch(batch):
  bert_japanese_features = bert_japanese_tokenizer([s.lower() for s in batch], padding=True,
        truncation='longest_first', return_tensors="pt", max_length=256)
  features = { k:torch.tensor(v).to(device) for k, v in bert_japanese_features.items()}
  output = model.forward(features)
  return output['sentence_embedding'].cpu().detach()

続いて passage のリストをエンコードする関数です。

def encode_contexts(model, contexts, batch_size):

    embeddings = []
    head = 0
    tail = batch_size
    num_batch =len(contexts) // batch_size

    def build_batch(contexts, head, tail):
      sentences = contexts[head:tail]
      batch = []
      for title, text in sentences:
        batch.append(title + "。" + text)
      return batch

    for i in range(num_batch):
      head = i * batch_size
      tail = head + batch_size
      if i < 20 or i % 10000 == 0 :
        print("head:{}, tail={}".format(head, tail), flush=True)
      batch = build_batch(contexts, head, tail)
      embeddings.append(encode_batch(batch))
    if tail < len(contexts):
      batch = build_batch(contexts, tail, None)
      embeddings.append(encode_batch(batch))

    return torch.cat(embeddings, axis=0)

passage の埋め込みも分割して地道に実行することにしましょう。

head=0
tail=64000

GCS から処理済みの passage を取得します。

!gsutil cp gs://somewhere/FiD/contexts-{head}-{tail}.pkl .

passage をロードして埋め込み表現に変換します。

def convert_contexts_to_embeddings(head, tail):
  with open("contexts-{}-{}.pkl".format(head, tail), "rb") as f:
    contexts = pickle.load(f)
  embeddings = encode_contexts(model, contexts, 8)
  embeddings = embeddings.to('cpu').detach().numpy()
  np.save("embeddings-{}-{}.npy".format(head, tail), embeddings)

convert_contexts_to_embeddings(head, tail)

変換結果を GCS に保存します。

!gsutil cp embeddings-{head}-{tail}.npy gs://somewhere/FiD/

この処理も

head=64001
tail=128000

として、あと 12 回ですね。。。

こんどは question をエンコードする関数です12

def encode_questions(model, texts, batch_size):
  embeddings = []
  head = 0
  tail = batch_size
  num_batch =len(texts) // batch_size

  for i in range(num_batch):
    head = i * batch_size
    tail = head + batch_size
    if i < 20 or i % 10000 == 0 :
      print("head:{}, tail={}".format(head, tail), flush=True)
    batch = texts[head:tail]
    embeddings.append(encode_batch(batch))
  if tail < len(texts):
    batch = texts[tail:]
    embeddings.append(encode_batch(batch))
  return torch.cat(embeddings, axis=0)     

JAQKET データセットの question を埋め込み表現に変換します。

splits = ["train", "dev", "test"]

for split in splits:
  print("Start encoding {} split...".format(split), flush=True)
  fname = "jaqket.resplit.{}.jsonl".format(split)
  examples = load_jsonl(fname)
  queries = [example["question"] for example in examples]
  embeddings = encode_questions(model, queries, 8)
  embeddings = embeddings.to('cpu').detach().numpy()

  print("Dumping context embeddings as numpy...", flush=True)
  np.save("question-embeddings-{}.npy".format(split), embeddings)

GCS に保存しておきましょう。

!gsutil cp question-embeddings-*.npy gs://somewhere/FiD/

ここまでで、 question と passage を埋め込み表現に変換できたので、次は近傍検索を使って question の回答を含んでいそうな passage を抽出していきましょう。

6. SCaNN による近傍検索

すみません。Colab で動かした風に書いていますが、本章だけ Colab ではどうにも動きません。メモリが 64 GB 程度ある環境で実行して下さい。

今度は前章で作った question の埋め込みを条件にして passage の埋め込みを検索することで、問題の回答を含んでいそうな passage を抽出します。

「回答を含んでいそうな」とはいうものの、やっていることは類似文書検索ですね。「 question と似たような内容の passage には回答が高い確率で含まれているだろう。」という仮定を置いているわけです。

埋め込み表現の検索には scann を使いました。

!pip install scann==1.2.4

GCS の認証を通して、

from google.colab import auth
auth.authenticate_user()

passage と question の埋め込み表現を取得します。

!gsutil cp gs://somewhere/FiD/embeddings-*.npy .
!gsutil cp gs://somewhere/FiD/question-embeddings-*.npy . 

必要なライブラリをインポートして

import numpy as np
import scann

64,000 件で分割するとファイル名はこんな感じになってるはずです。

fnames = []
head_and_tail = list(range(0, 828236, 64000)) + [828236]
for i in range(len(head_and_tail[:-1])):
  head = head_and_tail[i]
  tail =  head_and_tail[i+1]
  fnames.append("embeddings-{}-{}.npy".format(head, tail))
fnames
# ['embeddings-0-64000.npy',
#  'embeddings-64000-128000.npy',
#  'embeddings-128000-192000.npy',
#  'embeddings-192000-256000.npy',
#  'embeddings-256000-320000.npy',
#  'embeddings-320000-384000.npy',
#  'embeddings-384000-448000.npy',
#  'embeddings-448000-512000.npy',
#  'embeddings-512000-576000.npy',
#  'embeddings-576000-640000.npy',
#  'embeddings-640000-704000.npy',
#  'embeddings-704000-768000.npy',
#  'embeddings-768000-828236.npy']

passage の埋め込み表現をロードして、正規化します。

embeddings = []
for fname in fnames:
  embeddings.append(np.load(fname))
embeddings = np.concatenate(embeddings)

normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1)[:, np.newaxis]
del embeddings

SCaNN の Searcher を作ります。パラメータの設定はここ13を参考にしました。

k = 100
num_leaves = 2000
num_leaves_to_search=100
#training_sample_size = int(num_leaves / 16) # 6.25 percent of embeddings
training_sample_size = 250000
reordering_num_neighbors = 10 * k

searcher = scann.scann_ops_pybind.builder(
    db=normalized_embeddings, num_neighbors=k, distance_measure="dot_product").tree(
    num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=training_sample_size).score_ah(
    dimensions_per_block=2, anisotropic_quantization_threshold=0.2).reorder(reordering_num_neighbors).build()

各スプリットの question の埋め込みを条件に近傍検索して保存します。

splits = ["train", "dev", "test"]

for split in splits:
  print("Start MIPS search of {} split...".format(split), flush=True)
  fname = "question-embeddings-{}.npy".format(split)
  queries = np.load(fname)
  normalized_queries = queries / np.linalg.norm(queries, axis=1)[:, np.newaxis]
  neighbors, distances = searcher.search_batched(normalized_queries)
  print("Dumping neighbors of query as numpy...", flush=True)
  np.save("q-neighbors-{}.npy".format(split), neighbors)
# Start MIPS search of train split...
# Dumping neighbors of query as numpy...
# Start MIPS search of dev split...
# Dumping neighbors of query as numpy...
# Start MIPS search of test split...
# Dumping neighbors of query as numpy...  

GCS に処理結果を保存します。

!gsutil cp q-neighbors-*.npy gs://somewhere/FiD/

これで question に近い内容の passage のインデックスが分かったので、ここまで作ってきたデータを組み合わせて、 FiD の学習データを形成していきます。

7. 学習データの形成

新しくノートブックを開きます。アクセラレータは None で構いません。

GCS の認証を通して、

from google.colab import auth
auth.authenticate_user()

JAQKET データセットを加工したファイル、処理済みの passage、各 question の近傍 passage のインデックスを取得します。

!gsutil cp jaqket.resplit.*.jsonl .
!gsutil cp gs://somewhere/FiD/contexts-*.pkl . 
!gsutil cp gs://somewhere/FiD/q-neighbors-*.npy .

全ての passage のメモリ展開と.jsonl のロード関数です。

import pickle
import numpy as np
import json

contexts = []
for fname in fnames:
  with open(fname, "rb") as f:
    contexts.extend(pickle.load(f)) 

def load_jsonl(fname):
  with open(fname, "r") as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]
    examples = [json.loads(line) for line in lines]
    return examples

各スプリットのサンプルに "ctxs” として “question” の近傍にあたる passage を追加します。

splits = ["train", "dev", "test"]

for split in splits:
  print("Start converting {} split...".format(split), flush=True)
  fname = "jaqket.resplit.{}.jsonl".format(split)
  examples = load_jsonl(fname)
  neighbors = np.load("q-neighbors-{}.npy".format(split))
  converted = [{"id" : example["id"], 
                "question" : example["question"], 
                "answers" : example["answers"],  
                "ctxs": [{"title" : contexts[j][0], "text" : contexts[j][1]} for j in neighbors[i]]} 
                 for i, example in enumerate(examples)]
  with open("fid_%s.jsonl" % (split), "w") as f:
    lines = []
    for row in converted:
      line = json.dumps(row, ensure_ascii=False)
      lines.append(line)
    f.write("\n".join(lines))
# Start converting train split...
# Start converting dev split...
# Start converting test split...

それでは形成されたサンプルを見てみましょう。学習データの 2 件目です14

import json
import pprint
with open("fid_train.jsonl", "r") as f:
  line = f.readlines()[1]
  example = json.loads(line)
example["ctxs"] = example["ctxs"][:5]
pprint.pprint(example)

# {'answers': ['クレムリン'],
#  'ctxs': [{'text': 'クレムリンは1156年にユーリー・ドルゴルーキーが砦を築いて以来一貫してモスクワの中心であり、モスクワ大公国時代からロシア帝国初期を通じて王宮が置かれていた。ソビエト連邦成立後はここに政府が置かれ、現在もロシア連邦の大統領府があるロシア政治の中枢である。クレムリンの正面には赤の広場が広がり、広場周辺にはグム百貨店や聖ワシリイ大聖堂、レーニン廟がある。',
#            'title': 'モスクワ'},
#           {'text': 'モスクワのクレムリはモスクワ川に面した河岸段丘上に建てられ、歴代のモスクワ大公やロシア皇帝により宮殿や聖堂などが増築されていった。かつてモスクワのクレムリにはソビエト連邦共産党の中枢があったため、「クレムリン」はソ連首脳部の代名詞でもあった。現在ロシア大統領府はモスクワのクレムリにあるが、ロシア連邦政府はクレムリ外の建物を使っている。',
#            'title': 'クレムリ'},
#           {'text': 'ロシア連邦大統領特殊プログラム総局(ロシアれんぽうだいとうりょうとくしゅプログラムそうきょく、ロシア語: '
#                    'Главное управлениеспециальных программ Президента、略称:ロシア語: '
#                    'ГУСП)は、ロシア連邦の地下シェルター等の特殊施設の建設・運営・維持を担当する機関である。',
#            'title': 'ロシア連邦大統領特殊プログラム総局'},
#           {'text': 'ミハイロフスキー城(Михайловский замок、Mikhailovskii '
#                    'zamok、Mikhailovsky '
#                    'Castle)は、ロシア、サンクトペテルブルクにある城塞。別名をインジェネールヌィ城(Инженерный '
#                    'замок、Inzhenerny zamok、「技師の城」の意)という。英語版では聖ミハイル城、St. '
#                    "Michael's "
#                    'Castleの名称である。ミハイロフ城ともいう。ミハイロフスキー宮殿は、ロシア皇帝パーヴェル1世の宮殿として、1797年から1801年にかけて建設された。',
#            'title': 'ミハイロフスキー城'},
#           {'text': 'ロシア大統領府(ロシア語: Администрация президента '
#                    'России)は、ロシア大統領直属の国家行政機関。1991年7月19日にボリス・エリツィン・ロシア・ソビエト社会主義連邦共和国(現在のロシア連邦)大統領の大統領令によって、エリツィン大統領とアレクサンドル・ルツコイ副大統領(1993年副大統領職は廃止)に関する行政事務を取りあつかうため、ロシア連邦安全保障会議を含む諮問機関と同様に設立された。',
#            'title': 'ロシア大統領府'}],
#  'id': 'ABC01-01-0004',
#  'question': 'ロシア語で「城」という意味がある、ロシアの大統領府の別名は何でしょう?'}

ようやく FiD の学習データが用意できました。GCS に保存してきましょう。

!gsutil cp fid_*.jsonl gs://somewhere/FiD/

ORQA のときもそうでしたが、「Wikipedia を検索して…」 の類は手間がかかりますね。。。 それでは Reader の学習をやってみましょう。

8. Reader の学習

学習データの準備ができたので、ここから Reader の学習にうつります。

ここからは、新たに Colab のノートブックを開いて、アクセラレータに GPU を選んでください。

セットアップ

FiD のコードは https://github.com/facebookresearch/FiD で公開されているので取得します15

!git clone https://github.com/facebookresearch/FiD
!cd FiD && git checkout baf533c3f7a26c1cac624ee9252ce5ccf344a935
# ...
# HEAD is now at baf533c Update requirements.txt

FiD は transformers の T5 をベースに実装されているので、必要なライブラリをインストールします。

!pip install transformers==4.14.1
!pip install sentencepiece

ちなみに FiD のリポジトリの README.md には

  • “Transformers (version 3.0.2, unlikely to work with a different version)”

とがっつり書いてあるのですが、無視して 4.14.1 を使いました。 というのも T5 の事前学習モデルとしては Megagon Labs さんのもの6を使いたかったのですが、3.0.2 でロードしようとするとエラーになってしまって。 transformers が古いバージョンにピン止めされるのも嫌だったので、今回は FiD のコードの方を適宜修正して動かすことにしました。

コードの修正

それでは、ここからコードを編集していきます。

モデル名の修正

まず、ロードするモデルの名称が “t5-base” など英語モデルが前提となっているようなので、

!cat FiD/train_reader.py | grep "model_name =" 
#    model_name = 't5-' + opt.model_size

書き換えて、"megagonlabs/t5-base-japanese-web" にします。

!sed -i "/^ *model_name =/s/'t5-' + opt.model_size/'megagonlabs\/t5-base-japanese-web'/"  FiD/train_reader.py
!cat FiD/train_reader.py | grep "model_name =" 
#    model_name = 'megagonlabs/t5-base-japanese-web'

dict が返る想定のコードに修正

FiD のコードは T5 から tuple が返る想定のコードになっているので(146行目)、

!cat FiD/src/model.py | awk 'NR>=143 && NR<=147{print NR"|"$0}'
#143|        input_ids = input_ids.view(bsz*self.n_passages, passage_length)
#144|        attention_mask = attention_mask.view(bsz*self.n_passages, passage_length)
#145|        outputs = self.encoder(input_ids, attention_mask, **kwargs)
#146|        outputs = (outputs[0].view(bsz, self.n_passages*passage_length, -1), ) + outputs[1:]
#147|        return outputs

dict が返るように修正しました。

!sed -i "146s/outputs =/#outputs =/" FiD/src/model.py
!sed -i "146a \ \ \ \ \ \ \ \ outputs.last_hidden_state = outputs.last_hidden_state.view(bsz, self.n_passages*passage_length, -1)" FiD/src/model.py 
!cat FiD/src/model.py | awk 'NR>=143 && NR<=148{print NR"|"$0}'
#143|        input_ids = input_ids.view(bsz*self.n_passages, passage_length)
#144|        attention_mask = attention_mask.view(bsz*self.n_passages, passage_length)
#145|        outputs = self.encoder(input_ids, attention_mask, **kwargs)
#146|        #outputs = (outputs[0].view(bsz, self.n_passages*passage_length, -1), ) + outputs[1:]
#147|        outputs.last_hidden_state = outputs.last_hidden_state.view(bsz, self.n_passages*passage_length, -1)
#148|        return outputs

チェックポイントを GCS にコピーする

ついでにチェックポイントが生成されたときに GCS にコピーするロジックを足しました。

!cat FiD/src/util.py | awk 'NR>=68 && NR<=70{print NR"|"$0}'
#68|    torch.save(checkpoint, fp)
#69|    symlink_force(epoch_path, cp)
#70|

69行目の後に追加します。

!sed -i '69a\ \ \ \ basename = os.path.basename(dir_path)\n\ \ \ \ os.system("gsutil cp -r {} gs://somewhere/FiD/{}/checkpoint/".format(epoch_path, basename))' FiD/src/util.py 
!cat FiD/src/util.py | awk 'NR>=68 && NR<=72{print NR"|"$0}'
#68|    torch.save(checkpoint, fp)
#69|    symlink_force(epoch_path, cp)
#70|    basename = os.path.basename(dir_path)
#71|    os.system("gsutil cp -r {} gs://somewhere/FiD/{}/checkpoint/".format(epoch_path, basename))
#72|

プレフィックスの修正

passage のデータは title と text で構成されているので、 2 章のアーキテクチャの図の左端の列は、 以下のように question, title, text にプレフィックスを付けて連結した形式となります。

"{question_prefix} {question} {title_prefix} {title} {passage_prefix} {text}"

元々のコードは当然のことながら以下のように英語のプレフィックスになっています。

!cat FiD/src/data.py | awk 'NR>=16 && NR<=18{print NR"|"$0}'
!echo "..."
!cat FiD/src/data.py | awk 'NR>=185 && NR<=186{print NR"|"$0}'
#16|                 question_prefix='question:',
#17|                 title_prefix='title:',
#18|                 passage_prefix='context:'):
#...
#185|                 title_prefix='title:',
#186|                 passage_prefix='context:'):

これを日本語の T5 向けにトークナイズするとトークン数がかさむので、ここも日本語にしてしまいましょう。

!sed -i "s/question_prefix='question:'/question_prefix='問題:'/" FiD/src/data.py
!sed -i "s/title_prefix='title:'/title_prefix='タイトル:'/" FiD/src/data.py
!sed -i "s/passage_prefix='context:'/passage_prefix='文脈:'/" FiD/src/data.py
!cat FiD/src/data.py | awk 'NR>=16 && NR<=18{print NR"|"$0}'
!echo "..."
!cat FiD/src/data.py | awk 'NR>=185 && NR<=186{print NR"|"$0}'
#16|                 question_prefix='問題:',
#17|                 title_prefix='タイトル:',
#18|                 passage_prefix='文脈:'):
#...
#185|                 title_prefix='タイトル:',
#186|                 passage_prefix='文脈:'):

ついに全ての準備が整いました。学習を実行してみましょう。

学習の実行

まずは GCS の認証を通して、

from google.colab import auth
auth.authenticate_user()

作成済みの学習データを GCS から取得します。

!gsutil cp gs://somewhere/FiD/fid_*.jsonl .

以下のようにして学習を起動します。

!export PYTHONPATH=${PYTHONPATH}:.:./FiD && \
  TRAIN_DATA='./fid_train.jsonl' && \
  EVAL_DATA='./fid_dev.jsonl' && \
  PER_GPU_BATCH_SIZE=1 && \
  N_CONTEXT=100 && \
  CKPT_DIR='/content/checkpoint' && \
  OPTIM='adam' && \
  WEIGHT_DECAY=0.1 && \
  LEARNIG_RATE=0.0001 && \
  SCHEDULER='fixed' && \
  MAX_LEN=200 && \
  TOTAL_STEPS=91420 && \
  ACCUMULATION_STEPS=64 && \
  SAVE_FREQ=1280 && \
  EVAL_FREQ=1280 && \
  EVAL_PRINT_FREQ=1280 && \
  \
  date && \
  echo "TRAIN_DATA=$TRAIN_DATA" &&\
  echo "EVAL_DATA=$EVAL_DATA" &&\
  echo "PER_GPU_BATCH_SIZE=$PER_GPU_BATCH_SIZE" &&\
  echo "N_CONTEXT=$N_CONTEXT" &&\
  echo "CKPT_DIR=$CKPT_DIR" && \
  echo "OPTIM=$OPTIM" && \
  echo "WEIGHT_DECAY=$WEIGHT_DECAY" && \
  echo "LEARNIG_RATE=$LEARNIG_RATE" && \
  echo "SCHEDULER=$SCHEDULER" && \
  echo "MAX_LEN=$MAX_LEN" && \
  echo "TOTAL_STEPS=$TOTAL_STEPS" && \
  echo "ACCUMULATION_STEPS=$ACCUMULATION_STEPS" && \
  echo "SAVE_FREQ=$SAVE_FREQ" && \
  echo "EVAL_FREQ=$EVAL_FREQ" && \
  echo "EVAL_PRINT_FREQ=$EVAL_PRINT_FREQ" && \
  \
  python ./FiD/train_reader.py \
        --name fid_jaqket_wiki40b \
        --use_checkpoint \
        --train_data $TRAIN_DATA \
        --eval_data $EVAL_DATA \
        --model_size base \
        --per_gpu_batch_size $PER_GPU_BATCH_SIZE \
        --n_context $N_CONTEXT \
        --checkpoint_dir $CKPT_DIR \
        --lr $LEARNIG_RATE \
        --optim $OPTIM \
        --scheduler $SCHEDULER \
        --weight_decay $WEIGHT_DECAY \
        --text_maxlength $MAX_LEN \
        --total_step $TOTAL_STEPS \
        --accumulation_steps $ACCUMULATION_STEPS \
        --save_freq $SAVE_FREQ \
        --eval_freq $EVAL_FREQ \
        --eval_print_freq $EVAL_PRINT_FREQ

簡単に train_reader.py の引数を説明しておきます。

  • --use_checkpoint :
    学習時に PyTorch の Checkpointing 機能16を使うかどうかです。計算時間と引き換えに必要メモリ量を節約できます。
  • --model_size :
    今回は “megagonlabs/t5-base-japanese-web” 固定にしてしまったので使用されません。
  • --per_gpu_batch_size :
    GPU 毎のバッチサイズです。
  • --n_context :
    FiD に投入する passage の数です。
  • --checkpoint_dir :
    フルパスで書いておかないと動作がおかしくなったような覚えがあります。
  • --lr :
    FiD のコードのデフォルト値です。FiD の論文1の記述や Retriever を蒸留する論文7の Table 6. を見たところ、 base サイズのモデルならデフォルト値で問題なさそうです。
  • --optim :
    同上。
  • --scheduler :
    同上。
  • --weight_decay :
    同上。
  • --text_maxlength :
    passage の最大長です。question, title, text をプレフィックスを組み合わせて結合した文字列がこの長さに padding あるいは truncate されます。
  • --total_step :
    バッチを実行する回数です。今回は --per_gpu_batch_size を 1 に設定したので、学習サンプル数になりますね。91420 ステップ(=7 epoch) くらいで十分でした。
  • --accumulation_steps :
    モデルのパラメータ更新をする際に累積するバッチの回数です。論文は --per_gpu_batch_size = 1 で 64 GPU を使って学習しているのですが、 Colab では GPU が 1 つしかないので、64 バッチ分を累積することで論理バッチ数を 64 にしています。
  • --save_freq :
    checkpoint を保存する間隔です。パラメータ更新 20 回毎に保存する感じですね。もう少し粗くても良いでしょう。
  • --eval_freq :
    こちらは検証を実行する間隔です。--save_freq に合わせました。
  • --eval_print_freq :
    ちゃんと調べてません(すみません)。たぶん検証結果をログに出力する間隔かなと。

学習の状況は run.log ファイルに出力されます。途中で総ステップ数が変わっているのはご愛敬で試行錯誤の後ですね。

!cat /content/checkpoint/fid_jaqket_wiki40b/run.log
# [02/04/2022 09:14:16] {train_reader.py:201} INFO - Start training
# [02/04/2022 09:14:16] {train_reader.py:30} WARNING - Tensorboard is not available.
# [02/04/2022 10:05:13] {train_reader.py:81} INFO - 1280 / 65300 |train: 9.243 |evaluation: 0.00EM |lr: 0.00010
# [02/04/2022 10:56:11] {train_reader.py:81} INFO - 2560 / 65300 |train: 6.715 |evaluation: 0.00EM |lr: 0.00010
...
#[02/07/2022 09:17:46] {train_reader.py:81} INFO - 102400 / 104480 |train: 0.290 |evaluation: 62.21EM |lr: 0.00010
#[02/07/2022 10:09:15] {train_reader.py:81} INFO - 103680 / 104480 |train: 0.285 |evaluation: 62.81EM |lr: 0.00010

ステップ数の目安が良くわからなかったのでエイやと 104480 ステップ回した時の学習曲線は以下のようになりました。

reader_learning_curve

  • Colab のランタイムの寿命が切れた場合、GCS には検証の結果は残ってないので、画面に表示されているログから拾っておいて下さい。

学習途中から再開する場合

GCS にコピーしたチェックポイントを使って学習を再開する場合は以下のようにして下さい。 FiD のコードは latest という名前で最新のチェックポイントへのリンクを作るので、それを手動で復元します。

# step-1920 から再開するイメージです
!mkdir -p ./checkpoint/fid_jaqket_wiki40b/checkpoint
!gsutil cp -r gs://somewhere/FiD/fid_jaqket_wiki40b/checkpoint/step-1920 ./checkpoint/fid_jaqket_wiki40b/checkpoint/
!ln -sf /content/checkpoint/fid_jaqket_wiki40b/checkpoint/step-1920 ./checkpoint/fid_jaqket_wiki40b/checkpoint/latest

それでは、テストデータを使って精度を確認しましょう。

テストデータでの評価

検証用のコードも少し修正する必要があります。トークナイザのモデル名が英語のモデルになっているので、

!cat FiD/test_reader.py | grep "t5-base" 
#    tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base', return_dict=False)

書き換えてしまいましょう。

!sed -i "s/'t5-base'/'megagonlabs\/t5-base-japanese-web'/"  FiD/test_reader.py
!cat FiD/test_reader.py | grep "t5-base" 
#    tokenizer = transformers.T5Tokenizer.from_pretrained('megagonlabs/t5-base-japanese-web', return_dict=False)

以下のようにして動かします。 FiD のコードは最良のチェックポイントへのリンクを best_dev で保持しているので、それを指定します。

  • Colab のランタイムが切れて GCS から復元すると、当然 best_dev のリンクも消えてしまいます。。。ただ学習曲線からすると latest を拾っておけばだいたい大丈夫かと思います。
!export PYTHONPATH=${PYTHONPATH}:.:./FiD && \
  MODEL_PATH='/content/checkpoint/fid_jaqket_wiki40b/checkpoint/best_dev' && \
  EVAL_DATA='./fid_test.jsonl' && \
  PER_GPU_BATCH_SIZE=1 && \
  N_CONTEXT=100 && \
  CKPT_DIR='/content/checkpoint' && \
  \
  date && \
  echo "MODEL_PATH=$MODEL_PATH" && \
  echo "EVAL_DATA=$EVAL_DATA" && \
  echo "PER_GPU_BATCH_SIZE=$PER_GPU_BATCH_SIZE" && \
  echo "N_CONTEXT=$N_CONTEXT" && \
  echo "CKPT_DIR=$CKPT_DIR" && \
  \
  python ./FiD/test_reader.py \
        --model_path $MODEL_PATH \
        --name fid_jaqket_wiki40b \
        --use_checkpoint \
        --eval_data $EVAL_DATA \
        --per_gpu_batch_size $PER_GPU_BATCH_SIZE \
        --n_context $N_CONTEXT \
        --checkpoint_dir $CKPT_DIR 

以下のような結果になりました。

!tail -3 /content/checkpoint/fid_jaqket_wiki40b/run.log
# [02/07/2022 12:14:01] {test_reader.py:128} INFO - Start eval
# [02/07/2022 12:23:08] {test_reader.py:73} WARNING - Process rank:0, total 997 | average = 0.667
# [02/07/2022 12:23:08] {test_reader.py:131} INFO - EM 66.70, Total number of example 997

正答率(Exact Match) で 66.70% が出ました。ちょうど 2/3 を当てたことになります。 ORQA の時は同じデータで 30 % 程度だったので倍以上の改善ですね。

それでは、ここから Reader の Cross-Attention の重みで蒸留し Retriever の学習にチャレンジしてみましょう。

9. Cross-Attention の重み取得

Colab の環境は “8. Reader の学習” のセットアップとコード修正が適用済み、学習ループを回す直前まで実行した状態とします。

コードの修正

Cross-Attention の重みを取得するために、T5 の Attention クラスの処理を差し替える挙動になっていますが、transformers のバージョン を安易に上げたのでいろいろと修正が必要になりました17

T5Attention.forward() のパラメータ名のリファクタリングに対応

どうやら T5Attention.forward() の変数名が変わっちゃったみたいです。

!cat FiD/src/model.py | awk 'NR>=195 && NR<=206{print NR"|"$0}'
#195|def cross_attention_forward(
#196|        self,
#197|        input,
#198|        mask=None,
#199|        kv=None,
#200|        position_bias=None,
#201|        past_key_value_state=None,
#202|        head_mask=None,
#203|        query_length=None,
#204|        use_cache=False,
#205|        output_attentions=False,
#206|    ):

関数の宣言部分のパラメータ名を transormers の 4.14.1 に合わせて、続きのコードと整合性を保つために関数の頭で元の変数名に戻してます。

!sed -i \
     -e '199s/kv/key_value_states/' \
     -e '202s/head_mask/layer_head_mask/' \
     -e '201s/past_key_value_state/past_key_value/' \
     -e '215s/self.d_kv/self.key_value_proj_dim/' \
     -e '209a\ \ \ \ kv=key_value_states\n\ \ \ \ head_mask=layer_head_mask\n\ \ \ \ past_key_value_state=past_key_value' FiD/src/model.py
!cat FiD/src/model.py | awk 'NR>=195 && NR<=206{print NR"|"$0}'
#195|def cross_attention_forward(
#196|        self,
#197|        input,
#198|        mask=None,
#199|        key_value_states=None,
#200|        position_bias=None,
#201|        past_key_value=None,
#202|        layer_head_mask=None,
#203|        query_length=None,
#204|        use_cache=False,
#205|        output_attentions=False,
#206|    ):

ところが動かしてみると、エラーになります。どうも(いろいろ修正した後の)215 行目の assert に引っかかったようで。。。

!cat ./FiD/src/model.py | awk 'NR>=214 && NR<=216{print NR"|"$0}'
#214|    assert(head_mask == None)
#215|    assert(position_bias != None or self.has_relative_attention_bias)
#216|

FiD が前提としている transformers 3.0.2 の後に、この不具合が原因18でこのコミット19で挙動が修正されたようです。 この修正で T5 のコードには (position_bias is None and self.has_relative_attention_bias == False のケースのコード(↓)が追加されているので、 これを移植して対応することにします。

!cat /usr/local/lib/python3.7/dist-packages/transformers/models/t5/modeling_t5.py | awk 'NR>=488 && NR<=496{print NR"|"$0}'
#488|        if position_bias is None:
#489|            if not self.has_relative_attention_bias:
#490|                position_bias = torch.zeros(
#491|                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
#492|                )
#493|                if self.gradient_checkpointing and self.training:
#494|                    position_bias.requires_grad = True
#495|            else:
#496|                position_bias = self.compute_bias(real_seq_length, key_length)

さっきの assert はコメントアウトしてしまって、

!sed -i '215s/assert/#assert/' ./FiD/src/model.py  

以下の 233 行目の直後に上記の対応コードを移植します。

!cat ./FiD/src/model.py  | awk 'NR>=233 && NR<=235{print NR"|"$0}'
#233|    if position_bias is None:
#234|        position_bias = self.compute_bias(qlen, klen)
#235|    scores += position_bias

不要部分を削除して、部品を作って、はめ込みます。

!sed -i '233,234d' ./FiD/src/model.py 
!cat /usr/local/lib/python3.7/dist-packages/transformers/models/t5/modeling_t5.py | sed 's/^....//' | sed -e 's/real_seq_length/qlen/' -e 's/key_length/klen/' \
 | awk 'NR>=488 && NR<=497{print $0}' > patch
!cat patch
#    if position_bias is None:
#        if not self.has_relative_attention_bias:
#            position_bias = torch.zeros(
#                (1, self.n_heads, qlen, klen), device=scores.device, dtype=scores.dtype
#            )
#            if self.gradient_checkpointing and self.training:
#                position_bias.requires_grad = True
#        else:
#            position_bias = self.compute_bias(qlen, klen)
!cat ./FiD/src/model.py | awk 'NR<=232{print $0}' > head
!cat ./FiD/src/model.py | awk 'NR>=233{print $0}' > tail
!cat head patch tail > ./FiD/src/model.py
!cat ./FiD/src/model.py | awk 'NR>=233 && NR<=243{print NR"|"$0}'
#233|    if position_bias is None:
#234|        if not self.has_relative_attention_bias:
#235|            position_bias = torch.zeros(
#236|                (1, self.n_heads, qlen, klen), device=scores.device, dtype=scores.dtype
#237|            )
#238|            if self.gradient_checkpointing and self.training:
#239|                position_bias.requires_grad = True
#240|        else:
#241|            position_bias = self.compute_bias(qlen, klen)
#242|
#243|    scores += position_bias

動かしてみると、tuple index out of range でエラーになりました。安易にバージョンを上げた前章の自分を呪います。 最後の tuple を組み立てるところが怪しそうですね。

!cat ./FiD/src/model.py | awk 'NR>=253 && NR<=266{print NR"|"$0}'
#253|    output = self.o(output)
#254|
#255|    if use_cache:
#256|        output = (output,) + ((k, v),)
#257|    else:
#258|        output = (output,) + (None,)
#259|
#260|    if output_attentions:
#261|        output = output + (attn,)
#262|
#263|    if self.has_relative_attention_bias:
#264|        output = output + (position_bias,)
#265|
#266|    return output

このコードを transformers 4.14.1 の T5Attentionforward() と見比べてみると、transformers 4.14.1 のコードは

  • 返却する tuple 中で possition_biasattn_weights(attn) の位置が入れ替わっています。
  • tuple の (k, v) の項に self.is_decoder の条件が追加されています。
  • tuple の position_bias の項から self.has_relative_attention_bias の条件が削除されています。

といった違いがありました。この辺りの条件を合わせてみます。

!sed -i '255s/if /if self.is_decoder and /' ./FiD/src/model.py
!sed -i '259a\ \ \ \ output = output + (position_bias,)\n' ./FiD/src/model.py 
!sed -i '265,267d' ./FiD/src/model.py
!cat ./FiD/src/model.py | awk 'NR>=253 && NR<=265{print NR"|"$0}'
#253|    output = self.o(output)
#254|
#255|    if self.is_decoder and use_cache:
#256|        output = (output,) + ((k, v),)
#257|    else:
#258|        output = (output,) + (None,)
#259|
#260|    output = output + (position_bias,)
#261|
#262|    if output_attentions:
#263|        output = output + (attn,)
#264|
#265|    return output

以下のように --write_crossattention_scores を付けて動かします。

!export PYTHONPATH=${PYTHONPATH}:.:./FiD && \
  MODEL_PATH='/content/checkpoint/fid_jaqket_wiki40b/checkpoint/best_dev' && \
  EVAL_DATA='./fid_train.jsonl' && \
  PER_GPU_BATCH_SIZE=1 && \
  N_CONTEXT=100 && \
  CKPT_DIR='/content/checkpoint' && \
  \
  date && \
  echo "MODEL_PATH=$MODEL_PATH" && \
  echo "EVAL_DATA=$EVAL_DATA" && \
  echo "PER_GPU_BATCH_SIZE=$PER_GPU_BATCH_SIZE" && \
  echo "N_CONTEXT=$N_CONTEXT" && \
  echo "CKPT_DIR=$CKPT_DIR" && \
  \
  python ./FiD/test_reader.py \
        --model_path $MODEL_PATH \
        --name fid_jaqket_wiki40b \
        --use_checkpoint \
        --eval_data $EVAL_DATA \
        --per_gpu_batch_size $PER_GPU_BATCH_SIZE \
        --n_context $N_CONTEXT \
        --checkpoint_dir $CKPT_DIR \
        --write_crossattention_scores
!gsutil cp ./checkpoint/fid_jaqket_wiki40b/dataset_wscores.json gs://somewhere/FiD/dataset_train_wscores.json

結果は、dataset_wscores.json に書き出されます。

!cat ./checkpoint/fid_jaqket_wiki40b/dataset_wscores.json | jq 'limit(1; .[])' | head -12
#{
#  "id": "ABC01-01-0003",
#  "question": "格闘家ボブ・サップの出身国はどこでしょう?",
#  "answers": [
#    "アメリカ合衆国"
#  ],
#  "ctxs": [
#    {
#      "title": "チャド・バノン",
#      "text": "ボブ・サップに続くモンスター系ファイターとしてK-1 BEAST 2003にて富平辰文と対戦。序盤には怒涛のラッシュを見せたものの慣れないローキックに翻弄され次第に失速、0-3の判定負けを喫した。その後のプロ格闘技大会への出場は確認されていない。端正な顔立ちにトレードマークのモヒカンアー、均整の取れたマッチョボディと格闘家の中でも有数の容姿を持つ。",
#      "score": -2.685264825820923
#    },

“score” 属性が追加されていることが確認できました。この値を教師として Retriever を学習する訳ですね。

同様に EVAL_DATA の値をfid_dev.jsonl に、最後の gsutil のコピー先ファイル名を dataset_dev_wscores.json に変更して実行し、 検証データのスコアも取得しておきまます。

さて、かなり適当にコードを修正しましたが大丈夫でしょうか?元データ( Sentence BERT の類似度で評価)と本章で生成したデータ( Cross-Attention の重みによる評価)で最初に回答が出現する平均順位を比べてみました。

1st_pos_of_hasanswer_true

回答を含む passage の出現位置が改善しているので、意図したとおりに動いているようですね。

では、 Retriever の学習に移ります。

10. Retriever の学習

Colab の環境は “8. Reader の学習” 以降のセットアップとコード修正が適用済みとします。アクセラレータは GPU です。

セットアップ

Retriever には BERT を使うので、MeCab 関係の追加インストールが必要です。

!apt-get install mecab mecab-ipadic-utf8 python-mecab libmecab-dev
!pip install mecab-python3 fugashi ipadic

コードの修正

Retriever 関係のコードも英語前提なので書き換えが必要です。

トークナイザの変更

トークナイザが英語モデルなので、

!cat ./FiD/train_retriever.py | grep "tokenizer ="
#    tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')

日本語のものに書き換えます。

!sed -i "s/BertTokenizerFast.from_pretrained('bert-base-uncased')/BertJapaneseTokenizer.from_pretrained('cl-tohoku\/bert-base-japanese-whole-word-masking')/" ./FiD/train_retriever.py
!cat ./FiD/train_retriever.py | grep "tokenizer ="
#    tokenizer = transformers.BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

事前学習済みモデルの変更

BERT の事前学習済みモデルも英語なので、

!cat ./FiD/src/model.py | grep "BertModel.from_pretrained"
#            self.model = transformers.BertModel.from_pretrained('bert-base-uncased')

東北大さんのBERTに差し替え。

!sed -i "s/transformers.BertModel.from_pretrained('bert-base-uncased')/transformers.BertModel.from_pretrained('cl-tohoku\/bert-base-japanese-whole-word-masking')/" ./FiD/src/model.py
!cat ./FiD/src/model.py | grep "BertModel.from_pretrained"
#            self.model = transformers.BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

このままだと、Retriever の語彙数が英語 BERT の 30522 になってしまうようなので、RetrieverConfig を書き換えて vocab_size の設定を追記します。

!sed -i '276a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ vocab_size=32000,' ./FiD/src/model.py 
!sed -i '286a\ \ \ \ \ \ \ \ self.vocab_size = vocab_size' ./FiD/src/model.py
!cat ./FiD/src/model.py |  awk 'NR>=267 && NR<=288{print NR"|"$0}'
#267|class RetrieverConfig(transformers.BertConfig):
#268|
#269|    def __init__(self,
#270|                 indexing_dimension=768,
#271|                 apply_question_mask=False,
#272|                 apply_passage_mask=False,
#273|                 extract_cls=False,
#274|                 passage_maxlength=200,
#275|                 question_maxlength=40,
#276|                 projection=True,
#277|                 vocab_size=32000,
#278|                 **kwargs):
#279|        super().__init__(**kwargs)
#280|        self.indexing_dimension = indexing_dimension
#281|        self.apply_question_mask = apply_question_mask
#282|        self.apply_passage_mask = apply_passage_mask
#283|        self.extract_cls=extract_cls
#284|        self.passage_maxlength = passage_maxlength
#285|        self.question_maxlength = question_maxlength
#286|        self.projection = projection
#287|        self.vocab_size = vocab_size

passage のサンプリング処理を追加

データ長を確認して question の最大長を 54(99.7 %ile), passage の最大長を 137(99 %ile)としましたが、K80 では n_context = 100 で実行すると Memory Error になってしまいました。 passage の数を減らすしかないですが、せっかくなので epoch の周回毎に異なる passage の集合になるようにサンプリング処理を入れることにしました。

サンプリングするのは学習時だけにするべきなので、そのフラグを足して(以下19, 20行目)、

!sed -i '18s/):/,/' ./FiD/src/data.py 
!sed -i '18a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ training=True):\n\ \ \ \ \ \ \ \ self.training=training' ./FiD/src/data.py
!cat ./FiD/src/data.py  | awk 'NR>=12 && NR<=20{print NR"|"$0}'
#12|class Dataset(torch.utils.data.Dataset):
#13|    def __init__(self,
#14|                 data,
#15|                 n_context=None,
#16|                 question_prefix='問題:',
#17|                 title_prefix='タイトル:',
#18|                 passage_prefix='文脈:',
#19|                 training=True):
#20|        self.training=training

example が保持する passage の件数が n_context より多い場合にサンプリングするようにします(以下47~49行目)。

!sed -i '47s/contexts/    contexts/' ./FiD/src/data.py
!sed -i "46a\ \ \ \ \ \ \ \ \ \ \ \ if self.training and len(example['ctxs']) > self.n_context:\n"\
"\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ contexts = random.sample(example['ctxs'], self.n_context)\n"\
"\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ contexts.sort(key=lambda x: float(x['score']), reverse=True)\n"\
'\ \ \ \ \ \ \ \ \ \ \ \ else:' ./FiD/src/data.py 
!cat ./FiD/src/data.py  | awk 'NR>=45 && NR<=52{print NR"|"$0}'
#45|        if 'ctxs' in example and self.n_context is not None:
#46|            f = self.title_prefix + " {} " + self.passage_prefix + " {}"
#47|            if self.training and len(example['ctxs']) > self.n_context:
#48|                contexts = random.sample(example['ctxs'], self.n_context)
#49|                contexts.sort(key=lambda x: float(x['score']), reverse=True)
#50|            else:
#51|                contexts = example['ctxs'][:self.n_context]
#52|            passages = [f.format(c['title'], c['text']) for c in contexts]

呼び出し元の方で先ほど追加したフラグを設定するよう書き換えます(以下179行目)。

!sed -i '179s/)/, training=False)/' ./FiD/train_retriever.py
!cat ./FiD/train_retriever.py | awk 'NR>=178 && NR<=181{print NR"|"$0}'
#178|    )
#179|    eval_dataset = src.data.Dataset(eval_examples, opt.n_context, training=False)
#180|
#181|    global_step = 0

ついでに Reader の学習スクリプトも修正しておきました(8章での Reader の学習は n_context == 100 で行ったので動作に影響ないはずですが)。

!sed -i '172s/)/, training=False)/' ./FiD/train_reader.py
!cat ./FiD/train_reader.py | awk 'NR>=171 && NR<=174{print NR"|"$0}'
#171|    )
#172|    eval_dataset = src.data.Dataset(eval_examples, opt.n_context, training=False)
#173|
#174|    if not checkpoint_exists and opt.model_path == "none":

GCS からスコア付きの学習データを取得して、

!gsutil cp gs://somewhere/FiD/fid_*_with_score.json .

以下のように実行します。ハイパーパラメータは論文7の Table 6. に合わせました。

!export PYTHONPATH=${PYTHONPATH}:.:./FiD && \
  TRAIN_DATA='./fid_train_with_score.jsonl' && \
  EVAL_DATA='./fid_dev_with_score.jsonl' && \
  PER_GPU_BATCH_SIZE=1 && \
  N_CONTEXT=64 && \
  Q_MAXLEN=54 && \
  P_MAXLEN=137 && \
  CKPT_DIR='/content/retriever_checkpoint' && \
  OPTIM='adam' && \
  WEIGHT_DECAY=0.1 && \
  LEARNIG_RATE=0.00005 && \
  SCHEDULER='fixed' && \
  TOTAL_STEPS=134400 && \
  ACCUMULATION_STEPS=64 && \
  SAVE_FREQ=6400 && \
  EVAL_FREQ=6400 && \
  EVAL_PRINT_FREQ=6400 && \
  \
  date && \
  echo "TRAIN_DATA=$TRAIN_DATA" &&\
  echo "EVAL_DATA=$EVAL_DATA" &&\
  echo "PER_GPU_BATCH_SIZE=$PER_GPU_BATCH_SIZE" &&\
  echo "N_CONTEXT=$N_CONTEXT" && \
  echo "Q_MAXLEN=$Q_MAXLEN" && \
  echo "P_MAXLEN=$P_MAXLEN" && \
  echo "CKPT_DIR=$CKPT_DIR" && \
  echo "OPTIM=$OPTIM" && \
  echo "WEIGHT_DECAY=$WEIGHT_DECAY" && \
  echo "LEARNIG_RATE=$LEARNIG_RATE" && \
  echo "SCHEDULER=$SCHEDULER" && \
  echo "TOTAL_STEPS=$TOTAL_STEPS" && \
  echo "ACCUMULATION_STEPS=$ACCUMULATION_STEPS" && \
  echo "SAVE_FREQ=$SAVE_FREQ" && \
  echo "EVAL_FREQ=$EVAL_FREQ" && \
  echo "EVAL_PRINT_FREQ=$EVAL_PRINT_FREQ" && \
  \
  python ./FiD/train_retriever.py \
        --name fid_jaqket_wiki40b_ret \
        --train_data $TRAIN_DATA \
        --eval_data $EVAL_DATA \
        --per_gpu_batch_size $PER_GPU_BATCH_SIZE \
        --n_context $N_CONTEXT \
        --question_maxlength $Q_MAXLEN \
        --passage_maxlength $P_MAXLEN \
        --checkpoint_dir $CKPT_DIR \
        --lr $LEARNIG_RATE \
        --optim $OPTIM \
        --scheduler $SCHEDULER \
        --weight_decay $WEIGHT_DECAY \
        --total_step $TOTAL_STEPS \
        --accumulation_steps $ACCUMULATION_STEPS \
        --save_freq $SAVE_FREQ \
        --eval_freq $EVAL_FREQ \
        --eval_print_freq $EVAL_PRINT_FREQ

学習の状況は run.log ファイルに出力されます。途中で総ステップ数が変わっているのはご愛敬で試行錯誤の後ですね。

!cat /content/retriever_checkpoint/fid_jaqket_wiki40b_ret/run.log
#[02/07/2022 16:26:40] {train_retriever.py:91} INFO - 6400 / 134400 -- train: 0.000886, eval: 0.001246, inv: 643.9, lr: 0.000050 | avg top1: 30.6 | avg top2: 35.3 | avg top5: 41.7 | idx top1: 7.0 | idx top2: 13.6 | idx top5: 28.6
#[02/07/2022 17:26:11] {train_retriever.py:91} INFO - 12800 / 134400 -- train: 0.000707, eval: 0.001314, inv: 621.4, lr: 0.000050 | avg top1: 36.3 | avg top2: 38.9 | avg top5: 45.1 | idx top1: 6.1 | idx top2: 12.4 | idx top5: 26.3
#...
#[02/08/2022 11:17:47] {train_retriever.py:91} INFO - 128000 / 134400 -- train: 0.000180, eval: 0.001449, inv: 544.3, lr: 0.000050 | avg top1: 44.7 | avg top2: 46.7 | avg top5: 53.7 | idx top1: 5.1 | idx top2: 10.2 | idx top5: 21.8
#[02/08/2022 12:17:17] {train_retriever.py:91} INFO - 134400 / 134400 -- train: 0.000173, eval: 0.001373, inv: 542.7, lr: 0.000050 | avg top1: 44.3 | avg top2: 47.2 | avg top5: 53.8 | idx top1: 5.2 | idx top2: 10.2 | idx top5: 22.0

検証のメトリクスについて補足します。

  • inv :
    推論したスコアを昇順に並べ「推論スコアの順位 k より下位で Gold サンプル上の順位が k より上位であったサンプル数」を全ての順位について合計した値。ようは Gold サンプルの順位付けに反した数ですね。
  • avg topk :
    予測スコアの Top-K が Gold スコアの Top-K に含まれた割合(%)
  • idx topk :
    Gold スコアの Top-K を全て拾うために必要なサンプル数の平均.

学習曲線はこんな感じです。縦軸が 3 本になると面倒なので avg topk と idx topk は 10 倍してプロットしてます。

learning_curve_of_retriever

検証ロスで見ると学習開始直後のステップ 19200 が最良ですが、その後も他のメトリクスは徐々に改善してます。 ラベルのスコアとの誤差が一番小さいのは 19200 なのですが、サンプル間のスコアの上下関係を把握する能力は学習終盤の方が良いということなんでしょう。 そういう訳なので best_dev は無視して最終のチェックポイントを使うのが良さそうです。

ここからは学習した Retriever を使って Reader の学習データを作り直していきます。

11. 学習済み Retriever を用いた passage の埋め込み

前章で学習した Retriever を使って passage の埋め込み表現を再生成します。

本章も Colab の環境は “8. Reader の学習” 以降のセットアップとコード修正が適用済みとし、アクセラレータは GPU です。

FiD のコードが想定している passage の形式を公開されている英語のデータで確認してみましょう。

!wget -c https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
!gzip -d psgs_w100.tsv.gz
!head -3 psgs_w100.tsv
#id text    title
#1  "Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from"  Aaron
#2  "God at Sinai granted Aaron the priesthood for himself and his male descendants, and he became the first High Priest of the Israelites. Aaron died before the Israelites crossed the North Jordan river and he was buried on Mount Hor (Numbers 33:39; Deuteronomy 10:6 says he died and was buried at Moserah). Aaron is also mentioned in the New Testament of the Bible. According to the Book of Exodus, Aaron first functioned as Moses' assistant. Because Moses complained that he could not speak well, God appointed Aaron as Moses' ""prophet"" (Exodus 4:10-17; 7:1). At the command of Moses, he let"   Aaron

ヘッダ付きの TSV で列は id, text, title の順で用意すれば良さそうです。

Wiki40b から作った日本語の passage をこのフォーマットに変換しましょう。

GCS の認証を通して、生成済みの passage を取得します。

from google.colab import auth
auth.authenticate_user()
!gsutil cp gs://somewhere/FiD/contexts-*-*.pkl .

本記事の手順通りであれば GCS から取り出したファイル名は以下のようになっているはずです。

import pickle
contexts = []

fnames = []
head_and_tail = list(range(0, 828236, 64000)) + [828236]
for i in range(len(head_and_tail[:-1])):
  head = head_and_tail[i]
  tail =  head_and_tail[i+1]
  fnames.append("contexts-{}-{}.pkl".format(head, tail))
fnames
#['contexts-0-64000.pkl',
# 'contexts-64000-128000.pkl',
# 'contexts-128000-192000.pkl',
# 'contexts-192000-256000.pkl',
# 'contexts-256000-320000.pkl',
# 'contexts-320000-384000.pkl',
# 'contexts-384000-448000.pkl',
# 'contexts-448000-512000.pkl',
# 'contexts-512000-576000.pkl',
# 'contexts-576000-640000.pkl',
# 'contexts-640000-704000.pkl',
# 'contexts-704000-768000.pkl',
# 'contexts-768000-828236.pkl']

これをメモリにロードして結合します。

for fname in fnames:
  with open(fname, "rb") as f:
    subset = pickle.load(f)
    contexts.extend(subset)
len(contexts)
# 5561547

先程確認した TSV のフォーマットに変換して出力します。

  • じつは contexts_*_*.pkl の生成時に text が “” になるデータが少数出来てしまっていました。このまま後続の学習データの再編成をすると、やたらと text == “” の passage に食いつくようになってしまったので、この段階で text == “” の passage を除去することにしました。
lines = [["id", "text", "title"]]
for i, context in enumerate(contexts) :
  lines.append([str(i), context[1], context[0]])

with open("wiki40b_psgs_w100.tsv", "w") as f:
  f.write("\n".join(["\t".join(line) for line in lines if len(line[1]) > 0]))
!wc -l wiki40b_psgs_w100.tsv
# 5559174 wiki40b_psgs_w100.tsv

GCS に退避しておきましょう。

!gsutil cp wiki40b_psgs_w100.tsv gs://somewhere/FiD/

コード修正

Retriever を使った埋め込みでも一部コードを修正する必要がありました。

プレフィックスの書き換え

8 章で言及したプレフィックスの設定がここにもありました。

!cat ./FiD/generate_passage_embeddings.py | grep -e "title:" -e "context:"
#    dataset = src.data.TextDataset(passages, title_prefix='title:', passage_prefix='context:')

8 章の修正と整合性を持たせる形で書き換えます。

!sed -i -e 's/title:/タイトル:/' -e 's/context:/文脈:/' ./FiD/generate_passage_embeddings.py 
!cat ./FiD/generate_passage_embeddings.py | grep -e "dataset ="
#    dataset = src.data.TextDataset(passages, title_prefix='タイトル:', passage_prefix='文脈:')

トークナイザの変更

トークナイザも英語 BERT になっているので、

!cat ./FiD/generate_passage_embeddings.py | grep -e "tokenizer ="
#    tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')

これも書き換えます。

!sed -i "s/BertTokenizerFast.from_pretrained('bert-base-uncased')/BertJapaneseTokenizer.from_pretrained('cl-tohoku\/bert-base-japanese-whole-word-masking')/" ./FiD/generate_passage_embeddings.py
!cat ./FiD/generate_passage_embeddings.py | grep -e "tokenizer ="
#    tokenizer = transformers.BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

TSV のクォーテーションの設定変更

passage の両端に “ が入っていると誤動作するので、設定を変更します。

!cat ./FiD/src/util.py | awk 'NR>=220 && NR<=220{print NR"|"$0}'
#220|        reader = csv.reader(fin, delimiter='\t')

とりあえず、クォーテーションを考慮せずにタブ文字だけ見て切り分けるようにしました。レアケースなのでこれで十分かと思います。

!sed -i '220s/)/, quoting=csv.QUOTE_NONE)/' ./FiD/src/util.py
!cat ./FiD/src/util.py | awk 'NR>=220 && NR<=220{print NR"|"$0}'
#220|        reader = csv.reader(fin, delimiter='\t', quoting=csv.QUOTE_NONE)

学習済みモデルの展開

Retriever の最終チェックポイントを GCS から持ってきて、latest と best_dev を復元しておきます。

!mkdir -p ./retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint
!gsutil cp -r gs://somewhere/FiD/fid_jaqket_wiki40b_ret/checkpoint/step-134400 ./retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/
!ln -sf /content/retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/step-134400 ./retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/latest
!cp -a /content/retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/step-134400 ./retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/best_dev

学習済み Retriever を用いた passage の埋め込み

以下のようにして実行します。

!export PYTHONPATH=${PYTHONPATH}:.:./FiD && \
  PER_GPU_BATCH_SIZE=128 && \
  P_MAXLEN=137 && \
  MODEL_DIR='/content/retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/best_dev' && \
  PASSAGES_TSV='wiki40b_psgs_w100_small.tsv' && \
  OUTPUT_PATH="wikipedia_embeddings" && \
  \
  date && \
  echo "PER_GPU_BATCH_SIZE=$PER_GPU_BATCH_SIZE" &&\
  echo "P_MAXLEN=$P_MAXLEN" && \
  echo "MODEL_DIR=$MODEL_DIR" && \
  echo "PASSAGES_TSV=$PASSAGES_TSV" && \
  echo "OUTPUT_PATH=$OUTPUT_PATH" && \
  \
  python ./FiD/generate_passage_embeddings.py \
        --model_path $MODEL_DIR \
        --passages $PASSAGES_TSV \
        --output_path $OUTPUT_PATH \
        --passage_maxlength $P_MAXLEN \
        --shard_id 0 \
        --num_shards 1 \
        --per_gpu_batch_size $PER_GPU_BATCH_SIZE

wkipedia_embedding_00 という名前でファイルが生成されます。

!ls -lh wikipedia_embeddings*
-rw-r--r-- 1 root root 2.8G Jan 18 04:24 wikipedia_embeddings_00

GCS に退避しておきましょう。

!gsutil cp wikipedia_embeddings* gs://somewhere/FiD/

だいぶ長くなってきましたね。。。あとは作成した passage の埋め込みと Retriever を使って Reader の学習データを再編成するだけです。

12. 学習済み Retriever を使った Reader の学習データの再編成

ここからは学習済み Retriever を使った Reader の学習データの再編成していきます。

  1. Reader の学習データの question を Retriever で埋め込み表現にし、
  2. それを条件に前章で生成した passage 埋め込み表現を検索して、
  3. 類似度の高いものを 「 question の回答が含まれている可能性が高い passage 」として Reader の学習データを生成します。

本章も Colab の環境は “8. Reader の学習” 以降のセットアップとコード修正が適用済みとし、アクセラレータは GPU です。

セットアップ

FiD のコードは埋め込み表現の検索に FAISS を使うので必要なパッケージをインストールします。

!pip install faiss
!apt-get install libomp-dev

コードの修正

ここでも日本語がらみのコード修正が必要です。

トークナイザの変更

英語のモデルになっているので、

!cat ./FiD/passage_retrieval.py | grep "tokenizer ="
#     tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')

日本語にします。

!sed -i "s/BertTokenizerFast.from_pretrained('bert-base-uncased')/BertJapaneseTokenizer.from_pretrained('cl-tohoku\/bert-base-japanese-whole-word-masking')/" ./FiD/passage_retrieval.py
!cat ./FiD/passage_retrieval.py | grep "tokenizer ="
#    tokenizer = transformers.BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

正誤判定ロジックの変更

正誤判定というのは「 passage に answer が含まれているかどうか?」という判定です。

オリジナルの関数( has_answer() )は正規化したりトークナイズしたりと何やらいろいろやっていますが、 単純化して「 answer が title か text に含まれていれば OK」という処理( has_answer2() )にします。

!sed -i '113adef has_answer2(answers, title, text, tokenizer):\n\ \ \ \ return sum([answer in title or answer in text for answer in answers]) > 0\n' ./FiD/src/evaluation.py 
!cat ./FiD/src/evaluation.py | awk 'NR>=101 && NR<=115{print NR"|"$0}'
#101|def has_answer(answers, text, tokenizer) -> bool:
#102|    """Check if a document contains an answer string."""
#103|    text = _normalize(text)
#104|    text = tokenizer.tokenize(text, uncased=True)
#105|
#106|    for answer in answers:
#107|        answer = _normalize(answer)
#108|        answer = tokenizer.tokenize(answer, uncased=True)
#109|        for i in range(0, len(text) - len(answer) + 1):
#110|            if answer == text[i: i + len(answer)]:
#111|                return True
#112|    return False
#113|
#114|def has_answer2(answers, title, text, tokenizer):
#115|    return sum([answer in title or answer in text for answer in answers]) > 0

この関数の呼び出しもとを

!cat ./FiD/src/evaluation.py | awk 'NR>=97 && NR<=97{print NR"|"$0}'
#97|        hits.append(has_answer(answers, text, tokenizer))

書き換えて、

!sed -i '97s/has_answer(answers, text/has_answer2(answers, title, text/' ./FiD/src/evaluation.py
!cat ./FiD/src/evaluation.py | awk 'NR>=96 && NR<=98{print NR"|"$0}'
#97|        hits.append(has_answer2(answers, title, text, tokenizer))

よく考えたら、 has_answer2() は引数( title )が増えているので、そこも直さないと。。。

!cat ./FiD/src/evaluation.py | awk 'NR>=89 && NR<=97{print NR"|"$0}'
#89|    for i, doc in enumerate(ctxs):
#90|        text = doc['text']
#91|
#92|        if text is None:  # cannot find the document for some reason
#93|            logger.warning("no doc in db")
#94|            hits.append(False)
#95|            continue
#96|
#97|        hits.append(has_answer2(answers, title, text, tokenizer))

こんな感じにしました。

!sed -i '92s/text is/title is None or text is/' ./FiD/src/evaluation.py 
!sed -i "89a\ \ \ \ \ \ \ \ title = doc['title']" ./FiD/src/evaluation.py
!cat ./FiD/src/evaluation.py | awk 'NR>=89 && NR<=98{print NR"|"$0}'
#89|    for i, doc in enumerate(ctxs):
#90|        title = doc['title']
#91|        text = doc['text']
#92|
#93|        if title is None or text is None:  # cannot find the document for some reason
#94|            logger.warning("no doc in db")
#95|            hits.append(False)
#96|            continue
#97|
#98|        hits.append(has_answer2(answers, title, text, tokenizer))

GCS から必要なデータを取得します。

from google.colab import auth
auth.authenticate_user()

!gsutil cp gs://somewhere/FiD/fid_*.jsonl .
!gsutil cp gs://somewhere/FiD/wiki40b_psgs_w100.tsv .
!gsutil cp gs://somewhere/FiD/wikipedia_embeddings_* .

!mkdir -p ./retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint
!gsutil cp -r gs://somewhere/FiD/fid_jaqket_wiki40b_ret/checkpoint/step-134400 ./retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/
!ln -sf /content/retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/step-134400 ./retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/latest
!cp -a /content/retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/step-134400 ./retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/best_dev

それでは、以下のようにして実行します。 DATA_PATH, OUTPUT_PATH を修正して、検証データ、テストデータにも同様の処理をします。

!export PYTHONPATH=${PYTHONPATH}:.:./FiD && \
  PER_GPU_BATCH_SIZE=64 && \
  Q_MAXLEN=54 && \
  MODEL_DIR='/content/retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/best_dev' && \
  PASSAGES_TSV='wiki40b_psgs_w100.tsv' && \
  DATA_PATH='./fid_train.jsonl' && \
  EMBEDDINGS="wikipedia_embeddings_*" && \
  OUTPUT_PATH="retrieved_train_data.json" && \
  N_DOCS=100 && \
  \
  date && \
  echo "PER_GPU_BATCH_SIZE=$PER_GPU_BATCH_SIZE" &&\
  echo "Q_MAXLEN=$Q_MAXLEN" && \
  echo "MODEL_DIR=$MODEL_DIR" && \
  echo "PASSAGES_TSV=$PASSAGES_TSV" && \
  echo "DATA_PATH=$DATA_PATH" && \
  echo "EMBEDDINGS=$EMBEDDINGS" && \
  echo "OUTPUT_PATH=$OUTPUT_PATH" && \
  echo "N_DOCS=$N_DOCS" && \
  \
  python ./FiD/passage_retrieval.py \
    --model_path $MODEL_DIR \
    --passages $PASSAGES_TSV \
    --data $DATA_PATH \
    --passages_embeddings $EMBEDDINGS \
    --output_path $OUTPUT_PATH \
    --question_maxlength $Q_MAXLEN \
    --n-docs $N_DOCS 
# ...
#[02/04/2022 10:08:45] {evaluation.py:63} INFO - Matching answers in top docs...
#[02/04/2022 10:08:46] {evaluation.py:71} INFO - Per question validation results len=100
#[02/09/2022 05:11:58] {evaluation.py:71} INFO - Per question validation results len=13061
#[02/09/2022 05:11:58] {passage_retrieval.py:88} INFO - Validation results: top k documents hits [7457, 8741, 9290, 9644, ...]
#[02/09/2022 05:11:58] {passage_retrieval.py:90} INFO - Validation results: top k documents hits accuracy [0.5709363754689534, 0.6692443151366664, 0.7112778500880483, 0.738381440931016, ...]

実行ログの最後の2行は 100 の数字が出てきます。最後から 2 行目の "top k documents hits” の 7457 は検索結果の top-1 に answer が含まれていたサンプル数、最後の行の 0.5709363754689534 はその割合ですね。「検索結果 1 位の 57 % が回答を含んでいる」という意味になります。

GCS に生成したデータを退避します。

!gsutil cp retrieved_*_data.json gs://somewhere/FiD/

Cross-Attention の重みを蒸留した学習データができたので、 Reader の性能が上がるか試してみました。

13. Cross-Attention の重みを蒸留した学習データでの再学習

前章で作成したデータで再度 Reader を学習しました。

手順的には 8. Reader の学習 と同じなので割愛します。 先程の続きからではなく、Megagon Labs さんの事前学習済みモデルを起点に学習するので学習データのところだけ、以下のように書き替えます。

...
  TRAIN_DATA='./retrieved_train_data.json' && \
  EVAL_DATA='./retrieved_test_data.json' && \
...

学習後にテストデータ( retrieved_test_data.json )で評価してみました。

[02/11/2022 16:04:46] {test_reader.py:73} WARNING - Process rank:0, total 997 | average = 0.625
[02/11/2022 16:04:46] {test_reader.py:131} INFO - EM 62.49, Total number of example 997

初回の結果( EM 66.70 )に届きませんね。。。再編成前のテストデータ( ‘fid_test.jsonl’ ) でも試してみました。

[02/11/2022 15:39:32] {test_reader.py:73} WARNING - Process rank:0, total 997 | average = 0.630
[02/11/2022 15:39:32] {test_reader.py:131} INFO - EM 62.99, Total number of example 997

プロットするとこんな感じです。色で学習データ、縞で評価データを示してます。

em_comparison

残念ながら、 Cross-Attention の重みを蒸留した学習データからは期待した効果が得られませんでした(橙の棒)。

最初の学習データと Cross-Attention の重みで再編成したデータを比べてみましょう。 前章の passage_retrieval.py が精度計算に使っている関数でスコアを出してプロットしてみました。

top_k_acc

青が最初の学習データ、橙 が Cross-Attention の重みで再編成したデータです。 左に寄ると正解が出現する順位が高く、上に寄ると正解が含まれる確率が高くなます。

retrieved(train) が突出して良いですが、このデータで学習したのですから、これは当然ですね。

validation, test も top-1~3 あたりは 橙 が若干左に来ています。これは Cross-Attention の重みを蒸留した効果だと思います。 ですが、top-100 での上下を確認すると 橙 は 青 に及ばない感じです。

train だけ比べれば橙は青より正解を含んだサンプルが多いわけですから Reader の精度は上がっても良いような気がするんですけどね。 ただ多いと言っても 1.5 % 程度なので、何度か実験して平均しないと何とも言えないところかもしれません。

Retriever について考えると top-100 における validation, test の上下関係は明確に青に劣っています。 Reader は、 top-100 に回答文字列が入っていないと苦しい訳ですから、Retriever としては Sentence BERT (青) が、 Cross-Attention の重みで蒸留したモデル(橙)より優れていると言えるでしょう。

ただ、前者は次元数が 768 、後者は 256 なのでその辺りの設定の差が出たかもしれませんね。

せっかく 256 次元の埋め込み表現を作ったので、推論してみましょう。ギリギリ Colab で動きました。

14. 推論の実行

本章も Colab の環境は “8. Reader の学習” 以降のセットアップとコード修正が適用済みとし、アクセラレータは GPU です。

まずは、ここまで作ってきたものを GCS から拾ってきて以下の状態にして下さい。

Reader と Retriever のチェックポイントは以下のような感じ。

!find . -name step-* 
# ./checkpoint/fid_jaqket_wiki40b/checkpoint/step-88320
# ./retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/step-134400

Cross-Attention の重みで作った passage のテキストを埋め込み表現です。

!ls wiki*
# wiki40b_psgs_w100.tsv  wikipedia_embeddings_00

どうにもメモリが足らないので再起動しながらでないと動かせませんでした。。

問題を埋め込み表現に変換

クイズの問題はこんな感じです。最初のはテストセットの一行目、二つ目は適当に考えました。

questions = [
  '和名をハダカカメガイといい、実は巻き貝の一種とされている、その姿から「流氷の天使」と呼ばれる動物は何でしょう?',
  '鳥山明の漫画「ドラゴンボール」に登場する、ベジータとブルマの息子の超サイヤ人は誰でしょう?'           
]

以下のようにして Retriever で埋め込み表現にして保存します。

import transformers
import sys
sys.path.append("./FiD")
from src.model import Retriever
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import numpy as np

Q_MAXLEN=54

bert_tokenizer = transformers.BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
question = bert_tokenizer.batch_encode_plus(questions, pad_to_max_length=True, 
             return_tensors="pt", max_length=Q_MAXLEN,
             truncation=True)
question_ids = question['input_ids']
question_mask = question['attention_mask'].bool()

retriever = Retriever.from_pretrained("./retriever_checkpoint/fid_jaqket_wiki40b_ret/checkpoint/step-134400")
retriever.cuda()
retriever.eval()

questions_embedding = retriever.embed_text(
                text_ids=question_ids.to(device).view(-1, question_ids.size(-1)),
                text_mask=question_mask.to(device).view(-1, question_ids.size(-1)),
                apply_mask=retriever.config.apply_question_mask,
                extract_cls=retriever.config.extract_cls,)

np.save("questions.npy", questions)
np.save("questions_embedding.npy", questions_embedding.to('cpu').detach().numpy().copy())

ここで再起動です

問題の埋め込み表現で passage を検索

先程の埋め込み表現を使って、対応する passage のインデックスを FAISS で検索します。

import pickle
import sys
sys.path.append("./FiD")
from src.index import Indexer
from passage_retrieval import add_embeddings
import numpy as np

N_DOCS = 100

index = Indexer(256, 0, 8)

with open("./wikipedia_embeddings_00", "rb") as f:
  ids, embeddings = pickle.load(f)

indexing_batch_size = 50000
while embeddings.shape[0] > indexing_batch_size:
  embeddings,ids = add_embeddings(index, embeddings, ids, indexing_batch_size)
add_embeddings(index, embeddings, ids, indexing_batch_size)

questions_embedding = np.load("questions_embedding.npy")

top_ids_and_scores = index.search_knn(questions_embedding, N_DOCS)

with open("top_ids_and_scores.pkl", "wb") as f:
  pickle.dump(top_ids_and_scores, f)

またまた、再起動です。

検索したインデックスから passage を取得し、Reader で推論する

先程検索したインデックスで passage を拾って Reader で推論します。

import transformers
import sys
sys.path.append("./FiD")
from src.model import FiDT5
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import pickle
import json
import numpy as np

with open("top_ids_and_scores.pkl", "rb") as f:
  top_ids_and_scores = pickle.load(f)

with open("./wiki40b_psgs_w100.tsv") as f:
  lines = f.readlines()
  lines = [line.strip() for line in lines]
  lines = [line.split("\t") for line in lines]
  passage_map = {line[0]: (line[2], line[1]) for line in lines}

def get_passages(passage_map, top_ids):
  passages = []
  for top_id in top_ids:
    passages.append(passage_map[top_id])
  return passages

all_passages = []
for i in range(len(top_ids_and_scores)):
  all_passages.append(get_passages(passage_map, top_ids_and_scores[i][0]))

t5_tokenizer = transformers.T5Tokenizer.from_pretrained('megagonlabs/t5-base-japanese-web', return_dict=False)
reader = FiDT5.from_pretrained("./checkpoint/fid_jaqket_wiki40b/checkpoint/step-88320")
reader.cuda()
reader.eval()

questions = np.load("questions.npy")

MAX_SEQ_LEN=200
QUESTION_PREFIX = "問題:"
TITLE_PREFIX = "タイトル:"
TEXT_PREFIX = "文脈:"

def tokenize_function(questions, all_passages, tokenizer):
  questions = [QUESTION_PREFIX + " " + question for question in questions]
  passages_of_examples =  [[questions[i] + " " + TITLE_PREFIX + " " + p[1] + " " + TEXT_PREFIX + " " + p[0] for p in passages] for i, passages in enumerate(all_passages)]
  input_ids = []
  attention_mask = []
  for passages in passages_of_examples:
    inputs = tokenizer.batch_encode_plus(passages, max_length=MAX_SEQ_LEN, pad_to_max_length=True, truncation=True)
    input_ids.append(inputs["input_ids"])
    attention_mask.append(inputs["attention_mask"])
  input_ids = np.array(input_ids)
  attention_mask = np.array(attention_mask)
  return {"input_ids": input_ids, "attention_mask": attention_mask}

inputs = tokenize_function(questions, all_passages, t5_tokenizer)

outputs = reader.generate(input_ids=torch.tensor(inputs["input_ids"]).cuda(),
                attention_mask=torch.tensor(inputs["attention_mask"]).cuda(), max_length=50)
answers = [t5_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

for q, a in zip(questions, answers):
  print("Q:{} \n  => A:{}".format(q, a))
# Q:和名をハダカカメガイといい、実は巻き貝の一種とされている、その姿から「流氷の天使」と呼ばれる動物は何でしょう? 
#   => A:クリオネ
# Q:鳥山明の漫画「ドラゴンボール」に登場する、ベジータとブルマの息子の超サイヤ人は誰でしょう? 
#   => A:トランクス

おぉ、両方とも正解しました!

15. おわりに

今回は、 FiD を使ってクイズに答えるモデルに再挑戦して正答率は 2 倍以上になりました。テキスト生成モデルですから、 大量の文書を読んで判定したり、対話したり色々と応用できるかもしれませんね。 モデルとしては T5 に reshape の処理が入っただけですから自前で再実装するのも難しくなさそうです。

次回は Prompt Tuning の話をしようかと思いはじめました。 モデル丸ごとをファインチューン出来ない規模のモデルを扱う人のための話かと思ってたんですが、 モデルを固定してプロンプトで多様なタスクに対応できれば推論環境のランニングコストを低減できるかなと思ったり。 あと、T5X20 を触ってみたいというのもありますね。 でも base のサイズ感だと精度的には少し落ちる感じになるみたいなので、どうしようかな。。。


  1. https://arxiv.org/abs/2007.01282 

  2. 外部情報を差し替えて最新の状況に対応したり、間接的に出力テキストを制御したりできたら便利なんじゃないかと。 

  3. https://arxiv.org/abs/2112.04426 

  4. 他にも RAG (https://arxiv.org/abs/2005.11401) とか EMDR (https://arxiv.org/abs/2106.05346) とかあるみたいですが、一番手軽に試せそうだったんでコレにしました。。 

  5. https://github.com/facebookresearch/FiD 

  6. https://huggingface.co/megagonlabs/t5-base-japanese-web 

  7. https://arxiv.org/abs/2012.04584 

  8. REALM, RAG, EMDR といったところですね。理にかなってはいるのですが、計算に必要なメモリとかパワーとかがキツくなるんですよね。。。 

  9. 多分あってると思います。違ってたらゴメンナサイ。 

  10. https://hironsan.hatenablog.com/entry/how-to-use-wiki40b 

  11. 作ってない or 消しちゃった人(ほとんどの方でしょうが)は第18回の記事を参考に作って下さい。スミマセン。。。ちなみに私は社内のサーバを使って 2 GPU でバッチサイズ 96、F1 = 24.35, AP = 13.58 のモデルを使いました。 

  12. 「encode_contexts()と共通化してバッチ組むとこだけ引数に関数を差し込んだら?」と言われてしまいそう。。。うーん、そのとおりですね。 

  13. https://github.com/google-research/google-research/blob/master/scann/docs/algorithms.md#rules-of-thumb 

  14. 1 件目でない理由は 2 件目の方が抽出された passage がそれっぽかったからです。また ctxs は先頭 5 件だけ抜き出してます。 

  15. ライセンスは “Attribution-NonCommercial 4.0 International” なので気を付けてください。 

  16. https://pytorch.org/docs/stable/checkpoint.html ちなみに FiD のコードは自前で checkpointing の実装をしていますが、transformers の 4.14.1 の T5 には同様の機能が組み込まれています。 https://github.com/huggingface/transformers/issues/6564 

  17. transformers の 4.14.1 だと generate()output_attentions=Truereturn_dict_in_generate=True を渡せば推論結果の dict から cross_attentions をキーとして取得することができます。 3.0.2 当時はこの機能なかったんでしょうか。。。すいません。そこまで調べてません。正直、 encoder 前後の reshape 以外は 4.14.1 の標準機能で対応できそうなので全部自前で作ったほうが早かったかもしれません。。。 

  18. https://github.com/huggingface/transformers/pull/8518 

  19. https://github.com/huggingface/transformers/commit/42e2d02e44e1d10d7863eb18c5db581f049f7cb2 

  20. https://github.com/google-research/t5x