オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第13回 ORQA でクイズに答えるモデルを作る
オージス総研 技術部 アドバンストテクノロジセンター
鵜野 和也
2021年2月18日

今回は BERT の応用的な使い方として、Google の ORQA を用いてクイズに答えるモデルを作ります。 ORQA は内部的には BERT を3つ使うので、前回 ELECTRA を用いて生成した事前学習済みモデルを起点にします。さて、どの程度の正答率が出せるでしょうか?

1. はじめに

BERT に関しては 前回 までで、文章分類したり、固有表現抽出をしたりと単体で利用するパターンはある程度こなしてきたので、今回は BERT の応用的な使い方として複数の BERT を組み合わせて利用し、クイズに答えるモデルを作ってみましょう。手法としては Google の ORQA 1を用います。ORQA は Open Domain Question Answering タスクを解くモデルです。

ここで Open Domain Question Answering について説明しておきます。自然言語処理のタスクで Question Answering と言えば、学習データとして“質問"と"文献"、"解答"が与えられ、"文献"の中から"質問"に対応する"解答"を抽出するタスクになります。 SQuAD 2 が有名ですね。

これに対し Open Domain Question Answering(以下、OpenQA) は学習データとしては"質問"と"解答” のみ が与えられ、Question Answering の“文献"にあたる情報を Wikipedia 等の外部知識を参照して自力で拾い出し解答するタスクになります。

それではまず、 ORQA について簡単に説明します。

2. ORQA

ORQA(Open Retrieval Question Answering) は 2019 年に Google が発表した OpenQA タスクを解くモデルです。ORQA も前述のように外部知識から適切な文献を検索する部分("retriever”)と、検索した文献と質問から解答を抽出する部分(“reader”)で構成されます。 ORQA の特徴は以下の二つになります。

  • 質問 - 解答ペアのみで retriever と reader の両方を学習する。
  • 上記を可能とする為、retriever に対して Inverse Cloze Task と呼ばれる事前学習を導入した。

順番に少し詳しく見ていきましょう。

ORQA の概要

ORQA の全体像は以下の通りです。左側(BERTQ、BERTB)が retriever 、右側(BERTR)が reader に相当します。

orqa

BERT が Q、B、R の3つ登場していますが、それぞれ以下のような役目です。

  • BERTQ : 質問文を埋め込み表現にする。
  • BERTB : 検索対象となる文書を埋め込み表現にする。Inverse Cloze Task の後はパラメータ固定。
  • BERTR : 検索された文書と質問文から解答を抽出する。

先ほど、「質問 - 解答ペアのみで retriever と reader の両方を学習する」と説明しましたが、Inverse Cloze Task による事前学習で retriever がある程度の性能を出せるようになった後は BERTB のパラメータは固定され、BERTQ、BERTR のパラメータのみ更新されます。

BERTB のパラメータが更新されてしまうと検索対象の文書全件を埋め込み直さないといけないので、これは当然というか仕方ないというか。その代わり BERTQ 側を更新して、質問 - 解答ペアに合わせて寄せていく感じなのでしょう。学習は以下のような処理になります。

事前処理

まず全ての検索対象文書を BERTB で埋め込み表現にしておきます。

retriever での処理

  1. 質問文を BERTQ で埋め込み表現にする。
  2. 1. の埋め込み表現を検索対象文書の埋め込み表現と突き合わせて類似度の高い Top-K を検索する。

数式で書くと以下のとおりです、BERTQ、BERTR の埋め込み表現として “[CLS]” トークンに対応する出力を取ります。それに重みを掛け両者の内積を取って、retriever のスコア Sretr(b,q) とします。

orqa

reader での処理

  1. 質問文と retriever が検索した Top-K の各文書を “[CLS]質問文[SEP]検索された文書[SEP]"の形にして BERTR に投入。
  2. 解答の候補とするスパン s のスコア Sread(b,s,q) は、s の先頭と末尾の位置(START(s), END(s))の出力を連結し、MLPに通してスカラー値にします。

orqa

質問が "What does the zip in zip code stand for ?” で、検索された候補文書が “… the term ‘ZIP’ is an acronym for Zone Improvement Plan …” であり、解答候補のスパンが “Zone Improvement Plan” の場合を図にするとこんな感じになります。

orqa

最終的なスコア S(b,s,q) は Sretr(b,q) と Sread(b,s,q) の和であり、S(b,s,q) が最大となる b, s に対応するテキストが解答となります。

orqa

目的関数

q に対する b, s の確率は前述の S(b,s,q) を用いて以下のようになります。論文を読んでいて分母にある “s∈b” つまり「 TOP(k) の文書に含まれる全ての候補スパン」のところで、「どこから出てきたんだ?」と混乱しましたが、公開されたソースコードを読んでみると候補スパンの最大長を決めておいて総当たりで候補スパンの集合を生成していました3

orqa

損失は TOP(k) の文書に含まれる解答に合致する候補スパン(s∈b, a=TEXT(s))の負の対数尤度になります。

orqa

また retriever の検索結果に対し、より広範な範囲に計算負荷の軽い損失を適用します。Lfullは実験では上位5件(k=5)を用いていましたが、こちらは上位5000件(c=5000) での計算になります。Lfull は retriever の検索結果上位5件の文書に正解が含まれていないと学習しようがありません。そこで上位5000件に対して、とりあえず文書に正解が含まれていれば、retriever が更新されるロスを導入しています。

orqa

最終的な目的関数は上記二つの和になります。

orqa

BERT による計算は k = 5 だとすると、 BERTQ による質問文の埋め込みで1回、 BERTR による質問文と検索された文書のトークン単位の埋め込みで5回と計6回必要になります(最初の BERTQ の1回は系列長が短い分軽いですが)。

続いて事前学習である Inverse Cloze Task を見ていきましょう。

Inverse Cloze Task

さて、学習開始直後は Wq, Wb は初期化したばかり、 BERTQ, BERTB も BERT の事前学習済みモデルのママなので、質問文による検索結果はほぼランダムになります。

ですが、前述した Lfull にしろ Learly にしろ、検索結果として解答を含む文書がある程度上位にこないと、まったく学習が進みません。この状況を打開するために Inverse Cloze Task(以下、ICT) という教師なし事前学習が導入されています。

以下は ICT のイメージです。ICT は文書から1文を抽出し、これを仮想の質問文(q)とします(下図青部分)。1文を抜いた残りを連結し仮想のコンテキスト(b)とします(下図オレンジ部分)。この仮想のコンテキストを文書集合からサンプリングした他の文書と混ぜ、仮想の質問文 q で検索します。これで Sretr(b,q) が高くなるように学習する訳です。

orqa

数式で書くと以下のようになります。

orqa

ICT では文書から q を抜くことで単語の一致に頼らない検索ができるようにしています。とはいえ、単語が一致していればそれは強力な手がかりでもあるので、ICT では仮想コンテキスト b を生成する際に q を抜く確率を 90% とし、残り 10% は原文ママとすることで、単語の合致にも反応できるような工夫がされています。

それではここからは実際にコードを動かしてみましょう。まずは ICT です。

※ いつものように Colab で動かす体で記述してますが、ORQA の学習は GPU メモリの都合で Colab では動かせませんでした。 ICT のほうは Colab の TPU で動かせたんですけどね。ですので、試す場合は GCP なり AWS なりで GPU インスタンスを立てて動かして下さい。

3. ICT による事前学習

それでは ICT による事前学習からはじめていきましょう。いつものように、記事内のコードスニペットは、特に断りがない場合は Google Colaboratory (以下、Colab)で動かす想定にしています。

事前学習データの準備

まずは、事前学習データの準備です。日本語Wikipedia のダンプを加工していくだけですので、とりあえずは Colab のアクセラレータは None のままでかまいません。

加工したデータは GCS に書き込むので認証を通しておきましょう。

from google.colab import auth
auth.authenticate_user()

ORQA のソースコードですが、専用のリポジトリがあるわけではなく Google AI Language チームの共有リポジトリの 1 ディレクトリになります(なのでしばらく存在に気付いてませんでした。。。)。また ORQA のコードは Tensorflow の特定バージョンへの依存がシビアなようなので、今回は珍しく commit を固定しています。

!git clone https://github.com/google-research/language
!cd language && git checkout e3a0875f0bedb6e6
# ...
# HEAD is now at e3a0875 Merge of PR #71

BERT への依存があるのでそちらも取得しておきます。

!git clone https://github.com/google-research/bert
!cd bert && git checkout eedf5716ce12
# ...
# HEAD is now at eedf571 Merge pull request #1027 from iuliaturc-google/master

それと Tensorflow Text です。

!pip install tensorflow_text

ICT では文章から1文を切り出して仮想の質問とする処理があるので、日本語の文章を文に区切るセパレータが必要です。今回は konoha を使いました。

!pip install 'konoha[sentence]'

前回の ELECTRA 由来の事前学習済みモデルを使うので、 transformers と MeCab 関係も必要です。

!pip install transformers
!apt-get install mecab mecab-ipadic-utf8
!pip install mecab-python3==0.996.5
!pip install fugashi ipadic

ORQA のコードは Tensorflow 2.x を前提とするので、以下のマジックコマンドを実行しておきます。

%tensorflow_version 2.x 

前処理

日本語 Wikipedia のダンプを取得して、 wikiextractor でテキストを抽出します。 wikiestractor は少し前のバージョンを使ってますね。。。たしか作業したときはバージョンアップの直後か何かで必要なスクリプトが見つからなかったとか、そんな理由だったような気がします。

あと、この検証をしたときは 2020/8/20 のダンプを使ったのですが、今見たら、もうなくなっていたので 2021/1/20 の URL を記載しておきます。 (ちゃんと動作確認してませんがたぶん動くんではないかと。。。)

!wget https://dumps.wikimedia.org/jawiki/20210120/jawiki-20210120-pages-articles.xml.bz2
#!wget https://dumps.wikimedia.org/jawiki/20200820/jawiki-20200820-pages-articles.xml.bz2
!git clone https://github.com/attardi/wikiextractor
!cd wikiextractor/ && git checkout e4abb4cbd019b0257824 && git log 
# ...
# HEAD is now at e4abb4c Fix typo

以下のようにして抽出します。

! INPUT_PATH=$(pwd)/jawiki-20210120-pages-articles.xml.bz2 && \
  OUTPUT_PATH=$(pwd)/jawiki-20210120 && \
  python -m wikiextractor.WikiExtractor \
    -o $OUTPUT_PATH \
    --json \
    --filter_disambig_pages \
    --quiet \
    --processes 2 \
    $INPUT_PATH

language と bert をモジュール検索パスに追加します。

import sys
sys.path.append("./language")
sys.path.append("./bert")

ORQA のソースでは曖昧さ回避ページや一覧ページなどは除外対象になっているのですが、当然ながら英語向けなので差し替えておきます。

import re
def remove_doc(title):
  return re.match(r"(.*一覧$)|"
                  r"(.*\(曖昧さ回避\).*)", title)

from language.orqa.preprocessing import wiki_preprocessor
wiki_preprocessor.remove_doc = remove_doc

文のスプリッタも差し替えです。

from konoha import SentenceTokenizer
sentence_splitter = SentenceTokenizer()

トークナイザは transformers の日本語 BERT から借用しまして、

from transformers import BertJapaneseTokenizer
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

パラメータを格納する FLAGS 変数をでっちあげます。

from collections import namedtuple
FLAGS = namedtuple("FLAGS", ["bert_hub_module_path", "bert_hub_module_handle", "max_block_length", "input_pattern" , "output_dir", "num_threads"])

FLAGS.bert_hub_module_path = "DUMMY"
FLAGS.bert_hub_module_handle = "DUMMY"
FLAGS.max_block_length = 288
FLAGS.input_pattern = "./jawiki-20200820/**/wiki*"
FLAGS.output_dir    = "./orqa_db"
FLAGS.num_threads = 1

ようやく準備が整ったので処理を実行します。

import functools
import random
import os
import tensorflow.compat.v1 as tf

preprocessor = wiki_preprocessor.Preprocessor(sentence_splitter,
                                                FLAGS.max_block_length,
                                                tokenizer)

from language.orqa.preprocessing.preprocess_wiki_extractor import create_block_info
block_count = 0
input_paths = tf.io.gfile.glob(FLAGS.input_pattern)
random.shuffle(input_paths)

tf.logging.info("Processing %d input files.", len(input_paths))

tf.io.gfile.makedirs(FLAGS.output_dir)
blocks_path = os.path.join(FLAGS.output_dir, "blocks.tfr")
examples_path = os.path.join(FLAGS.output_dir, "examples.tfr")
titles_path = os.path.join(FLAGS.output_dir, "titles.tfr")

with tf.python_io.TFRecordWriter(blocks_path) as blocks_writer:
  with tf.python_io.TFRecordWriter(examples_path) as examples_writer:
    with tf.python_io.TFRecordWriter(titles_path) as titles_writer:
      for input_path in input_paths:
        for block_info in create_block_info(input_path, preprocessor):
          title = block_info[0]
          block = block_info[1] 
          examples = block_info[2] 
          blocks_writer.write(block.encode("utf-8"))
          examples_writer.write(examples)
          titles_writer.write(title.encode("utf-8"))
          block_count += 1
          if block_count % 10000 == 0:
            tf.logging.info("Wrote %d blocks.", block_count)
tf.logging.info("Wrote %d blocks in total.", block_count)
# INFO:tensorflow:Max block length 288
# INFO:tensorflow:Processing 1000 input files.
# INFO:tensorflow:Wrote 10000 blocks.
# INFO:tensorflow:Wrote 20000 blocks.
...

完了したら出来上がったものを GCS に保存しておきましょう。

!gsutil cp orqa_db/* gs://somewhere/orqa/wikipedia/

生成されたファイルのサイズ感はこんな感じです。

!ls -lh orqa_db
# total 4.3G
# -rw-r--r-- 1 root root 2.9G Aug 28 09:55 blocks.tfr
# -rw-r--r-- 1 root root 1.4G Aug 28 09:55 examples.tfr
# -rw-r--r-- 1 root root 117M Aug 28 09:55 titles.tfr

titles はともかく blocks と examples が謎ですね。 block は Wikipedia の記事を文に分割して “ ” で連結したものです。ただし記事が長い場合、タイトルのトークン数、ブロックを構成する各文のトークン数 + 3(“[CLS]”, “[SEP]”, “[SEP]”) が max_block_length を超えないように記事が分割されます。この辺りのソースを見てもらえるとわかりやすいでしょう4

最後に生成した block の数を数えておきましょう。ランタイムを再起動して、以下のように実行して下さい。

%tensorflow_version 2.x 

import tensorflow.compat.v1 as tf
tf.enable_eager_execution()

dataset = tf.data.TFRecordDataset("./orqa_db/examples.tfr")

num_examples = 0
for example in dataset:
    num_examples += 1
print(num_examples)
# 3185632

3185632 件になっていますが、これは筆者が 2020/8/20 版のダンプで実行した際のレコード数す。なので、上述の手順どおりに 2021/1/20 版を使うと値は違ってくる筈です。この件数を後々使うので控えておいてください。

それでは ICT を実行していきます。

ICT の実行

セットアップ

新しいノートブックを開き、"ランタイム“ -> “ランタイムタイプの変更” で “ハードウェアアクセラレータ” を “TPU” にしてください。 再びセットアップしていきます。まずは GCS の認証を通しておきましょう。

  • 途中で何度か "You must restart the runtime in order to use newly installed versions.” のメッセージがでますが後でまとめて再起動するので、無視して進めて下さい。
from google.colab import auth
auth.authenticate_user()

ORQA のコードを取得して、

!git clone https://github.com/google-research/language
!cd language && git checkout e3a0875f0bedb6e6
# ...
# HEAD is now at e3a0875 Merge of PR #71

次に BERT です。

!git clone https://github.com/google-research/bert
!cd bert && git checkout eedf5716ce12
# ...
# HEAD is now at eedf571 Merge pull request #1027 from iuliaturc-google/master

前回に ELECTRA で作成した事前学習済みモデルを使うので、学習時のオプティマイザも合わせた方がいいかな?と思ったのですが、インポートしてみると以下のようなエラーになってしまいました。

  • AttributeError: module 'tensorflow_core._api.v2.train' has no attribute 'Optimizer'
  • AttributeError: module 'tensorflow_core._api.v2.train' has no attribute 'get_or_create_global_step'

なので、安直ですが少々書き換えてしまいます。

!cp ./bert/optimization.py optimization.py.org
!cat optimization.py.org | sed -e 's/tf.train.Optimizer/tf.compat.v1.train.Optimizer/' \
                               -e 's/tf.train.get_or_create_global_step/tf.compat.v1.train.get_or_create_global_step/' > optimization.py
!diff optimization.py.org optimization.py
# 27c27
# <   global_step = tf.train.get_or_create_global_step()
# ---
# >   global_step = tf.compat.v1.train.get_or_create_global_step()
# 87c87
# < class AdamWeightDecayOptimizer(tf.train.Optimizer):
# ---
# > class AdamWeightDecayOptimizer(tf.compat.v1.train.Optimizer):
!mv optimization.py ./bert/optimization.py 

続いて MIPS ライブラリの ScaNN をインストールします。

MIPS は Maximum Inner Product Search の略で検索対象のベクトル群からクエリベクトルとの内積が高いベクトルを検索する処理です。ScaNN は Google の開発した MIPS アルゴリズム5/ライブラリになります。長くなるので ScaNN については説明しませんが、また機会があれば紹介したいと思います。

実はもう Github に独立したリポジトリ6が出来ていて pip install scann でインストールできるようなってるらしいんですが、私が検証した時は以下のような手順で行いました。

まず、コンパイル済みのバイナリをインストールします。このバイナリが Tensorflow の 2.1.3 に依存していて、この時点で Tensorflow が 2.1.3 に差し変わります。

!wget https://storage.googleapis.com/scann/releases/1.0.0/scann-1.0.0-cp36-cp36m-linux_x86_64.whl
!pip install scann-1.0.0-cp36-cp36m-linux_x86_64.whl

さらに以下をインストールし、

!add-apt-repository -y ppa:ubuntu-toolchain-r/test
!apt-get install -y g++-9-multilib

ScaNN のソースコードも取ってきます。

%%bash
git init
git config core.sparsecheckout true
git remote add origin https://github.com/google-research/google-research
echo scann > .git/info/sparse-checkout
git pull origin master

git checkout 0681f5d5c6e5fad5
# ...
# HEAD is now at 0681f5d5 streaming residual connection

念のためテストを実行してみます。とりあえず大丈夫そうです。

!python scann/scann/scann_ops/py/scann_ops_pybind_test.py
# ...
# Ran 11 tests in 5.252s
#
# OK

Tensorflow Text は 2.1.1 を使いました。

!pip install tensorflow-text==2.1.1

次にトークナイザを流用した関係で Transformers と MeCab を入れていきます。

!pip install transformers==3.0.2
!apt-get install -y mecab mecab-ipadic-utf8
!pip install mecab-python3==0.996.5

最後に electra です。

!git clone https://github.com/google-research/electra
!cd electra && git checkout 79111328070e
# ...
# HEAD is now at 7911132 Merge pull request #47 from stefan-it/add-keep-checkpoint-max-parameter

ただし、 electra のコードは Tensorflow 1.x が前提なのでそのままでは 2.x で動かせません。少し修正します。 修正済みのファイルをどこからともなく取ってきて、オリジナルと差し替えます。

!gsutil cp gs://somewhere/electra/tf_2.0/*.py .
!cp electra/model/modeling.py modeling.py.org
!cp electra/model/optimization.py optimization.py.org
!cp modeling.py electra/model/
!cp optimization.py electra/model/

修正内容は以下の通りです。ようは Tensorflow 2.x で tf.contrib が消滅してしまって Layer Normalization はどうしてくれるの?という話です。仕方ないので、tf.keras から持ってきました。ただ良くわからない副作用があって、スコープの添え字に “_1” がついてしまうので、 その対応も入れてます。

!diff electra/model/modeling.py modeling.py.org

33c33
< #from tensorflow.contrib import layers as contrib_layers
---
> from tensorflow.contrib import layers as contrib_layers
351d350
< 
353,366c352,357
<     (ckpt_name, ckpt_var) = (x[0], x[1])
<     if prefix + ckpt_name in name_to_variable:
<       # Workaround for init_from_checkpoint in 2.0
<       assignment_map[ckpt_name] = name_to_variable[prefix+ ckpt_name]
<       initialized_variable_names[ckpt_name] = 1
<       initialized_variable_names[ckpt_name + ":0"] = 1
<     # Workaround for extra scope generation("embeddings_1") 
<     #   caused by tf.keras.layers.LayerNormalization (?)
<     ckpt_name_emb1 = ckpt_name.replace("embeddings", "embeddings_1")
<     if prefix + ckpt_name_emb1 in name_to_variable and ckpt_name not in initialized_variable_names:
<       # Workaround for init_from_checkpoint in 2.0
<       assignment_map[ckpt_name] = name_to_variable[prefix+ ckpt_name_emb1]
<       initialized_variable_names[ckpt_name] = 1
<       initialized_variable_names[ckpt_name + ":0"] = 1
---
>     (name, var) = (x[0], x[1])
>     if prefix + name not in name_to_variable:
>       continue
>     assignment_map[name] = prefix + name
>     initialized_variable_names[name] = 1
>     initialized_variable_names[name + ":0"] = 1
389c380
< def layer_norm(input_tensor, name="LayerNorm"):
---
> def layer_norm(input_tensor, name=None):
391,393c382,383
<   #return contrib_layers.layer_norm(
<   #    inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
<   return tf.keras.layers.LayerNormalization(name=name,axis=-1,epsilon=1e-12,dtype=tf.float32)(input_tensor)
---
>   return contrib_layers.layer_norm(
>       inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
396c386
< def layer_norm_and_dropout(input_tensor, dropout_prob, name="LayerNorm"):
---
> def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):

こっちも “_1” 対応を一応入れてます。

!diff electra/model/optimization.py optimization.py.org
# 185d184
# <       "/embeddings_1/": 0,

ようやくセットアップが終わったのでランタイムを再起動してください。

学習の実行

まずは各種ライブラリをインポートします。

%tensorflow_version 2.x 

from google.colab import auth
auth.authenticate_user()

import os
import sys
sys.path.append("./language")
sys.path.append("./electra")
sys.path.append("./bert")

import tensorflow.compat.v1 as tf

import functools
from absl import app
from absl import flags
from language.common.utils import experiment_utils
from language.orqa.models import ict_model

tf.disable_v2_behavior()

パラメータの定義です。

flags.DEFINE_string("bert_hub_module_path",
                    "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1",
                    "Path to the BERT TF-Hub module.")
flags.DEFINE_integer("query_seq_len", 64, "Query sequence length.")
flags.DEFINE_integer("block_seq_len", 288, "Document sequence length.")
flags.DEFINE_integer("projection_size", 128, "Projection size.")
flags.DEFINE_float("learning_rate", 1e-4, "Learning rate.")
flags.DEFINE_integer("num_block_records", 13353718, "Number of block records.")
flags.DEFINE_string("examples_path", None, "Input examples path")
flags.DEFINE_integer("num_input_threads", 12, "Num threads for input reading.")
flags.DEFINE_float("mask_rate", 0.9, "Mask rate.")

ELECTRA の事前学習済みモデルや設定ファイルのパラメータです。iterations_per_loop はチェックポイントの間隔と別に設定したかったので追加しています。

flags.DEFINE_string("init_check_point", None, "")
flags.DEFINE_string("bert_config", None, "")
flags.DEFINE_integer("iterations_per_loop", 1000, "num iterations per loop.")

パラメータに値を設定します。Colab の Free TPU ではバッチサイズ 4096 はムリなので、仕方なく 256 にしています。num_block_records は数えたらこの値でした。num_train_steps は「日本語Wikipediaのサイズからしてこんなもんかな」くらいの話です。

num_block_records の値は前処理で確認した値に書き換えて下さい。

FLAGS = flags.FLAGS
FLAGS.mark_as_parsed()

FLAGS.model_dir="gs://somewhere/orqa/ict_model"
FLAGS.examples_path="gs://somewhere/orqa/wikipedia/examples.tfr"
FLAGS.num_block_records=3185632
FLAGS.save_checkpoints_steps=1000
FLAGS.batch_size=256 # NOTE: modified from 4096. 512 doesn't work.
FLAGS.num_train_steps=100000
FLAGS.use_tpu=True
FLAGS.bert_hub_module_path="DUMMY:NOT_USED"
FLAGS.tpu_name="DUMMY:Colab Free TPU"
FLAGS.init_check_point="gs://somewhere/electra/max_seq_length_512/models/electra_base_wiki_ja"
FLAGS.bert_config="gs://somewhere/electra/max_seq_length_512/config.json"
FLAGS.iterations_per_loop=200

Colab Free TPU のアドレスを確認して、

assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

TPUClusterResolver に仕込みます。

cluster = tf.distribute.cluster_resolver.TPUClusterResolver(TPU_ADDRESS, zone=None, project=None)
def ColabTPUClusterResolver(tpu_name, zone, project):
  print("ColabTPUClusterResolver returns %s" % (cluster))
  return cluster
tf.distribute.cluster_resolver.TPUClusterResolver = ColabTPUClusterResolver

ここからは model_fn の差し替えです。オリジナルのコードは Tensorflow Hub から華麗に学習済みモデルをダウンロードして使えていますが、前回に自前で作った checkpoint を使うので強引に書き換えていきます。

まずは、必要な関数をインポートして、

from electra.model import optimization as electra_optimization
from language.common.utils import tensor_utils
from language.common.utils import tpu_utils
from electra.model import modeling
import collections
import re
import json

次に transformer を構築する関数です。これはほぼ electra.run_pretraining.PretrainedingModel._build_transformer() をそのまま持ってきただけだったかと。

def build_transformer(inputs, is_training, use_tpu, bert_config, name="electra", reuse=False, **kwargs):
  """Build a transformer encoder network."""
  with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
    return modeling.BertModel(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=inputs["input_ids"],
        input_mask=inputs["input_mask"],
        token_type_ids=inputs["segment_ids"],
        use_one_hot_embeddings=use_tpu,
        scope=name,
        **kwargs)

BERT の出力を埋め込み表現にする関数です。これは language.orqa.models.ict_model.module_fn() から切り出しました。

def get_projected_emb(bert, scope, params, is_training):
  with tf.variable_scope(scope):
    reprs = bert.get_pooled_output()
    projected_emb = tf.layers.dense(reprs, params["projection_size"], name="weight")
    projected_emb = tf.keras.layers.LayerNormalization(axis=-1)(projected_emb)
    if is_training:
      projected_emb = tf.nn.dropout(projected_emb, rate=0.1)
    return projected_emb

ELECTRA で作った checkpoint を読み込むためのアサイメントマップ生成関数です。electra.model.modeling.get_assignment_map_from_checkpoint() をベースに Tensorflow 2.x 向けの改造とスコープの “_1” への対応を入れてます。

def get_assignment_map_from_checkpoint(tvars, init_checkpoint, target, prefix="", pattern="^electra"):
  """Compute the union of the current variables and checkpoint variables."""
  name_to_variable = collections.OrderedDict()
  for var in tvars:
    name = var.name
    m = re.match("^(.*):\\d+$", name)
    if m is not None:
      name = m.group(1)
    name_to_variable[name] = var

  initialized_variable_names = {}
  assignment_map = collections.OrderedDict()

  for x in tf.train.list_variables(init_checkpoint):
    (ckpt_name, ckpt_var) = (x[0], x[1])

    m = re.match(pattern, ckpt_name)
    if m is not None:
      # "electra/weight": "bert_q/weight"
      target_variable_name = re.sub(pattern, prefix + target, ckpt_name)
      if target_variable_name in name_to_variable:
        # Workaround for init_from_checkpoint in 2.0
        assignment_map[ckpt_name] = name_to_variable[target_variable_name]
        initialized_variable_names[ckpt_name] = 1
        initialized_variable_names[ckpt_name + ":0"] = 1
      # Workaround for extra scope generation("embeddings_1")
      #   caused by tf.keras.layers.LayerNormalization (?)
      ckpt_name_emb1 = ckpt_name.replace("embeddings", "embeddings_1")
      target_variable_name = re.sub(pattern, prefix + target, ckpt_name_emb1)
      if target_variable_name in name_to_variable and ckpt_name not in initialized_variable_names:
        # Workaround for init_from_checkpoint in 2.0
        assignment_map[ckpt_name] = name_to_variable[target_variable_name]
        initialized_variable_names[ckpt_name] = 1
        initialized_variable_names[ckpt_name + ":0"] = 1

  return assignment_map, initialized_variable_names

事前学習済み ELECTRA のパラメータを BERTQ, BERTB にロードする関数です。

def init_from_checkpoint(params):
  # Load pre-trained weights from checkpoint
  init_checkpoint = params["init_check_point"]
  tf.logging.info("Using checkpoint: %s", init_checkpoint)
  tvars = tf.trainable_variables()

  scaffold_fn = None
  if init_checkpoint:
    assignment_map_q, _ = get_assignment_map_from_checkpoint(
        tvars, init_checkpoint, target="bert_q")
    assignment_map_b, _ = get_assignment_map_from_checkpoint(
        tvars, init_checkpoint, target="bert_b")

    if params["use_tpu"]:
      def tpu_scaffold():
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map_q)
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map_b)
        return tf.train.Scaffold()
      scaffold_fn = tpu_scaffold
    else:
      tf.train.init_from_checkpoint(init_checkpoint, assignment_map_q)
      tf.train.init_from_checkpoint(init_checkpoint, assignment_map_b)
  return scaffold_fn

これで部品がそろったので、language.orqa.models.ict_model.module_fn() に組み込んでます。

def model_fn(features, labels, mode, params):
  """Model function."""
  del labels

  # [local_batch_size, block_seq_len]
  block_ids = features["block_ids"]
  block_mask = features["block_mask"]
  block_segment_ids = features["block_segment_ids"]

  # [local_batch_size, query_seq_len]
  query_ids = features["query_ids"]
  query_mask = features["query_mask"]

  local_batch_size = tensor_utils.shape(block_ids, 0)
  tf.logging.info("Model batch size: %d", local_batch_size)

  is_training = mode == tf.estimator.ModeKeys.TRAIN 

  # Load config.json
  with tf.io.gfile.GFile(params["bert_config"], "r") as f:
    config_json = json.load(f)
    bert_config = modeling.BertConfig.from_dict(config_json)

  # Build BERT_Q
  bert_q = build_transformer(
    dict(
      input_ids=query_ids,
      input_mask=query_mask,
      segment_ids=tf.zeros_like(query_ids)),
    is_training,
    params["use_tpu"], 
    bert_config,
    embedding_size=bert_config.hidden_size,
    untied_embeddings=True, 
    name="bert_q")

  query_emb = get_projected_emb(bert_q, "emb_q", params, is_training)

  # Build BERT_B
  bert_b = build_transformer(
    dict(
      input_ids=block_ids,
      input_mask=block_mask,
      segment_ids=block_segment_ids),
    is_training,
    params["use_tpu"], 
    bert_config,
    embedding_size=bert_config.hidden_size,
    untied_embeddings=True, 
    name="bert_b")

  block_emb = get_projected_emb(bert_b, "emb_b", params, is_training)

  if params["use_tpu"]:
    # [global_batch_size, hidden_size]
    block_emb = tpu_utils.cross_shard_concat(block_emb)

    # [global_batch_size, local_batch_size]
    labels = tpu_utils.cross_shard_pad(tf.eye(local_batch_size))

    # [local_batch_size]
    labels = tf.argmax(labels, 0)
  else:
    # [local_batch_size]
    labels = tf.range(local_batch_size)

  tf.logging.info("Global batch size: %s", tensor_utils.shape(block_emb, 0))

  # [batch_size, global_batch_size]
  logits = tf.matmul(query_emb, block_emb, transpose_b=True)

  # []
  loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

  # Load pre-trained ELECTRA checkpoint 
  scaffold_fn = init_from_checkpoint(params)

  train_op = electra_optimization.create_optimizer(
      loss=loss,
      learning_rate=params["learning_rate"],
      num_train_steps=params["num_train_steps"],
      warmup_steps=min(10000, max(100, int(params["num_train_steps"]/10))),
      use_tpu=params["use_tpu"] if "use_tpu" in params else False)

  predictions = tf.argmax(logits, -1)

  metric_args = [query_mask, block_mask, labels, predictions,
                 features["mask_query"]]

  def metric_fn(query_mask, block_mask, labels, predictions, mask_query):
    masked_accuracy = tf.metrics.accuracy(
        labels=labels,
        predictions=predictions,
        weights=mask_query)
    unmasked_accuracy = tf.metrics.accuracy(
        labels=labels,
        predictions=predictions,
        weights=tf.logical_not(mask_query))
    return dict(
        query_non_padding=tf.metrics.mean(query_mask),
        block_non_padding=tf.metrics.mean(block_mask),
        actual_mask_ratio=tf.metrics.mean(mask_query),
        masked_accuracy=masked_accuracy,
        unmasked_accuracy=unmasked_accuracy)

  if params["use_tpu"]:
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        scaffold_fn=scaffold_fn,
        eval_metrics=(metric_fn, metric_args))
  else:
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=metric_fn(*metric_args),
        predictions=predictions)  

余談ですが、features の属性は以下のようになっています。4 で紹介したコードと属性が違ってますが、language.orqa.datasets.ict_dataset.get_retrieval_examples() に加工する処理があります7

  • keep_example ([batch_size]) :
    False のサンプルは input_fn で除去されるので全て True
  • mask_query ([batch_size]) :
    block_ids から query_ids を除去したか否か。10%は原文ママです。
  • query_ids ([batch_size, query_seq_len]) :
    “[CLS]” + blocksentences からランダム選択した1文 + “[SEP]"のトークンIDです。
  • query_mask ([batch_size, query_seq_len]) :
    上記に対応する input_mask です。
  • block_ids ([batch_size, block_seq_len]) :
    ”[CLS]“ + title + ”[SEP]“ + block + ”[SEP]“ のトークンIDです。
  • block_mask ([batch_size, block_seq_len]) :
    上記に対応する input_mask です。
  • block_segment_ids ([batch_size, block_seq_len]) :
    上記に対応する segment_ids です。

今回は batch_size = 256 なので、普通なら 256 択の選択問題なのですが、if params["use_tpu"]: の分岐で面白いことをしています。Colab Free TPU だと コアが8個あって、それぞれのコアでバッチを処理しているんだと思います。で、自分以外のコアで生成した block の埋め込みをコピーしてきてくっつけてます。これで 256 x 8 = 2048 択に水増しして問題の難易度を上げている訳です。

さて、問題に戻ります。トークナイザも差し替えないといけません。

from language.orqa.utils import bert_utils
from transformers import BertJapaneseTokenizer

bert_japanese_tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

def get_bert_japanese_tokenizer(bert_hub_module_path):
  return bert_japanese_tokenizer

bert_utils.get_tokenizer = get_bert_japanese_tokenizer

オリジナルのコードは language.common.utils.experiment_utils.run_experiment() を使いますが、save_checkpoints_stepsiterations_per_loop を別々に設定したかったので修正しています。

def run_experiment(model_fn,
                   train_input_fn,
                   eval_input_fn,
                   exporters=None,
                   params=None,
                   params_fname=None):
  params = params if params is not None else {}
  params.setdefault("use_tpu", FLAGS.use_tpu)

  if FLAGS.model_dir and params_fname:
    tf.io.gfile.makedirs(FLAGS.model_dir)
    params_path = os.path.join(FLAGS.model_dir, params_fname)
    with tf.io.gfile.GFile(params_path, "w") as params_file:
      json.dump(params, params_file, indent=2, sort_keys=True)

  if params["use_tpu"]:
    if FLAGS.tpu_name:
      tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
          FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
      tpu_cluster_resolver = None
    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        tf_random_seed=FLAGS.tf_random_seed,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop)) # NOTE: modified from FLAGS.save_checkpoints_steps
    if "batch_size" in params:
      # Let the TPUEstimator fill in the batch size.
      params.pop("batch_size")
    estimator = tf.estimator.tpu.TPUEstimator(
        use_tpu=True,
        model_fn=model_fn,
        params=params,
        config=run_config,
        train_batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.eval_batch_size)
  else:
    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.model_dir,
        tf_random_seed=FLAGS.tf_random_seed,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max)
    params["batch_size"] = FLAGS.batch_size
    estimator = tf.estimator.Estimator(
        config=run_config,
        model_fn=model_fn,
        params=params,
        model_dir=FLAGS.model_dir)

  train_spec = tf.estimator.TrainSpec(
      input_fn=train_input_fn,
      max_steps=FLAGS.num_train_steps)
  eval_spec = tf.estimator.EvalSpec(
      name="default",
      input_fn=eval_input_fn,
      exporters=exporters,
      start_delay_secs=FLAGS.eval_start_delay_secs,
      throttle_secs=FLAGS.eval_throttle_secs,
      steps=FLAGS.num_eval_steps)

  tf.logging.set_verbosity(tf.logging.INFO)
  tf.estimator.train_and_evaluate(
      estimator=estimator,
      train_spec=train_spec,
      eval_spec=eval_spec)

最後にパラメータを dict にまとめて、

params = dict(
    batch_size=FLAGS.batch_size,
    eval_batch_size=FLAGS.eval_batch_size,
    bert_hub_module_path=FLAGS.bert_hub_module_path,
    query_seq_len=FLAGS.query_seq_len,
    block_seq_len=FLAGS.block_seq_len,
    projection_size=FLAGS.projection_size,
    learning_rate=FLAGS.learning_rate,
    examples_path=FLAGS.examples_path,
    mask_rate=FLAGS.mask_rate,
    num_train_steps=FLAGS.num_train_steps,
    num_block_records=FLAGS.num_block_records,
    num_input_threads=FLAGS.num_input_threads,
    init_check_point=FLAGS.init_check_point,
    bert_config=FLAGS.bert_config)

ようやくですが、ICT を実行します。

from language.common.utils import experiment_utils
from language.orqa.models import ict_model

run_experiment(
  model_fn=model_fn,
  train_input_fn=functools.partial(ict_model.input_fn, is_train=True),
  eval_input_fn=functools.partial(ict_model.input_fn, is_train=False),
  exporters=None,
  params=params)
# ColabTPUClusterResolver returns <tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver.TPUClusterResolver object at 0x7fa5c42a3048>
# INFO:tensorflow:Using config: {'_model_dir': 'gs://somewhere/orqa/ict_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 1000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
# ...
# INFO:tensorflow:Enqueue next (200) batch(es) of data to infeed.
# INFO:tensorflow:Dequeue next (200) batch(es) of data from outfeed.
# INFO:tensorflow:Outfeed finished for iteration (54, 42)
# INFO:tensorflow:Outfeed finished for iteration (54, 129)
# INFO:tensorflow:Saving checkpoints for 100000 into gs://somewhere/orqa/ict_model/model.ckpt.
# INFO:tensorflow:loss = 0.66586185, step = 100000 (178.031 sec)
# INFO:tensorflow:global_step/sec: 1.1234
# INFO:tensorflow:examples/sec: 287.59
# INFO:tensorflow:Stop infeed thread controller
# INFO:tensorflow:Shutting down InfeedController thread.
# INFO:tensorflow:InfeedController received shutdown signal, stopping.
# INFO:tensorflow:Infeed thread finished, shutting down.
# INFO:tensorflow:infeed marked as finished
# INFO:tensorflow:Stop output thread controller
# INFO:tensorflow:Shutting down OutfeedController thread.
# INFO:tensorflow:OutfeedController received shutdown signal, stopping.
# INFO:tensorflow:Outfeed thread finished, shutting down.
# INFO:tensorflow:outfeed marked as finished
# INFO:tensorflow:Shutdown TPU system.
# INFO:tensorflow:Loss for final step: 0.66586185.
# INFO:tensorflow:training_loop marked as finished

これで ICT による事前学習が完了しました。次は ICT が完了した checkpoint から BERTB を読み込んで前処理のところで記事を分割した block 毎に埋め込み表現にしていきます。

block のエンコード

ここで一旦ランタイムを再起動して作業状況を整理しておきましょう。それでは各種のインポートからやり直しです。

%tensorflow_version 2.x 
from google.colab import auth
auth.authenticate_user()

import os
import sys
sys.path.append("./language")
sys.path.append("./electra")
sys.path.append("./bert")

import time

from absl import app
from absl import flags

from language.orqa.utils import bert_utils
from language.orqa.utils import scann_utils
from language.orqa.predict.encode_blocks import input_fn

import numpy as np
import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()

次にパラメータの定義と設定です。

num_block_records の値は前処理で確認した値に書き換えて下さい。

flags.DEFINE_string("model_dir", None, "")
flags.DEFINE_string("bert_config", None, "")
flags.DEFINE_integer("projection_size", 128, "Projection size.")

FLAGS = flags.FLAGS
FLAGS.mark_as_parsed()

FLAGS.model_dir="gs://somewhere/orqa/ict_model"
FLAGS.examples_path="gs://somewhere/orqa/wikipedia/examples.tfr"
FLAGS.num_blocks=3185632
FLAGS.batch_size=256 # NOTE: modified from 4096. 512 doesn't work.
FLAGS.use_tpu=True
FLAGS.retriever_module_path="DUMMY:NOT_USED"
FLAGS.tpu_name="DUMMY:Colab Free TPU"
FLAGS.bert_config="gs://somewhere/electra/max_seq_length_512/config.json"

Colab Free TPU のセットアップと差し替えです。

assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print('TPU address is', TPU_ADDRESS)

cluster = tf.distribute.cluster_resolver.TPUClusterResolver(TPU_ADDRESS, zone=None, project=None)

def ColabTPUClusterResolver(tpu_name, zone, project):
  return cluster

tf.distribute.cluster_resolver.TPUClusterResolver = ColabTPUClusterResolver

transformer の構築関数を定義します。

from electra.model import modeling
import json

def build_transformer(inputs, is_training, use_tpu, bert_config, name="electra", reuse=False, **kwargs):
  """Build a transformer encoder network."""
  with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
    return modeling.BertModel(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=inputs["input_ids"],
        input_mask=inputs["input_mask"],
        token_type_ids=inputs["segment_ids"],
        use_one_hot_embeddings=use_tpu,
        scope=name,
        **kwargs)

次に埋め込み表現の生成関数なのですが ICT の学習時、tf.keras.layers.LayerNormalization() の添え字付けルールが tf.contrib.layers.layer_norm() と異なるようで、BERTB のみチェックポイントに格納されるパラメータ名に”_1" がついていました。 なので、ここでname="layer_normalization_1" と指定して合わせ込み込みます。

def get_projected_emb(bert, scope, is_training):
  with tf.variable_scope(scope):
    reprs = bert.get_pooled_output()
    projected_emb = tf.layers.dense(reprs, FLAGS.projection_size, name="weight")
    # NOTE name="layer_normalization**_1**"
    projected_emb = tf.keras.layers.LayerNormalization(name="layer_normalization_1", axis=-1)(projected_emb) 
    if is_training:
      projected_emb = tf.nn.dropout(projected_emb, rate=0.1)
    return projected_emb

model_fn() は block をエンコードするだけなので、 BERTB だけを使います。

def model_fn(features, labels, mode, params):
  """Model function."""
  del labels, params

  # [local_batch_size, block_seq_len]
  block_ids = features["block_ids"]
  block_mask = features["block_mask"]
  block_segment_ids = features["block_segment_ids"]

  # Load config.json
  with tf.io.gfile.GFile(FLAGS.bert_config, "r") as f:
    config_json = json.load(f)
    bert_config = modeling.BertConfig.from_dict(config_json)

  # Build BERT_B
  bert_b = build_transformer(
    dict(
      input_ids=block_ids,
      input_mask=block_mask,
      segment_ids=block_segment_ids),
    False,
    FLAGS.use_tpu, 
    bert_config,
    embedding_size=bert_config.hidden_size,
    untied_embeddings=True, 
    name="bert_b")

  block_emb = get_projected_emb(bert_b, "emb_b", False)

  predictions = dict(block_emb=block_emb)
  return tf.estimator.tpu.TPUEstimatorSpec(mode=mode, predictions=predictions)

後はトークナイザを差し替えて、

from language.orqa.utils import bert_utils

from transformers import BertJapaneseTokenizer
bert_japanese_tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

def get_bert_japanese_tokenizer(bert_hub_module_path):
  return bert_japanese_tokenizer

bert_utils.get_tokenizer = get_bert_japanese_tokenizer  

ここまでの部品を組み立てて estimator にします。

if FLAGS.use_tpu and FLAGS.tpu_name:
  tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
    FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
else:
  tpu_cluster_resolver = None

run_config = tf.estimator.tpu.RunConfig(
  cluster=tpu_cluster_resolver,
  master=FLAGS.master,
  tpu_config=tf.estimator.tpu.TPUConfig(iterations_per_loop=1000))

estimator = tf.estimator.tpu.TPUEstimator(
  use_tpu=FLAGS.use_tpu,
  model_fn=model_fn,
  config=run_config,
  model_dir=FLAGS.model_dir,
  train_batch_size=FLAGS.batch_size,
  eval_batch_size=FLAGS.batch_size,
  predict_batch_size=FLAGS.batch_size)

埋め込み表現の保存先です。

encoded_path = os.path.join(FLAGS.model_dir, "encoded", "encoded.ckpt")
tf.logging.info("Embeddings will be written to %s", encoded_path)
# INFO:tensorflow:Embeddings will be written to gs://somewhere/orqa/ict_model/encoded/encoded.ckpt

前処理のところで作った block を埋め込んでいきます。前処理で作った examples.tfr を読み込んで(FLAGS.examples_pathで設定してます)、 BERTB で埋め込みます。

start_time = time.time()
all_block_emb = None
i = 0
for outputs in estimator.predict(input_fn=input_fn):
  if i == 0:
    all_block_emb = np.zeros(
        shape=(FLAGS.num_blocks, outputs["block_emb"].shape[-1]), dtype=np.float32)
  if i >= FLAGS.num_blocks:
    break
  all_block_emb[i, :] = outputs["block_emb"]
  i += 1
  if i % 1000 == 0:
    elapse_time = time.time() - start_time
    examples_per_second = i / elapse_time
    remaining_minutes = ((FLAGS.num_blocks - i) / examples_per_second) / 60
    tf.logging.info(
        "[%d] examples/sec: %.2f, "
        "elapsed minutes: %.2f, "
        "remaining minutes: %.2f", i, examples_per_second, elapse_time / 60,
        remaining_minutes)
tf.logging.info("Expected %d rows, found %d rows", FLAGS.num_blocks, i)
tf.logging.info("Saving block embedding to %s...", encoded_path)
scann_utils.write_array_to_checkpoint("block_emb", all_block_emb, encoded_path)
tf.logging.info("Done saving block embeddings.")
# ...
# INFO:tensorflow:[3185000] examples/sec: 1583.60, elapsed minutes: 33.52, remaining minutes: 0.01
# INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
# INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
# INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
# INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
# INFO:tensorflow:prediction_loop marked as finished
# WARNING:tensorflow:Reraising captured error
# INFO:tensorflow:Expected 3185632 rows, found 3185632 rows
# INFO:tensorflow:Saving block embedding to gs://somewhere/orqa/ict_model/encoded/encoded.ckpt...
# WARNING:tensorflow:From ./language/language/orqa/utils/scann_utils.py:30: py_func (from tensorflow.python.ops.script_ops) is deprecated and will be removed in a future version.
# Instructions for updating:
# tf.py_func is deprecated in TF V2. Instead, there are two
#    options available in V2.
#    - tf.py_function takes a python function which manipulates tf eager
#    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
#    an ndarray (just call tensor.numpy()) but having access to eager tensors
#    means `tf.py_function`s can use accelerators such as GPUs as well as
#    being differentiable using a gradient tape.
#    - tf.numpy_function maintains the semantics of the deprecated tf.py_func
#    (it is not differentiable, and manipulates numpy arrays). It drops the
#    stateful argument making all functions stateful.
# 
# 
# Exception ignored in: <generator object TPUEstimator.predict at 0x7fce4b51fca8>
# Traceback (most recent call last):
#   File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 3078, in predict
#    rendezvous.raise_errors()
#   File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/error_handling.py", line 147, in raise_errors
#    six.reraise(typ, value, traceback)
#   File "/usr/local/lib/python3.6/dist-packages/six.py", line 703, in reraise
#    raise value
#   File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/error_handling.py", line 116, in catch_errors
#    yield
#   File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 536, in _run_infeed
#     session.run(self._enqueue_ops)
#   File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 960, in run
#     run_metadata_ptr)
#   File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1183, in _run
#     feed_dict_tensor, options, run_metadata)
#   File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1361, in _do_run
#     run_metadata)
#   File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1386, in _do_call
#     raise type(e)(node_def, op, message)
# tensorflow.python.framework.errors_impl.CancelledError: Step was cancelled by an explicit call to `Session::Close()`.
# 
# ERROR:tensorflow:Closing session due to error Step was cancelled by an explicit call to `Session::Close()`.
# INFO:tensorflow:Done saving block embeddings.

最後の終わり方がちょっとアレですが、埋め込みが完了しました。ここからは ORQA タスクでファインチューニングをしていきます。

4. ORQA タスクによるファインチューニング

ここからは質問 - 解答のペアでファインチューニングしていきます。

記事内のコードスニペットは Colab で動かす風に記述してますが、 Colab の GPU 環境ではメモリが足らず動かせませんでした。論文では “a single machine with a 12GB GPU” とあるので、ぎりぎり足らなかった感じかもしれません。仕方ないので社内のサーバに jupyter を入れて作業しています。

ちなみに Docker イメージとして tensorflow/tensorflow:2.1.0-gpu-py3-jupyter を使ってます。このイメージで jupyter notebook を起動、新規のノートブックで wgetgit をインストールしてください。

!apt-get update
!apt-get install -y wget
!apt-get install -y git

後は、ICT タスクのセットアップと同じ手順を実行します。 ORQA タスクでは学習中に MIPS 検索を行うので、ScaNN のカスタム op をインストールします。

%%bash
cd language
TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++11 -shared language/orqa/ops/orqa_ops.cc -o language/orqa/ops/orqa_ops.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2

テストを動かします。1個とばしてますが、大丈夫そうですね(というか大丈夫でした)。

!cd language && python -m language.orqa.ops.orqa_ops_test
# ...
# [       OK ] OrqaOpsTest.test_has_answer
# [ RUN      ] OrqaOpsTest.test_reader_inputs
# [       OK ] OrqaOpsTest.test_reader_inputs
# [ RUN      ] OrqaOpsTest.test_session
# [  SKIPPED ] OrqaOpsTest.test_session
# ----------------------------------------------------------------------
# Ran 3 tests in 3.249s
# 
# OK (skipped=1)

tensorflow models も取得します。

!git clone https://github.com/tensorflow/models
!cd models && git checkout 06be7fb4be4d8a56eb
# ...
# HEAD is now at 06be7fb... Internal change

tensorflow-hub も入れておきます。

!pip install tensorflow-hub==0.9.0

これでセットアップは完了です。次は質問 - 解答ペアのデータセットです。

JAQKET データセット

今回はクイズを題材にした日本語 QA データセットである JAQKET データセット8を使うことにしました。

jaqket

JAQKET データセットは解答が必ず Wikipedia の記事名になるよう正規化されており、以下のようにクイズ問題、解答及び解答候補からなる選択問題のデータセットになっています。

jaqket

現在公開されているデータは訓練データ、開発用データ1、開発用データ2 に分かれていてそれぞれ以下の件数です。

  • 訓練データ : 13,061 件
  • 開発用データ1 : 995 件
  • 開発用データ2 : 997 件

今回は訓練データで学習し、開発用データ1を検証セット、開発用データ2をテストセットとして実験します。 また、前述のとおり質問 - 解答ペアのみでの学習としますので questionanswer_entity のみ用い、answer_candidates は使わないことにします。

それではデータを取得して、

!wget https://jaqket.s3-ap-northeast-1.amazonaws.com/data/train_questions.json
!wget https://jaqket.s3-ap-northeast-1.amazonaws.com/data/dev1_questions.json
!wget https://jaqket.s3-ap-northeast-1.amazonaws.com/data/dev2_questions.json

加工します。

import json
def convert_jaqket(filename, split, dataset_name="jaqket"):
  with open(filename, "r") as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]
    examples = [json.loads(line) for line in lines]
    examples = [{"qid": example["qid"], 
                 "question": example["question"], 
                 "answer": [example["answer_entity"]] } for example in examples]
  with open("%s.resplit.%s.jsonl" % (dataset_name, split), "w") as f:
    lines = []
    for example in examples:
      line = json.dumps(example, ensure_ascii=False)
      lines.append(line)
    f.write("\n".join(lines))

convert_jaqket("train_questions.json", "train")
convert_jaqket("dev1_questions.json", "dev")
convert_jaqket("dev2_questions.json", "test")

加工後のデータはこんな感じです。answer"[]" で括られていますが、これは正解が複数あるパターンを想定したコードになっているからです。

!head -5 jaqket.resplit.train.jsonl
# {"qid": "ABC01-01-0003", "question": "格闘家ボブ・サップの出身国はどこでしょう?", "answer": ["アメリカ合衆国"]}
# {"qid": "ABC01-01-0004", "question": "ロシア語で「城」という意味がある、ロシアの大統領府の別名は何でしょう?", "answer": ["クレムリン"]}
# {"qid": "ABC01-01-0005", "question": "織田信長、豊臣秀吉、徳川家康という3人の戦国武将の性格を表現するのに用いられる鳥は何でしょう?", "answer": ["ホトトギス"]}
# {"qid": "ABC01-01-0006", "question": "人気タレント・タモリの本名は何でしょう?", "answer": ["タモリ"]}
# {"qid": "ABC01-01-0008", "question": "「国際連合」の旗に描かれている植物といったら何でしょう?", "answer": ["オリーブ"]}

あ、タモさんの本名が違ってますね。。。直しておきましょう。9

!sed -i '/^{"qid": "ABC01-01-0006"/s/"タモリ"/"森田一義"/' jaqket.resplit.train.jsonl

加工後のデータ件数です。

!wc -l *.jsonl 
#    994 jaqket.resplit.dev.jsonl
#    996 jaqket.resplit.test.jsonl
#  13060 jaqket.resplit.train.jsonl
#  15050 total

各種関数の差し替え

それでは、ここで一旦ノートブックを再起動して、いつものようにアレコレ差し替えていきます。

まずは、必要なものをインポートします。

import os
import sys
sys.path.append("./language")
sys.path.append("./electra")
sys.path.append("./bert")
sys.path.append("./models")

import tensorflow.compat.v1 as tf

import functools
from absl import app
from absl import flags
from language.common.utils import experiment_utils
from language.orqa.models import orqa_model

tf.disable_v2_behavior()

続いてパラメータを定義していきます。

flags.DEFINE_integer("retriever_beam_size", 5000,
                     "Retriever beam size.")
flags.DEFINE_integer("reader_beam_size", 5, "Reader beam size.")
flags.DEFINE_float("learning_rate", 1e-5, "Initial learning rate.")
flags.DEFINE_integer("span_hidden_size", 256, "Span hidden size.")
flags.DEFINE_integer("max_span_width", 10, "Maximum span width.")
flags.DEFINE_integer("num_block_records", 13353718, "Number of block records.")
flags.DEFINE_integer("query_seq_len", 64, "Query sequence length.")
flags.DEFINE_integer("block_seq_len", 288, "Document sequence length.")
flags.DEFINE_integer("reader_seq_len", 288 + 64, "Reader sequence length.") 
flags.DEFINE_string("reader_module_path",
                    "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1",
                    "Path to the reader TF-Hub module.")
flags.DEFINE_string("retriever_module_path", None,
                    "Path to the retriever TF-Hub module.")
flags.DEFINE_string("data_root", None, "Data root.")
flags.DEFINE_string("block_records_path", None, "Block records path.")
flags.DEFINE_string("dataset_name", None, "Name of dataset.")

ELECTRA の事前学習済みモデル( BERTR の初期化に使います)と BERT の設定ファイルのパスを設定するパラメータです。

flags.DEFINE_string("init_check_point", None, "")
flags.DEFINE_string("retriever_check_point", None, "")
flags.DEFINE_string("bert_config", None, "")
flags.DEFINE_integer("projection_size", None, "Projection size.")

FLAGS = flags.FLAGS
FLAGS.mark_as_parsed()

実際の設定値を代入します。社内のサーバで動かすので "gs://somewhere/orqa" の内容が 前もって "./data" にコピーしてある前提の記述です。

num_block_records の値は前処理で確認した値に書き換えて下さい。

FLAGS.num_block_records=3185632
FLAGS.retriever_module_path="./data/ict_model"
FLAGS.block_records_path="./data/wikipedia/blocks.tfr"
FLAGS.data_root="./"
FLAGS.model_dir="./orqa_model"
FLAGS.dataset_name="jaqket"
FLAGS.num_train_steps=13060 * 20
FLAGS.save_checkpoints_steps=1000

FLAGS.init_check_point="./data/electra_base_wiki_ja"
FLAGS.retriever_check_point="./data/ict_model"
FLAGS.bert_config="./data/config.json"

FLAGS.projection_size=128
FLAGS.retriever_beam_size = 5000

さらにインポートです。

from electra.model import optimization as electra_optimization
from electra.model import modeling
from language.orqa import ops as orqa_ops
import json
import numpy as np
import re

さて、いつも BERT がらみの学習をするときは、学習データを準備する段階でトークナイズされ、 input_ids, input_mask, segment_ids のような形になっているのですが、ORQA では features["question"] には質問文字列、labels には解答文字列の集合(複数解答アリ)のテンソルが入ってきます。

なので、model_fn() でトークナイズしないといけないのですが、「あ、トークナイザは transformers からパクったんだった。あぅっ。。。」となりました。オリジナルのコードは bert_utils.get_tf_tokenizer() とか気楽にやってますが、こちらはそんな訳に行かないので、tf.py_func() を駆使して、ほぼ互換関数を作って差し替えることになりました10

まずは、transformers のトークナイザを生成して、

from transformers import BertJapaneseTokenizer
bert_japanese_tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

そして “tf.py_func()” を経由して渡された numpy array をトークナイズする関数です。

def bert_japanese_tokenize_func(array):
  if np.array_equal(array, [b'']) :
    return np.array([], dtype=np.int64), np.array([0,0], dtype=np.int64)
  token_ids = []
  row_splits = [0]
  for text in array:
    tokens_of_text = bert_japanese_tokenizer.tokenize(text.decode("utf-8"))
    token_ids.extend(bert_japanese_tokenizer.convert_tokens_to_ids(tokens_of_text))
    row_splits.append(row_splits[-1] + len(tokens_of_text))
  return np.array(token_ids), np.array(row_splits)

返り値が分かりにくいですが、array で複数の文字列を渡されると、全てのトークナイズを連結したトークンIDのリストとその連結を分割して基に戻す為のインデックスのリストを返します。 動かすとこうなります。

bert_japanese_tokenize_func(["本日は晴天なり".encode("utf-8"), "吾輩は猫である".encode("utf-8")])
# (array([  108, 28486,     9,  4798, 28849,   297,  7184, 30046,     9,
#          6040,    12,    31]), array([ 0,  6, 12]))

2つの文字列のトークンIDが一つのリストに格納され [0:6][6:12] に分かれる感じです。

次は tf.py_func() の呼び出し元になる関数です。こちらは入力された文字列を先ほどの関数に tf.py_func() 経由で通し、返された情報を使って RaggedTensor に変換して返します。RaggedTensor は次元の長さが不揃いのテンソルです。

def bert_japanese_tokenize(input):
  token_ids, row_splits = tf.py_func(bert_japanese_tokenize_func, [input], [tf.int64, tf.int64], stateful=False)
  return tf.RaggedTensor.from_row_splits(token_ids, row_splits=row_splits)

試しに動かすとこんな感じになります。

with tf.Session() as sess:
  question = "ボブ・サップの生まれた国はどこですか"
  tf_question = tf.convert_to_tensor(question)
  question_token_ids = bert_japanese_tokenize(tf.expand_dims(tf_question, 0))
  print(question_token_ids.to_tensor().eval())
# [[10764    35   117   272     5  1115    10    79     9  5359  2992    29]]

次はトークン文字列に対応するトークンIDを返す関数です。

def bert_japanese_vocab_lookup_func(token):
  return bert_japanese_tokenizer.vocab[token.decode("utf-8")]

def bert_japanese_vocab_lookup(tf_token):
  return tf.py_func(bert_japanese_vocab_lookup_func, [tf_token], tf.int64, stateful=False)

こんな感じですね。

with tf.Session() as sess:
  cls_token_id = bert_japanese_vocab_lookup(tf.constant("[CLS]"))
  print(cls_token_id.eval())
# 2

さらに 3 つ関数を定義します。どんな感じの挙動なのかは後でまとめて説明しますね。

まず、サブワードに分割されたトークン列を受け取り、単語単位のトークン列と各サブワードと単語単位のトークン列の対応を保持するリストを返す関数です。

def joint_sub_word_tokens(sub_tokens):
  word_level_tokens = []
  block_token_map = []
  for sub_token in sub_tokens:

    if sub_token.startswith('##'):
      sub_token = re.sub("^##", "", sub_token)
      if len(word_level_tokens) > 0:
        word_level_tokens[-1]+=sub_token
      else:
        word_level_tokens.append(sub_token)
    else:
      word_level_tokens.append(sub_token)
    block_token_map.append(len(word_level_tokens)-1)
  return word_level_tokens, block_token_map

次に文字列のリストを受け取り、サブワードに分割して上記の関数を適用する関数です。

def bert_japanese_tokenize_with_original_mapping_func(blocks):
  orig_tokens = []
  block_token_ids = []
  block_token_map = []
  row_splits = [0]
  orig_row_splits = [0]

  for text in blocks:
    tokens_of_text = bert_japanese_tokenizer.tokenize(text.decode("utf-8"))
    block_token_ids.extend(bert_japanese_tokenizer.convert_tokens_to_ids(tokens_of_text))
    orig_tokens_of_text, block_token_map_of_text = joint_sub_word_tokens(tokens_of_text)
    orig_tokens.extend(orig_tokens_of_text)
    block_token_map.extend(block_token_map_of_text)
    row_splits.append(row_splits[-1] + len(tokens_of_text))
    orig_row_splits.append(orig_row_splits[-1] + len(orig_tokens_of_text)) 

  return (np.array([token.encode("utf-8") for token in orig_tokens]), 
          np.array(block_token_map), 
          np.array(block_token_ids), 
          np.array(row_splits), 
          np.array(orig_row_splits))

これを tf.py_func() でラップして model_fn() から呼び出せるようにします。

def bert_japanese_tokenize_with_original_mapping(blocks):
  orig_tokens, block_token_map, block_token_ids, row_splits, orig_row_splits = tf.py_func(
      bert_japanese_tokenize_with_original_mapping_func, [blocks], 
               [tf.string, tf.int64, tf.int64, tf.int64, tf.int64], stateful=False)
  orig_tokens = tf.RaggedTensor.from_row_splits(orig_tokens, row_splits=orig_row_splits)
  block_token_ids = tf.RaggedTensor.from_row_splits(block_token_ids, row_splits=row_splits)
  block_token_map = tf.RaggedTensor.from_row_splits(block_token_map, row_splits=row_splits)
  return orig_tokens, block_token_map, block_token_ids, blocks

私はこんなの見せられると真面目に読む気が失せるタイプです。なので「こんな感じで動きます」というのを示して説明にしたことにさせて下さい。

with tf.Session() as sess:
  blocks = tf.convert_to_tensor(["本日は晴天なり",  "吾輩は猫である"])
  (orig_tokens, block_token_map, block_token_ids, blocks) = (
      bert_japanese_tokenize_with_original_mapping(blocks))

  print("=== orig_tokens ===")
  for i, row in enumerate(orig_tokens.to_tensor().eval()):
    print("  idx=%d, length=%d, org_tokens=%s" % (i, len(row), [token.decode("utf-8") for token in row]))

  print("=== block token map ===")
  for i, row in enumerate(block_token_map.to_tensor().eval()):
    print("  idx=%d, length=%d, token_map=%s" % (i, len(row), row))

  print("=== block token ids ===")
  for i, row in enumerate(block_token_ids.to_tensor().eval()):
    print("  idx=%d, length=%d, token_ids=%s" % (i, len(row), row))

  print("=== blocks ===")
  for i, block in enumerate(blocks.eval()):
    print("  idx=%d, block=%s" % (i, block.decode("utf-8")))

# === orig_tokens ===
#   idx=0, length=5, org_tokens=['本日', 'は', '晴天', 'なり', '']
#   idx=1, length=5, org_tokens=['吾輩', 'は', '猫', 'で', 'ある']
# === block token map ===
#   idx=0, length=6, token_map=[0 0 1 2 2 3]
#   idx=1, length=6, token_map=[0 0 1 2 3 4]
# === block token ids ===
#   idx=0, length=6, token_ids=[  108 28486     9  4798 28849   297]
#   idx=1, length=6, token_ids=[ 7184 30046     9  6040    12    31]
# === blocks ===
#   idx=0, block=本日は晴天なり
#   idx=1, block=吾輩は猫である    

orig_tokensblock_token_map, block_token_ids は長さが違いますが、前者が単語単位、後者二つがサブワード単位だからです。"本日“と"吾輩"がそれぞれ、"本”, “##日” と“吾”, “##輩” に分かれたことが block_token_map を見ればわかります。

表示上「本日は晴天なり」が 5 トークンになっていますが、これは RaggedTensor.to_tensor() で通常のテンソルに変換する際に最長のものに長さを合わせる必要があって “” でパディングされたからであり、内部的には 4 トークンに分割されてます。

さて、ここからはだいたい今までどおりです。transformer の構築関数を定義して、

def build_transformer(inputs, is_training, use_tpu, bert_config, name="electra", reuse=False, **kwargs):
  """Build a transformer encoder network."""
  with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
    return modeling.BertModel(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=inputs["input_ids"],
        input_mask=inputs["input_mask"],
        token_type_ids=inputs["segment_ids"],
        use_one_hot_embeddings=use_tpu,
        scope=name,
        **kwargs)

埋め込み表現の生成関数。この関数を適用するのは BERTQ の方なので “_1” は気にしなくて大丈夫。

def get_projected_emb(bert, scope, params, is_training):
  with tf.variable_scope(scope):
    reprs = bert.get_pooled_output()
    projected_emb = tf.layers.dense(reprs, params["projection_size"], name="weight")
    projected_emb = tf.keras.layers.LayerNormalization(axis=-1)(projected_emb)
    if is_training:
      projected_emb = tf.nn.dropout(projected_emb, rate=0.1)
    return projected_emb

もう一つだけ。今回は BERT の部分に ELECTRA のコードを使っているのですが、何も考えずにパディングして動かすと TypeError: %d format: a number is required, not NoneType になってしまい。。。どうやら ELECTRA の get_shape_list() は入力の shape が <unknown> だと、そうなってしまうようで。 どうしたものかと思案した挙句、「 BERT への入力になるパディング後の shape なんて固定で既知なんだから、tf.zeros() で既知の固定サイズのテンソル作って足しちゃえば良くない?」と思いつき、以下のような回避コードになりました。

def pad_to_axis1(tensor2d, padded_seq_len):
  pad_token_id = bert_japanese_vocab_lookup(tf.constant("[PAD]"))
  batch_len = tf.size(tensor2d[:,0])
  tensor_seq_len = tf.size(tensor2d[0,:])  
  zeros = tf.zeros([batch_len, padded_seq_len], dtype=tf.int32)
  padding_seq_len = padded_seq_len - tensor_seq_len
  padding = tf.zeros([batch_len, padding_seq_len], tf.int32)
  padded_tensor2d = tf.concat([tensor2d, padding], 1)
  return padded_tensor2d + zeros

それでは、ここまで作ってきた関数を組み込んでいきます。まずは、reader に相当する retrieve() 関数です。

質問を埋め込み表現にして ScaNN を使って MIPS 検索し、検索結果の TOP-K(=retriever_beam_size) の文字列を読み込んで logit と一緒に返しています。

def retrieve(features, retriever_beam_size, mode, params):
  question_token_ids = bert_japanese_tokenize(
      tf.expand_dims(features["question"], 0))
  question_token_ids = tf.cast(question_token_ids.to_tensor(), tf.int32)
  cls_token_id = bert_japanese_vocab_lookup(tf.constant("[CLS]"))
  sep_token_id = bert_japanese_vocab_lookup(tf.constant("[SEP]"))
  question_token_ids = tf.concat(
      [[[tf.cast(cls_token_id, tf.int32)]], question_token_ids,
       [[tf.cast(sep_token_id, tf.int32)]]], -1)

  input_mask = tf.ones_like(question_token_ids)   
  segment_ids = tf.zeros_like(question_token_ids) 
  question_token_ids = pad_to_axis1(question_token_ids, params["query_seq_len"])
  input_mask = pad_to_axis1(input_mask, params["query_seq_len"])
  segment_ids = pad_to_axis1(segment_ids, params["query_seq_len"])

  is_training = mode == tf.estimator.ModeKeys.TRAIN 

  # Load config.json
  with tf.io.gfile.GFile(params["bert_config"], "r") as f:
    config_json = json.load(f)
    bert_config = modeling.BertConfig.from_dict(config_json)

  # Build BERT_Q
  bert_q = build_transformer(
    dict(
          input_ids=question_token_ids,
          input_mask=input_mask,
          segment_ids=segment_ids),
    is_training,
    params["use_tpu"], 
    bert_config,
    embedding_size=bert_config.hidden_size,
    untied_embeddings=True, 
    name="bert_q")

  question_emb = get_projected_emb(bert_q, "emb_q", params, is_training)

  block_emb, searcher = scann_utils.load_scann_searcher(
      var_name="block_emb",
      checkpoint_path=os.path.join(params["retriever_module_path"], "encoded",
                                   "encoded.ckpt"),
      num_neighbors=retriever_beam_size)

  # [1, retriever_beam_size]
  retrieved_block_ids, _ = searcher.search_batched(question_emb)

  # [1, retriever_beam_size, projection_size]
  retrieved_block_emb = tf.gather(block_emb, retrieved_block_ids)

  # [retriever_beam_size]
  retrieved_block_ids = tf.squeeze(retrieved_block_ids)

  # [retriever_beam_size, projection_size]
  retrieved_block_emb = tf.squeeze(retrieved_block_emb)

  # [1, retriever_beam_size]
  retrieved_logits = tf.matmul(
      question_emb, retrieved_block_emb, transpose_b=True)

  # [retriever_beam_size]
  retrieved_logits = tf.squeeze(retrieved_logits, 0)

  blocks_dataset = tf.data.TFRecordDataset(
      params["block_records_path"], buffer_size=512 * 1024 * 1024)
  blocks_dataset = blocks_dataset.batch(
      params["num_block_records"], drop_remainder=True)
  blocks = tf.get_local_variable(
      "blocks",
      initializer=tf.data.experimental.get_single_element(blocks_dataset))
  retrieved_blocks = tf.gather(blocks, retrieved_block_ids)
  return RetrieverOutputs(logits=retrieved_logits, blocks=retrieved_blocks)

次に reader に相当する read() 関数です。処理の内容を分かりやすくする為、途中で出てくる変数の意味合いについて補足します。

def read(features, retriever_logits, blocks, mode, params, labels):
  """Do reading."""
  orig_blocks = blocks

  (orig_tokens, block_token_map, block_token_ids, blocks) = (
      bert_japanese_tokenize_with_original_mapping(blocks))

  question_token_ids = bert_japanese_tokenize(
      tf.expand_dims(features["question"], 0)) 
  question_token_ids = tf.cast(question_token_ids.flat_values, tf.int32)

  orig_tokens = orig_tokens.to_tensor(default_value="")
  block_lengths = tf.cast(block_token_ids.row_lengths(), tf.int32)
  block_token_ids = tf.cast(block_token_ids.to_tensor(), tf.int32)
  block_token_map = tf.cast(block_token_map.to_tensor(), tf.int32)

  answer_token_ids = bert_japanese_tokenize(labels)
  answer_lengths = tf.cast(answer_token_ids.row_lengths(), tf.int32)
  answer_token_ids = tf.cast(answer_token_ids.to_tensor(), tf.int32)

  cls_token_id = bert_japanese_vocab_lookup(tf.constant("[CLS]"))
  sep_token_id = bert_japanese_vocab_lookup(tf.constant("[SEP]"))

  concat_inputs = orqa_ops.reader_inputs(
      question_token_ids=question_token_ids,
      block_token_ids=block_token_ids,
      block_lengths=block_lengths,
      block_token_map=block_token_map,
      answer_token_ids=answer_token_ids,
      answer_lengths=answer_lengths,
      cls_token_id=tf.cast(cls_token_id, tf.int32),
      sep_token_id=tf.cast(sep_token_id, tf.int32),
      max_sequence_len=params["reader_seq_len"])

  tf.summary.scalar("reader_nonpad_ratio",
                    tf.reduce_mean(tf.cast(concat_inputs.mask, tf.float32)))

  is_training = mode == tf.estimator.ModeKeys.TRAIN 

  # Load config.json
  with tf.io.gfile.GFile(params["bert_config"], "r") as f:
    config_json = json.load(f)
    bert_config = modeling.BertConfig.from_dict(config_json)

  # Build BERT_R
  bert_r = build_transformer(
    dict(
          input_ids=concat_inputs.token_ids,
          input_mask=concat_inputs.mask,
          segment_ids=concat_inputs.segment_ids),
    is_training,
    params["use_tpu"], 
    bert_config,
    embedding_size=bert_config.hidden_size,
    untied_embeddings=True, 
    name="bert_r")

  concat_token_emb = bert_r.get_sequence_output()

  # [num_spans], [num_spans], [reader_beam_size, num_spans]
  candidate_starts, candidate_ends, candidate_mask = span_candidates(
      concat_inputs.block_mask, params["max_span_width"])

  # Score with an MLP to enable start/end interaction:
  # score(s, e) = w·σ(w_s·h_s + w_e·h_e)
  kernel_initializer = tf.truncated_normal_initializer(stddev=0.02)

  # [reader_beam_size, max_sequence_len, span_hidden_size * 2]
  projection = tf.layers.dense(
      concat_token_emb,
      params["span_hidden_size"] * 2,
      kernel_initializer=kernel_initializer)

  # [reader_beam_size, max_sequence_len, span_hidden_size]
  start_projection, end_projection = tf.split(projection, 2, -1)

  # [reader_beam_size, num_candidates, span_hidden_size]
  candidate_start_projections = tf.gather(
      start_projection, candidate_starts, axis=1)
  candidate_end_projection = tf.gather(end_projection, candidate_ends, axis=1)
  candidate_hidden = candidate_start_projections + candidate_end_projection

  candidate_hidden = tf.nn.relu(candidate_hidden)
  candidate_hidden = tf.keras.layers.LayerNormalization(axis=-1)(
      candidate_hidden)

  # [reader_beam_size, num_candidates, 1]
  reader_logits = tf.layers.dense(
      candidate_hidden, 1, kernel_initializer=kernel_initializer)

  # [reader_beam_size, num_candidates]
  reader_logits = tf.squeeze(reader_logits)
  reader_logits += mask_to_score(candidate_mask)
  reader_logits += tf.expand_dims(retriever_logits, -1)


  # [reader_beam_size, num_candidates]
  candidate_orig_starts = tf.gather(
      params=concat_inputs.token_map, indices=candidate_starts, axis=-1)
  candidate_orig_ends = tf.gather(
      params=concat_inputs.token_map, indices=candidate_ends, axis=-1)

  return ReaderOutputs(
      logits=reader_logits,
      candidate_starts=candidate_starts,
      candidate_ends=candidate_ends,
      candidate_orig_starts=candidate_orig_starts,
      candidate_orig_ends=candidate_orig_ends,
      blocks=blocks,
      orig_blocks=orig_blocks,
      orig_tokens=orig_tokens,
      token_ids=concat_inputs.token_ids,
      gold_starts=concat_inputs.gold_starts,
      gold_ends=concat_inputs.gold_ends)

内容がわかりにくいのは以下の変数かと思います。

  • concat_inputs.block_mask([reader_beam_size,reader_seq_len]) :
    BERTR への入力系列(長さ reader_seq_len) は “[CLS]” + query + “[SEP]” + block + “[SEP]” + “[PAD]”* の形式になっています。ですので block 部分だけ抜けるよう、 block 部分だけ “1” 、その他が “0” のマスクになります。
  • candidate_starts/ends([num_spans]) :
    系列の長さ(reader_seq_len)と候補スパンの最大長(max_span_width) が決まれば自動的に候補スパンの数(num_spans) は決まります。BERTR への入力系列における num_span 個の候補スパンの開始/終了インデックスが格納されています。 block の先頭からの位置ではないことに注意して下さい。
  • candidate_mask([reader_beam_size, num_spans]) :
    candidate_starts/ends は BERTR への入力系列に対して機械的に生成した開始・終了ペアです。当然、開始もしくは終了位置が系列中に含まれる block の範囲の外になるものが含まれています。この変数は開始・終了位置が block の範囲に収まる候補スパンのみ “1”, その他が “0” のマスクになります。
  • projection([reader_beam_size, max_sequence_len, span_hidden_size * 2]) :
    候補スパンの開始・終了位置に対して MLP を掛ける処理に相当します。正確には BERTR の出力全体をここで写像しておいて、当該位置の値を後(reader_logits のところ)から抜く感じです。 “* 2” は「開始,終了に対してそれぞれ span_hidden_size 個のユニットを使うことにしているのですが、一度で計算したほうが効率よい」ということでしょう。
  • candidate_start/end_projections([reader_beam_size, num_candidates, span_hidden_size]) : projection を前後で二つに割って(start/end_projection) から各候補スパンの開始/終了位置(candidate_starts)の埋め込み表現を集めてきます。 num_candidates はコード中のコメントに合わせてますが、数値的には num_spans と同じだと思います。
  • candidate_hidden([reader_beam_size, num_candidates, span_hidden_size]) :
    各候補スパンの開始と終了位置の埋め込みを加算し、ReLu、Layer Normalization したものです。
  • reader_logits([reader_beam_size, num_candidates]) :
    candidate_hidden を dense に通してスカラー値にして、candidate_mask を掛けて(範囲外の候補スパンは -10000 する)、そして retriever の logits の先頭 reader_beam_size 分を抜いて足しこんでいます。
  • candidate_orig_start/ends ([reader_beam_size, num_candidates]) :
    サブトークン単位-> トークン単位のインデックスのマッピングとcandidate_start/ends で候補スパンの単語単位での開始/終了位置を出しています。

あれ?、 reader での処理では開始/終了の埋め込みを「連結(concatenation)して」と説明したのですが、ソースコードはcandidate_hidden のところで加算(“+”)してるからちょっと違ってますね。。。

model_fn() は ICT と ELECTRA の事前学習済みモデルのパラメータを読み込む処理を追加したのと、optimizer を ELECTRA 由来のコードに差し替えた以外はそのままです。

def model_fn(features, labels, mode, params):
  """Model function."""
  if labels is None:
    labels = tf.constant([""])

  reader_beam_size = params["reader_beam_size"]
  if mode == tf.estimator.ModeKeys.PREDICT:
    retriever_beam_size = reader_beam_size
  else:
    retriever_beam_size = params["retriever_beam_size"]
  assert reader_beam_size <= retriever_beam_size

  with tf.device("/cpu:0"):
    retriever_outputs = retrieve(
        features=features,
        retriever_beam_size=retriever_beam_size,
        mode=mode,
        params=params)

  with tf.variable_scope("reader"):
    reader_outputs = read(
        features=features,
        retriever_logits=retriever_outputs.logits[:reader_beam_size],
        blocks=retriever_outputs.blocks[:reader_beam_size],
        mode=mode,
        params=params,
        labels=labels)

  predictions = get_predictions(reader_outputs, params)

  if mode == tf.estimator.ModeKeys.PREDICT:
    loss = None
    train_op = None
    eval_metric_ops = None
  else:
    # [retriever_beam_size]
    retriever_correct = orqa_ops.has_answer(
        blocks=retriever_outputs.blocks, answers=labels)

    # [reader_beam_size, num_candidates]
    reader_correct = compute_correct_candidates(
        candidate_starts=reader_outputs.candidate_starts,
        candidate_ends=reader_outputs.candidate_ends,
        gold_starts=reader_outputs.gold_starts,
        gold_ends=reader_outputs.gold_ends)

    eval_metric_ops = compute_eval_metrics(
        labels=labels,
        predictions=predictions,
        retriever_correct=retriever_correct,
        reader_correct=reader_correct)

    # []
    loss = compute_loss(
        retriever_logits=retriever_outputs.logits,
        retriever_correct=retriever_correct,
        reader_logits=reader_outputs.logits,
        reader_correct=reader_correct)

    # Load pre-trained checkpoint 
    init_from_checkpoint(target="bert_q", pattern="^bert_q",
                         init_checkpoint=params["retriever_check_point"])
    init_from_checkpoint(target="emb_q", pattern="^emb_q",
                         init_checkpoint=params["retriever_check_point"])
    init_from_checkpoint(target="reader/bert_r", pattern="^electra",
                         init_checkpoint=params["init_check_point"])
    init_from_checkpoint(target="reader/reader/bert_r", pattern="^electra",
                         init_checkpoint=params["init_check_point"]) 

    train_op = electra_optimization.create_optimizer(
        loss=loss,
        learning_rate=params["learning_rate"],
        num_train_steps=params["num_train_steps"],
        warmup_steps=min(10000, max(100, int(params["num_train_steps"]/10))),
        use_tpu=False)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      predictions=predictions,
      eval_metric_ops=eval_metric_ops)

get_predictions() は最終的な回答文字列を生成する関数です。" “ を挟んで単語を連結していた部分を ”“ に修正しています。

def get_predictions(reader_outputs, params):
  vocab = tf.convert_to_tensor(list(bert_japanese_tokenizer.vocab.keys()))

  # []
  predicted_block_index = tf.argmax(tf.reduce_max(reader_outputs.logits, 1))
  predicted_candidate = tf.argmax(tf.reduce_max(reader_outputs.logits, 0))

  predicted_block = tf.gather(reader_outputs.blocks, predicted_block_index)
  predicted_orig_block = tf.gather(reader_outputs.orig_blocks,
                                   predicted_block_index)
  predicted_orig_tokens = tf.gather(reader_outputs.orig_tokens,
                                    predicted_block_index)
  predicted_orig_start = tf.gather(
      tf.gather(reader_outputs.candidate_orig_starts, predicted_block_index),
      predicted_candidate)
  predicted_orig_end = tf.gather(
      tf.gather(reader_outputs.candidate_orig_ends, predicted_block_index),
      predicted_candidate)
  predicted_orig_answer = tf.reduce_join(
      predicted_orig_tokens[predicted_orig_start:predicted_orig_end + 1],
      separator="")

  predicted_token_ids = tf.gather(reader_outputs.token_ids,
                                  predicted_block_index)
  predicted_tokens = tf.gather(vocab, predicted_token_ids)
  predicted_start = tf.gather(reader_outputs.candidate_starts,
                              predicted_candidate)
  predicted_end = tf.gather(reader_outputs.candidate_ends, predicted_candidate)
  predicted_normalized_answer = tf.reduce_join(
      predicted_tokens[predicted_start:predicted_end + 1], separator="")

  def _get_final_text(pred_text, orig_text):
    pred_text = six.ensure_text(pred_text, errors="ignore")
    orig_text = six.ensure_text(orig_text, errors="ignore")
    return squad_lib.get_final_text(
        pred_text=pred_text,
        orig_text=orig_text,
        do_lower_case=False)

  predicted_answer = tf.py_func(
      func=_get_final_text,
      inp=[predicted_normalized_answer, predicted_orig_answer],
      Tout=tf.string)

  return dict(
      block_index=predicted_block_index,
      candidate=predicted_candidate,
      block=predicted_block,
      orig_block=predicted_orig_block,
      orig_tokens=predicted_orig_tokens,
      orig_start=predicted_orig_start,
      orig_end=predicted_orig_end,
      answer=predicted_answer)

先ほど組み込んだパラメータ読み込み関数です。

def init_from_checkpoint(target, init_checkpoint, pattern):
  tf.logging.info("Using checkpoint: %s", init_checkpoint)
  tvars = tf.trainable_variables()
  if init_checkpoint:
    assignment_map, _ = get_assignment_map_from_checkpoint(
        tvars, init_checkpoint, target=target, pattern=pattern)
    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

こちらはアサイメントマップの生成関数。Tensorflow 2.x で動かすための修正とパラメータに ”_1" が付く問題の回避コードを入れました。

def get_assignment_map_from_checkpoint(tvars, init_checkpoint, target, prefix="", pattern="^electra"):
  """Compute the union of the current variables and checkpoint variables."""
  name_to_variable = collections.OrderedDict()
  for var in tvars:
    name = var.name
    m = re.match("^(.*):\\d+$", name)
    if m is not None:
      name = m.group(1)
    name_to_variable[name] = var

  initialized_variable_names = {}
  assignment_map = collections.OrderedDict()

  for x in tf.train.list_variables(init_checkpoint):
    (ckpt_name, ckpt_var) = (x[0], x[1])

    m = re.match(pattern, ckpt_name)
    if m is not None:
      # "electra/weight": "bert_q/weight"
      target_variable_name = re.sub(pattern, prefix + target, ckpt_name)
      if target_variable_name in name_to_variable:
        # Workaround for init_from_checkpoint in 2.0
        assignment_map[ckpt_name] = name_to_variable[target_variable_name]
        initialized_variable_names[ckpt_name] = 1
        initialized_variable_names[ckpt_name + ":0"] = 1
      # Workaround for extra scope generation("embeddings_1")
      #   caused by tf.keras.layers.LayerNormalization (?)
      ckpt_name_emb1 = ckpt_name.replace("embeddings", "embeddings_1")
      target_variable_name = re.sub(pattern, prefix + target, ckpt_name_emb1)
      if target_variable_name in name_to_variable and ckpt_name not in initialized_variable_names:
        # Workaround for init_from_checkpoint in 2.0
        assignment_map[ckpt_name] = name_to_variable[target_variable_name]
        initialized_variable_names[ckpt_name] = 1
        initialized_variable_names[ckpt_name + ":0"] = 1

      #if ckpt_name not in initialized_variable_names and (
      #    not ckpt_name.endswith("adam_m") and not ckpt_name.endswith("adam_v")):
      #  print(ckpt_name + " is ignored.")

  return assignment_map, initialized_variable_names

ようやくですが準備ができたので ORQA の学習を回します。

ORQA の実行

まずはインポートと namedtuple の定義です。

import collections
import six
from official.nlp.data import squad_lib
from language.orqa.utils import scann_utils
from language.orqa.models.orqa_model import span_candidates, mask_to_score, \
  marginal_log_loss, compute_correct_candidates, compute_loss, \
  compute_eval_metrics

RetrieverOutputs = collections.namedtuple("RetrieverOutputs", ["logits", "blocks"])
ReaderOutputs = collections.namedtuple("ReaderOutputs", [
    "logits", "candidate_starts", "candidate_ends", "candidate_orig_starts",
    "candidate_orig_ends", "blocks", "orig_blocks", "orig_tokens", "token_ids",
    "gold_starts", "gold_ends"
])

各種パラメータを dict にして、

params = dict(
  data_root=FLAGS.data_root,
  batch_size=FLAGS.batch_size,
  eval_batch_size=FLAGS.batch_size,
  query_seq_len=FLAGS.query_seq_len,
  block_seq_len=FLAGS.block_seq_len,
  learning_rate=FLAGS.learning_rate,
  num_train_steps=FLAGS.num_train_steps,
  retriever_module_path=FLAGS.retriever_module_path,
  reader_module_path=FLAGS.reader_module_path,
  retriever_beam_size=FLAGS.retriever_beam_size,
  reader_beam_size=FLAGS.reader_beam_size,
  reader_seq_len=FLAGS.reader_seq_len,
  span_hidden_size=FLAGS.span_hidden_size,
  max_span_width=FLAGS.max_span_width,
  block_records_path=FLAGS.block_records_path,
  num_block_records=FLAGS.num_block_records,
  init_check_point=FLAGS.init_check_point,
  retriever_check_point=FLAGS.retriever_check_point,
  bert_config=FLAGS.bert_config,
  projection_size=FLAGS.projection_size)

train/eval_input_fn() を定義します。

train_input_fn = functools.partial(orqa_model.input_fn,
                                    name=FLAGS.dataset_name,
                                    is_train=True)
eval_input_fn = functools.partial(orqa_model.input_fn,
                                  name=FLAGS.dataset_name,
                                   is_train=False)

あとは実行するだけです。

experiment_utils.run_experiment(
    model_fn=model_fn,
    params=params,
    train_input_fn=train_input_fn,
    eval_input_fn=eval_input_fn,
    exporters=None,
    params_fname="params.json")
# INFO:tensorflow:Using config: {'_model_dir': './orqa_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 1000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
# graph_options {
#   rewrite_options {
#     meta_optimizer_iterations: ONE
#   }
# }
# , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
# INFO:tensorflow:Not using Distribute Coordinator.
# INFO:tensorflow:Running training and evaluation locally (non-distributed).    
# ...

長かったですが、これで学習が完了しました。それではテストデータを使ってクイズを回答させてみましょう。

5. テストデータの評価

それではテストデータで評価してみましょう。ファインチューニングで使ったノートブックを続きから実行しています。

estimator を作り直して、

run_config = tf.estimator.RunConfig(
    model_dir=FLAGS.model_dir,
    tf_random_seed=FLAGS.tf_random_seed,
    save_checkpoints_steps=FLAGS.save_checkpoints_steps,
    keep_checkpoint_max=FLAGS.keep_checkpoint_max)

params["batch_size"] = FLAGS.batch_size

estimator = tf.estimator.Estimator(
    config=run_config,
    model_fn=model_fn,
    params=params,
    model_dir=FLAGS.model_dir)

テストデータを読み込む関数です。

from language.orqa.datasets import orqa_dataset
def test_input_fn(params):
  testset = orqa_dataset.get_dataset(data_root=params["data_root"], name=FLAGS.dataset_name, split="test")
  def _extract_labels(d):
    return d, d.pop("answers")
  testset = testset.map(_extract_labels)
  testset = testset.prefetch(10)
  return testset

それでは、実行します。

estimator.evaluate(input_fn=test_input_fn)
# ...
# INFO:tensorflow:Saving dict for global step 261200: exact_match = 0.31594783, global_step = 261200, loss = 5.097063, official_exact_match = 0.30491474, reader_oracle = 0.44934803, top_1000_match = 0.9568706, top_100_match = 0.81945837, top_10_match = 0.554664, top_5000_match = 0.9879639, top_500_match = 0.92878634, top_50_match = 0.7552658, top_5_match = 0.45937812
# INFO:tensorflow:Saving 'checkpoint_path' summary for global step 261200: ./orqa_model/model.ckpt-261200
# 
# {'exact_match': 0.31594783,
#  'loss': 5.097063,
#  'official_exact_match': 0.30491474,
#  'reader_oracle': 0.44934803,
#  'top_1000_match': 0.9568706,
#  'top_100_match': 0.81945837,
#  'top_10_match': 0.554664,
#  'top_5000_match': 0.9879639,
#  'top_500_match': 0.92878634,
#  'top_50_match': 0.7552658,
#  'top_5_match': 0.45937812,
#  'global_step': 261200}

exact_match が block に含まれる正解スパンと予測のスパンが合致した率、official_exact_match は ORQA が予測したスパンを文字列処理した解答文字列と正解文字列が合致した率ですので、official_exact_match が最終的な正答率です。予測したスパンを文字列にするところは雑に “” で連結しただけなので「スパンの予測は当てたけど、文字列にするところで解答と違っちゃった」みたいなのが、ある程度あったのでしょう。

そんな訳で最終的な正答率は約30.5% になりました。

クイズ王を名乗るには厳しいですが、ORQA の論文に記載されていたスコアも30%台ですので、とりあえず意図したとおりにには動いていそうです。 アレコレと妥協してチューニング等もしていないので、こんなものでしょう。

orqa

厳しいのは retriever の部分のようです。top_5_match(retriever の検索結果 Top 5 に正解が含まれる率) の時点で 0.45 なので、ここをもう少し何とかしたいところです。じつは正にこの部分の改良として REALM 11というのがありますので、機会があれば試してみたいと思います。

正答率だけではつまらないので、いくつかサンプルを実行してみましょう。ランタイムを再起動して、experiment_utils.run_experiment() の直前まで実行しておいてください。

推論を実行する predictor を取得する関数です。ちなみに ORQA には tf.py_func() がバシバシ入っているので SavedModel にするのはムリらしいです。

DATASET_PATH="./jaqket.resplit.test.jsonl"
PRINT_PREDICTION_SAMPLES=True

from language.orqa.models.orqa_model import serving_fn
from language.orqa.utils import eval_utils
from absl import logging

def get_predictor(model_dir):
  with tf.io.gfile.GFile(os.path.join(model_dir, "params.json")) as f:
    params = json.load(f)

  best_checkpoint_pattern = os.path.join(model_dir, "*.index")
  best_checkpoint = tf.io.gfile.glob(
      best_checkpoint_pattern)[0][:-len(".index")]
  serving_input_receiver = serving_fn()
  estimator_spec = model_fn(
      features=serving_input_receiver.features,
      labels=None,
      mode=tf.estimator.ModeKeys.PREDICT,
      params=params)
  question_tensor = serving_input_receiver.receiver_tensors["question"]
  session = tf.train.MonitoredSession(
      session_creator=tf.train.ChiefSessionCreator(
          checkpoint_filename_with_path=best_checkpoint))

  def _predict(question):
    return session.run(
        estimator_spec.predictions, feed_dict={question_tensor: question})

  return _predict

では実行してみましょう。

predictor = get_predictor(FLAGS.model_dir)
example_count = 0
correct_count = 0
with tf.io.gfile.GFile(DATASET_PATH) as dataset_file:
  for i, line in enumerate(dataset_file):
    example = json.loads(line)
    question = example["question"]
    answers = example["answer"]
    predictions = predictor(question)
    predicted_answer = six.ensure_text(predictions["answer"], errors="ignore")
    is_correct = eval_utils.is_correct(
          answers=[six.ensure_text(a) for a in answers],
          prediction=predicted_answer,
          is_regex=False)
    correct_count += int(is_correct)
    example_count += 1
    if PRINT_PREDICTION_SAMPLES and i & (i - 1) == 0:
      logging.info("[%d] '%s' -> '%s'", i, question, predicted_answer)

logging.info("Accuracy: %.4f (%d/%d)", correct_count / float(example_count),
               correct_count, example_count)
# INFO:tensorflow:Graph was finalized.
# INFO:tensorflow:Restoring parameters from ./orqa_model/model.ckpt-261200
# INFO:tensorflow:Running local_init_op.
# INFO:tensorflow:Done running local_init_op.
# 
# INFO:absl:[0] '和名をハダカカメガイといい、実は巻き貝の一種とされている、その姿から「流氷の天使」と呼ばれる動物は何でしょう?' -> 'ムツゴロウ'
# INFO:absl:[1] '作家のルスティケロが、マルコ・ポーロから聞いた話をまとめた作品といえば何でしょう?' -> '東方見聞録'
# INFO:absl:[2] '『騎士団長殺し』『1Q84』『ノルウェイの森』といった小説の作者は誰でしょう?' -> 'エーリヒ・マリア・レマルク'
# INFO:absl:[4] '今から約140億年前に起こったとされる、宇宙の始まりの大爆発のことを何というでしょう?' -> 'ビッグバン'
# INFO:absl:[8] 'ウォーターゲート事件によって辞任に追い込まれた、時のアメリカ大統領は誰でしょう?' -> 'リチャード・ニクソン'
# INFO:absl:[16] '提唱したドイツの統計学者の名に由来する、家計の支出に占める飲食費の割合を何というでしょう?' -> 'エンゲル係数'
# INFO:absl:[32] '一段進めた香車の下に王を入れて、周りを金や銀で守り固める将棋の戦法を、ある動物の名前をとって何というでしょう?' -> 'カメレオン戦法'
# INFO:absl:[64] '秋田の「西馬音内盆踊り」、岐阜の「郡上おどり」と並んで日本三大盆踊りと称される、毎年夏に徳島県で行われる盆踊りといえば何でしょう?' -> '阿波踊り'
# INFO:absl:[128] 'そのタイトルは「ファーストキスの長さ」を意味している、2015年にリリースされたHKT48の5枚目のシングルは何でしょう?' -> '早送りカレンダー'
# INFO:absl:[256] 'ヤクザやピカイチといった言葉の語源となったカードゲームは何でしょう?' -> '花札'
# INFO:absl:[512] 'アゴヒゲ、タテゴト、ゼニガタ、ゴマフなどの種類がある、海に棲むほ乳類は何?' -> 'ナマコ'
# INFO:absl:Accuracy: 0.3049 (304/997)               

どうでしょう?間違えた解答も人名のところには人名、動物のところには動物の名前を回答しているのがいい感じです。「早送りカレンダー」もちゃんと HKT48 の 11枚目のシングルらしいです。センターは矢吹奈子と田中美久(Wikipedia調べ)。

ORQA のコードにはデモもついています。predicor が動くところまできていれば後は簡単なので遊んでみてください。画面はこんな感じです。

orqa

「ファン・ゴッホ」で正解!と思いきや、データセット的には「フィンセント・ファン・ゴッホ」で不正解になってました。。。 しかし、よく見ると block の中には「アルルの跳ね橋」も「ひまわり」も「オランダ」も登場してませんよね。これで「ファン・ゴッホ」を抜いたのはけっこうスゴイ気がしますね。

6. おわりに

今回は 複数の BERT を組み合わせて利用するモデルを試してみようということで、 ORQA を試してみました。文字列処理が多い分だけ修正箇所が増えてしまいましたが、雰囲気をつかんでもらえれば嬉しいです。しかし最近のモデルは大きすぎて12 Colab で動かすのは厳しくなってきましたね。。。

さて次回は今回紹介できなかった ScaNN の話にするか、Huggingface Transformers で T5 を動かす話にするか、はたまた REALM の話にするか。。。どうしよっかなーと思っているところです。


  1. https://arxiv.org/abs/1906.00300 別に Google 大好き人間ではないのですが、 Colab の TPU を使おうとすると Google さんが公開してくれるコードが都合よいので、なんとなくそうなってしまうというか。 

  2. https://rajpurkar.github.io/SQuAD-explorer/ 

  3. https://github.com/google-research/language/blob/dd73e14ae89735c9c05424dd6096550dad273ee2/language/orqa/models/orqa_model.py#L56-L62 

  4. https://github.com/google-research/language/blob/dd73e14ae89735c9c05424dd6096550dad273ee2/language/orqa/preprocessing/wiki_preprocessor.py#L48-L89 

  5. https://arxiv.org/abs/1908.10396 

  6. https://github.com/google-research/google-research/tree/master/scann 

  7. https://github.com/google-research/language/blob/dd73e14ae89735c9c05424dd6096550dad273ee2/language/orqa/datasets/ict_dataset.py#L23-L86 

  8. https://www.nlp.ecei.tohoku.ac.jp/projects/jaqket/ 

  9. 配布元に連絡しておきましたが、コンペしているので、直すといろいろ話がややこしくなっちゃうかもですね。 

  10. tf.py_func() は計算グラフの処理を途中で普通の python 関数を呼び出せる奴ですね。本当は完全に互換の関数を作りたかったのですが、ちょっと厳しかったのでほぼ互換関数にしたような。どの辺が「ほぼ」だったかは忘れてしまいました。 

  11. https://arxiv.org/abs/2002.08909 

  12. https://arxiv.org/abs/2101.03961 とうとうパラメータ数が兆を超えてしまいました。お豆腐屋さんじゃないんだから!ってやつですね。