前回はテキストマイニングの手法と OSS を用いた実践について紹介しました。今回は、Google の T5(Text-to-Text Transfer Transformer) によるテキスト生成について、学習や推論のコード例と実験結果を交えてご紹介します。
1. はじめに
本記事では Google の T5(Text-to-Text Transfer Transformer) 1によるテキスト生成について、学習や推論のコード例と実験結果を交えてご紹介します。実験としては livedoor ニュースコーパス2での文章分類、やさしい日本語コーパス3及びやさしい日本語拡張コーパス4を用いたやさしい日本語変換を行いました。今回も Google Colaboratory で動かすことを想定したコードスニペットを入れていきますので、実際に動かしたり対象を変えてみたりして試して頂けると良いかと思います。 以後、本記事で特に説明なく「論文」と記載があれば、T5の論文1を指すものとします。
(前回予告の極性辞書はまた機会がありましたら、ということで。関心の対象がこっちに移ってしまい。。。)
2. T5 : Text-To-Text Transfer Transformer
Text-to-Text Transfer Transformer(以後単に“T5”) は分類、翻訳、要約といった様々な自然言語処理タスクを “Text-to-Text” で解くモデルです。
“Text-to-Text” とは入力を"タスク:問題"、出力を"回答"の形式として、全てのタスクを同じモデルで解いてしまおうという訳です。以下がそのイメージになります。学習データだけ変えれば同じモデルで様々なタスクが解けるというのは魅力的ですね。
“Transfer” と “Transformer” は第3回の BERT の記事で紹介した、大規模データによる事前学習済みモデルからの転移と Self-Attention を用いた Transformer のことです。
T5 は Transfomer の技術をベースにモデルの構成、事前学習の目的関数、事前学習のデータセット、学習方法、モデルのサイズなど様々パターンについて比較検証し GLUE, SuperGLUE で当時の SOTA を達成しました。
論文中の比較検証におけるベースラインの概要は以下のとおりです。この構成を基準とし各評価軸について条件を変えて実験しています。
- モデルの構成は“Encoder-Decoder”で、Encoder/Decoderともに Transformerスタックのサイズは BERTBASE 相当(パラメータ数:約2.2億)。
- 1バッチあたり 65,536トークンで 524,288 ステップ学習。
- トークン化には SentencePiece を用い、語彙数は 32,000。
- 事前学習の目的関数は “Random spans”。
ここからは各評価軸についてその概要を見ていきましょう。
2.1 モデルの構成
モデルの構成については論文の 3.2 節に記述されており、Encoder-Decoder, Language model, Prefix LM の比較になっています。
Encoder-Decoder
Encoder-Decoder は “Attention is All You Need” 5 で提案されたモデルの構造をほぼ踏襲したものです。以下は“Attention is All You Need"で用いられたモデルの構造です。
縦2列の左側が Encoder、右側が Decoder です。ちなみにこの構造から Encoder 部分を抜き出したのが BERT の構造になります。
上図とT5の構造上の相違点はトークンの位置情報の埋め込みに関する部分です。上図では "Positional Encoding” がトークン位置の関数になっており、この値を入力に加算しています。
しかし、T5 では「先頭からx番目」という絶対的な位置に対する埋め込み表現は使いません。最近の Transformer ベースの研究のトレンドにのっとり Self-Attention の計算において 「Key と Query の相対位置」に対応する埋め込み表現を学習、これを Attention の重みを算出する softmax の直前で Query と Key の内積にバイアスとして加算します。「Key と Query の相対位置」に対する埋め込みは、Multi-Head Attention の Head 毎に異なる埋め込みが使われますが、スタック(上図の縦1列)の全ての層で共有されています。言葉では少々分かりにくいですが、ソースコード的には、6, 7の辺り、"relative_attention_type"
には "bias_shared"
が使われます。
Language model
Language model では、Encoder-Decoderの図の Decoder (右側の列)のみを使います。ステップ i の出力から単語をサンプリングしてステップ i+1 の入力にするという具合に自己再帰的に出力を生成します。OpenAI GPT 8 やフェイクニュースへの悪用を懸念してフルサイズの学習済みモデルを公開しなかったことで話題となった GPT-2 9がこのタイプになります。
Prefix LM
Language model を “Text-to-Text” で使う際の欠点は、先頭から現在位置までのトークンの並びを見て、次のトークンを予測するという問題設定の為、BERT のような双方向の依存性を学習できないことです。Prefix LM は Attention のマスクのかけ方を工夫して入力文に相当する部分(=Prefix)には双方向、出力文に相当する部分には単方向の可視性を持たせたものです。
以下は Encoder-Decoder, Language model, Prefix LM における入出力シーケンスの要素と依存関係のイメージ図です。
図中の各角丸正方形は入出力シーケンスの要素を示しており、線が要素間の依存関係です。色は Transformer スタックの違いです。 スタック毎のパラメータ数を揃えて比較したところ(Language model, Prefix LM のパラメータ数を P とすれば、Encoder-Decoder はスタックが2つあるので 2P)、Encoder-Decoder が最良の性能だったとのことです10。
次は事前学習の目的関数について見ていきましょう。
2.2 事前学習の目的関数
事前学習の目的関数としては、"Prefix language modeling", “Masked language modeling”, “Deshuffling” の3つの手法が検討され、"Masked language modeling" については更に細かいバリエーションが検討されています。 論文の Table.3 に検討された各目的関数の処理のイメージが分かりやすく示されています。
それでは、上から順番に見ていきましょう。
- Prefix language modeling : これは普通の言語モデルというか、文章の出だしの部分が示され、その後にどう続くかを予測するものです。
- BERT-Style : BERT の事前学習ですね。15% のトークンを伏字にして、そのうち 90% を
"<M>"
に、残り10% をランダムなトークン(図中では灰色文字の“apple”)に置き換えて、元の文章を復元します。 - Deshuffling :トークンの順番を並び替えて、元の文章を復元します。
ここまでの 3つで比較すると “BERT-Style” が最良だったとのことで、以下は“BERT-Style"をベースに事前学習の高速/軽量化を狙った工夫になります。
- I.i.d noise, mask tokens : BERT-Style からランダムトークンへの置き換え(灰色文字の"apple”)を廃したものです。
- I.i.d noise, replace spans : 連続した伏字(=伏字スパン)を単一の特殊トークン(
"<X>"
や"<Y>"
)で置き換え、特殊トークンの該当部が何であったかを予測します。 - I.i.d noise, drop tokens : 伏字部分を削除し、削除された部分を予測します。
- Random spans: 単語単位に伏字化すると連続した伏字部分ができにくい為、トークンの何%を伏字にするか、伏字スパンの数を幾つにするかを指定します。例えば、500トークンで伏字化率=15%、伏字スパン数=25であれば伏字スパンの平均長は3となります。
上記について実験した結果、Random spans で伏字化率=15%、伏字スパンの平均長=3 が最良だったとのことです11。
さて、次は事前学習に使うデータセットについてです。
2.3 事前学習のデータセット
この論文のために Google は Colossal Clean Crawled Corpus(以下 “C4”) と名付けられた巨大なデータセットを作りました。
世界中のWebサーバをクロールして収集されたペタバイト級のコーパスとして Common Crawl 12があり、今も毎月(!)、20TBのデータが公開されています。ただ、 Common Crawl はマークアップ等は取り除かれているものの、自然言語でない内容やエラーメッセージ、メニュー、重複テキスト、ソースコード等がある為、Common Crawlの1月分に様々なクリーニング処理を行って作られたのが、C4 です。データ量は 745GB あり、Wikipedia 英語版の 46倍のサイズです。
論文の Table.8 に C4 を含めた6つのデータセットの比較結果が示されています。
比較対象とされたデータセットは以下のような内容です(簡単のためだいぶざっくり書いてます)。
- C4 : Common Crawl に様々なクリーニング処理13を施して作られたデータセットです。
- C4, unfiltered : C4 から"英語"以外のフィルタリング処理を排除したもの。
- RealNews-like : C4 にニュース記事のコンテンツのみ抽出する処理を追加したもの。
- WebText-like : Common Crawlの12か月分に C4 と同様のクリーニング処理を施し Reddit で 3 以上の評価を受けたもののみ抽出したもの。
- Wikipedia : Tensorflow Datasets の Wikipedia英語版のデータです。
- Wikipedia+TBC : Wikipedia では内容のドメインが百科事典限定になるので、 Tronto Books Corpus(TBC) の様々な電子書籍のデータを加えたものです。
ぱっと見ると、「 C4 すごいな」という印象はないですが、論文では以下の点が指摘されています。
- “C4, unfiltered” の結果を見るとデータの質が結果に大きく影響することが見て取れる。
- “Wikipedia+TBC"、"RealNews-like”, “Wikipedia"の結果からは後続タスクの対象ドメインに合ったデータセットによる事前学習で精度が向上すると言える14。
また、表に "Size” があるので誤解しそうですが、この比較は事前学習における学習トークン数を 235 ≒ 350億トークンで統一しています。この為、事前学習過程において、"C4", “C4, unfiltered” ではデータセットを1周回りきっていませんし、"Wikipedia" では複数周回っていることになる点についてご注意ください。
しかし、日本語で学習することを考えると必要なデータ量が気になるところです。事前学習のデータ量と精度の関係については、論文の Table.9 に記載があります。
これらは異なるサイズのデータセットで 235 トークン学習した結果です。学習トークン数が固定なのでデータ量が小さくなると学習中の周回数が増えていきます。事前学習中の周回数が増えると精度が低下していくわけですが、229≒5.4億トークンくらいまでは、あまり大きな低下はないようです。
筆者が Tensorflow Datasets の Wikipedia 日本語版のデータを 語彙数 32,000 の SentencePiece モデルを使ってトークン数を数えたところ約10億トークンになりました。これなら Wikipedia 日本語版を使った事前学習でもそこそこのモデルが作れそうです。
次は学習方法についてです。
2.4 学習方法
論文の 3.5 ではファインチューニングとマルチタスク学習について比較検討がなされています。
ファインチューニング
ファインチューニングの方法については以下のパターンが比較検討されていますが、結果としては “All parameters” が最良であったとのことです15。
- All parameters : ファインチューニング時に全パラメータを更新します。
- Adapter layers : Transformer の各ブロック末端に Adapter layer(dense-ReLU-dense のブロック)を挿入し、ファインチューニング時は Adapter layer と layer normalization のパラメータのみ更新します。Adapter layer の dense の次元数は複数比較されています。
- Gradual unfreezing : ファインチューニング開始時はスタックの最終層のパラメータのみ更新、学習が進むに合わせ徐々にパラメータ更新する層を先頭に向かって拡大し最終的には全パラメータを更新します。
マルチタスク学習
マルチタスク学習は複数のタスクを同時に学習させ、単一のモデルが同時に複数タスクを解けるようにする手法です。T5 の場合は全てのタスクが “Text-to-Text” 形式なので、単に複数タスクの学習データをどういう割合で混ぜるか?という話になります。論文では以下の 3つの戦略が比較されています。結果としてはどの戦略もファインチューニングには及ばなかったとのことです16。
- Examples-proportional mixing : 各タスクのデータセットサイズに比例した確率で学習データをサンプリングします。ただし、極端にデータ数が多いタスク(=事前学習タスク)の影響を抑える為にリミットを設定します。リミットのパラメータは複数比較されています。
- Temperature-scaled mixing : 各タスクの混合比を各タスクのサンプル数を 1/T 乗した上で、全タスクの合計が 1 になるように正規化します。T = 1 なら “Examples-proportional mixing” と等価、T を増加するに従い、"Equal mixing" に近づきます。T の値は複数比較されています。
- Equal mixing : 各タスクから同じ確率で学習データをサンプリングします。
また、マルチタスクでの事前学習の後に各タスク毎にファインチューニングする戦略も試していますが、こちらも事前学習+ファインチューニングに及ばなかったようです。
最後の評価軸はモデルのサイズです。
2.5 モデルのサイズ
モデルの構成としてベースラインとベースラインの約4倍の計算量になる以下のパータンを比較しています。
- 1 x size, 4 x training step : ベースラインのモデルでステップ数を4倍にして学習。
- 1 x size, 4 x batch size : ベースラインのモデルでバッチサイズを4倍にして学習。
- 2 x size, 2 x training step : ベースラインの2倍のパラメータ数のモデルでステップ数を2倍にして学習。
- 4 x size, 1 x training step : ベースラインの4倍のパラメータ数のモデルでステップ数はそのままで学習。
- 4x emsembled : ベースラインのモデルを4つ個別に事前学習+ファインチューニングしてアンサンブル。
- 4x emsembled, fine-tune only : ベースラインのモデルで共通の事前学習済みモデルから4つ個別にファインチューニングしてアンサンブル。
2倍、4倍サイズのモデルは Transformer スタックの各ブロックの構成を BERTLARGE相当にして、層数を 16層、32層にしたものです。また、学習ステップ数を増やしている場合は、事前学習時にモデルが見るデータのバリエーションが増えることに注意してください(ベースラインの 235 トークンの学習では C4 データセットの一部しか利用されていない為)。結果は論文の Table. 13 です。
どのパターンもそれなりにベースラインから精度が向上していますが、やはりモデルサイズを大きくすることが比較的効いているようです。モデルサイズを大きくすると、その分必要な学習データ量も増えるのではないかと思いましたが、"Baseline" と “4 x size, 1 x training step” の比較ではモデルサイズを4倍に拡大し、データ量は据え置きで性能向上していますね。ただし Wikipedia 日本語版あたりで事前学習すると、ベースラインの 235 トークンの学習過程でデータセット全体を35周くらいすることになるので、話は違ってくるかもしれません。
2.6 ここまでのまとめ
論文ではここまでの検証を踏まえて、パラメータ数 110億のモデルで SOTA に向けた実験に突入していきます。 T5 は何かしら新しいモデルの構造や手法の発明というよりは、Transformer の技術をベースに最新のトレンドや目的関数や学習の工夫を組み合わせたものであるということが分かってもらえたかと思います。
ここからは前述の知見を踏まえ、以下の方針で日本語データを使った検証をしていきたいと思います。
- モデルの構成は Encoder-Decoder。
- 事前学習の目的関数は Random spans。
- 学習方法は事前学習+ファインチューニング。
- 事前学習のデータセットは Wikipedia日本語版。
- モデルのサイズは事前学習のデータサイズも考慮してベースライン相当。
3. 日本語データを使った検証
ようやくですが、ここから日本語データを使った検証をしていきます。
T5 のソースは github 17 で公開されています(以降、"T5" と記述すれば論文、"t5" と記述すればソースコードやライブラリを指すものとします。)。ライブラリ的な使い方ができるようになっており、論文中の実験を再現できる各種設定ファイルも同梱されている為、基本的な扱い方を覚えてしまえば、BERT のコード 18よりも取り回しが良さそうな雰囲気です。
また、記事内のコードスニペットは、特に断りがない場合は Google Colaboratory (以下、Colab)で動かす想定です。
3.1 環境のセットアップ
まずは、"ランタイム“ -> "ランタイムタイプの変更” で “ハードウェアアクセラレータ” を “TPU” にしてください。せっかくなので Colab の Free TPU を利用する方法を紹介します。
そろそろ Colab の Tensorflow のバージョンが変わりそうなので、マジックコマンドを実行しておきます。
%tensorflow_version 1.x
記事公開後に発生した各ライブラリの更新により、記事の内容が実行できなくなっているとのコメント頂きました。 最新版準拠で書き直すのは厳しいのですが、ライブラリを次のようにバージョン指定することで、実行可能ですのでお試しください。( 2020/6/4 追記 )
!git clone https://github.com/google-research/text-to-text-transfer-transformer !cd text-to-text-transfer-transformer && git checkout bf46737 !pip install t5[gcp]==0.1.7 !pip install mesh-tensorflow==0.1.8 !pip install tensorflow-datasets==1.3.2 !pip install tensorflow==1.15.0
t5 をインストールします。今回の実験は以下のバージョンを使用してます。
!pip install t5[gcp] # ... !pip list | grep -e "t5" -e "tensorflow " -e "tensorflow-text" # mesh-tensorflow 0.1.8 # t5 0.1.7 # tensorflow 1.15.0 # tensorflow-text 1.15.0rc0
ソースもクローンしておきます。
!git clone https://github.com/google-research/text-to-text-transfer-transformer
GCS にチェックポイントを書き出すので認証を行います。
from google.colab import auth auth.authenticate_user()
3.2 SentencePiece モデルの学習
t5 は文章のトークナイズに SentencePiece を使います。今回は Wikipedia 日本語版で事前学習をするので、このデータを使って SentencePiece のモデルを作ります。
※この節は最後に Colab だとメモリが足りないんですよね、というオチがあるので注意して下さい。
まずは必要なツールをインストールします。
!git clone https://github.com/google/sentencepiece !apt-get install cmake build-essential pkg-config libgoogle-perftools-dev
SentencePiece をビルドします。
%%bash cd ./sentencepiece mkdir build cd build cmake .. make -j 8 make install ldconfig -v
次に Tensorflow Datasets から Wikipedia 日本語版のデータをテキストファイルに出力します。
ちなみに筆者が実験したときは try_gcs=True
がないとエラーになりました。Wikipedia 側の変更が原因だったようですが、このオプションを使うと Google が GCS 上に用意したキャッシュを参照するので、この問題を回避することができます。
import tensorflow_datasets as tfds import tensorflow as tf ds = tfds.load(name='wikipedia/20190301.ja', shuffle_files=True, download=True, try_gcs=True) train_ds = ds["train"].batch(128).prefetch(10) all_titles = [] all_texts = [] for example in tfds.as_numpy(train_ds): titles, texts = example["title"], example["text"] for title, text in zip(titles, texts): all_titles.append(title.decode('utf-8')) all_texts.append(text.decode('utf-8')) with open("input.txt", "w") as f: for text in all_texts: lines = [line.strip() for line in text.split("\n")] for line in lines: if len(line) == 0: continue f.write(line + "\n")
SentencePiece の学習への入力データは一行一文なので以下のように加工します。bert-japanese の加工を参考にしました19。
%%bash cat << EOF > preprocess.sh #!/bin/bash FILE=\$1 if [ \$# -ne 1 ]; then echo "Usage: ./preprocess.sh INPUT_TEXT" exit 1 fi echo "Processing \${FILE}" sed -i -e '/^$/d; /<doc id/,+1d; s/<\/doc>//g' \${FILE} sed -i -e 's/ *$//g; s/。\([^」|)|)|"]\)/。\n\1/g; s/^[ ]*//g' \${FILE} sed -i -e '/^。/d' \${FILE} sed -i -e 's/\(.*\)/\L\1/' \${FILE} EOF chmod 744 preprocess.sh ./preprocess.sh input.txt
ようやく SentencePiece モデルの学習です。t5 のソースコード20に以下の記述がありました。
Assumes the model was built using flags in build_sentencepiece_model.sh, which reserve ID=0 is for padding, ID=1 for EOS, and ID=2 for UNK.
build_sentencepiece_model.sh
は謎のままなのですが、コメントを信じて以下のように学習をします。
!/usr/local/bin/spm_train --input=./input.txt --model_prefix="wikipedia_20190301_ja_v003" --vocab_size=32000 --character_coverage=0.9995 --model_type=unigram \ --pad_id=0 --eos_id=1 --unk_id=2 --bos_id=-1
ただし、このコードはメモリ不足でエラーになるので --input_sentence_size=<size> --shuffle_input_sentence=true
を設定して入力データを絞って我慢するか(Colab のメモリ12GB のインスタンスであれば、--input_sentence_size=12000000
くらい)、メモリが潤沢な環境で実行するかして下さい。58GBくらいあれば大丈夫そうでした。
モデルができたら GCS に置いておきましょう。
!gsutil cp ./wikipedia_20190301_ja_v003* gs://somewhere/t5/sentencepiece/
3.3 事前学習の実行
それでは Wikipedia 日本語版のデータを使って事前学習をしてみましょう。
Task に Wikipedia 日本語版を登録する
t5 には Task
という、データセット、前処理、使用する SentencePiece のモデル、精度評価のメトリクス関数等をまとめた概念があります。Task
の上位概念に Mixture
があり、これは複数の Task
をまとめたものです。
t5 で学習をするときは Task
なり Mixture
なりを指定して実行する形式なので、まずは Wikipedia 日本語版に対応する Task
を作り、t5 から使えるようにしなければなりません。
SPM_PATH = "gs://somewhere/t5/sentencepiece/wikipedia_20190301_ja_v003.model" import functools from t5.data import preprocessors from t5.data.utils import TaskRegistry from t5.data.utils import MixtureRegistry from t5.data.utils import TfdsTask task_name_wikipedia_ja = "wikipedia_20190301.ja_v003_unsupervised" TaskRegistry.add( task_name_wikipedia_ja, TfdsTask, tfds_name="wikipedia/20190301.ja:0.0.3", text_preprocessor=functools.partial( preprocessors.rekey, key_map={"inputs": None, "targets": "text"}), token_preprocessor=preprocessors.unsupervised, sentencepiece_model_path=SPM_PATH, metric_fns=[]) MixtureRegistry.add(task_name_wikipedia_ja, [(task_name_wikipedia_ja, 1.0)])
TPU アドレスの確認
Free TPU で事前学習を回すために TPU のアドレスを確認します。
import os import pprint import json import tensorflow as tf assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!' TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR'] print('TPU address is', TPU_ADDRESS) #TPU address is grpc://xx.xx.xx.xxx:xxxx
事前学習の実行
事前学習は数時間では終わりません。Colab は(当たり前ですが)長時間の計算を実行するユーザよりインタラクティブな利用をしているユーザを優先して計算資源を割り当てます。あまり無茶な使い方をすると計算資源の割り当てを拒否されてしまいますし、みんなの Colab ですから、節度ある利用をしましょう!
以下のコードはほぼ t5.models.mesh_transformer_main
からの転記になります(長くなるのでコメント部分や改行を省略しています)。筆者が試した限りでは、Colab の カーネルの中から実行しないと Free TPU に接続できなかったので、このようにしています。
from __future__ import absolute_import from __future__ import division from __future__ import print_function import importlib import os import sys from absl import app from absl import flags import gin from mesh_tensorflow.transformer import utils import pkg_resources import t5 import tensorflow.compat.v1 as tf flags.DEFINE_string("tpu_job_name", None, "Name of TPU worker binary.") flags.DEFINE_string("model_dir", "/tmp/transformer_standalone", "Estimator model_dir") flags.DEFINE_string("tpu", None, "The Cloud TPU to use for training.") flags.DEFINE_string("gcp_project", None, "Project name for the Cloud TPU-enabled project.") flags.DEFINE_string("tpu_zone", None, "GCE zone where the Cloud TPU is located in.") flags.DEFINE_multi_string("module_import", None, "Modules to import.") flags.DEFINE_string("t5_tfds_data_dir", None, "used to store datasets prepared by TensorFlow Datasets") flags.DEFINE_list("additional_task_cache_dirs", [], "Directories to search for Tasks") FLAGS = flags.FLAGS def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.t5_tfds_data_dir: t5.data.set_tfds_data_dir_override(FLAGS.tfds_data_dir) t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs) gin.add_config_file_search_path(pkg_resources.resource_filename(__name__, "gin")) tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) utils.parse_gin_defaults_and_flags() utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
FLAGS
変数には「ちゃんとパースしましたよ」とフラグを立ててあげて、
FLAGS.mark_as_parsed()
次に本来は t5_mesh_transformer
コマンドのコマンドライン引数に指定する内容を直接 FLAGS
、tf.flags.FLAGS
に設定して、事前学習を開始します。Colab のインスタンスは時間制限があるので、checkpoint の出力先(model_dir
)は GCS にしておきます。
FLAGS.tpu = TPU_ADDRESS FLAGS.model_dir = 'gs://somewhere/t5/wikipedia_20190301.ja_v003/model' tf.flags.FLAGS.gin_location_prefix=["./text-to-text-transfer-transformer/t5/models/gin"] tf.flags.FLAGS.gin_file=[ "dataset.gin", "models/bi_v1.gin", "objectives/span_3_15_u_u.gin", "learning_rate_schedules/rsqrt_no_ramp_down.gin"] tf.flags.FLAGS.gin_param=[ "utils.tpu_mesh_shape.model_parallelism = 1", "utils.tpu_mesh_shape.tpu_topology = '2x2'", "run.batch_size = ('tokens_per_batch', 65536)", "run.train_steps = 524288", "MIXTURE_NAME = 'wikipedia_20190301.ja_v003_unsupervised'" ] tf.disable_v2_behavior() tf.logging.set_verbosity(tf.logging.INFO) main([])
事前学習の構成
前述のコードで設定した内容は以下のような意味合いです。
FLAGS.tpu
に TPUアドレスを指定。FLAGS.model_dir
に checkpoint の出力先を指定。-
tf.flags.FLAGS.gin_location_prefix
に gin ファイルのサーチパスにローカルにクローンしたソースのパスを追加21 。 dataset.gin
は学習用データセット関数や使用する SentencePiece モデル等の下準備。- モデルの構成はスタックが BERTBASE 相当の Encoder-Decoder (
bi_v1_gin
)。 - 目的関数はトークンの15% をノイズに置き換え。ノイズの平均スパン長は3トークン (
span_3_15_u_u.gin
) - 学習レートは論文 3.1.2 節に従い、 イテレーション数 n の関数で 1 / sqrt(max(n,k)), k=104 とする。(
rsqrt_no_ramp_down.gin
)。 - バッチサイズ、学習ステップ数は論文の 3.1.2節 に従いそれぞれ、 216=65536 トークン、219=524288 ステップとする(
run.batch_size
,run.train_steps
)。 - 事前学習のタスクは先ほど準備した Wikipedia日本語版を指定(
MIXTURE_NAME
)。
また、学習時の並行処理に関する設定は以下の通りです。t5 の Transformer は Mesh Tensorflow 22 で実装されている為、個々の実行環境に合わせた複数TPU/GPUでの並行処理が非常に容易になっています。
- Free TPU は 8コアなので、"2x2"(8コア)を指定する23 。(
utils.tpu_mesh_shape.tpu_topology
)。 - モデル並行は 1 を設定。8コアでモデル並行が 1 なので、データ並行数 8 での実行となる(
utils.tpu_mesh_shape.model_parallelism
)。
gin 24 についても補足しておきます。gin は Python 向けのコンフィグレーションフレームワークで、t5 や Mesh Tensorflow で利用されています。Python のクラスや関数に @gin.configurable
アノテーションを付け、--gin_param
や --gin_file
でアノテーションを付与したクラスのコンストラクタや関数のパラメータに値を設定します。--gin_param
は単一パラメータの設定、--gin_file
は複数設定を記述したファイル単位での設定になります。
事前学習が終わったら次はファインチューニングです。
3.4 ファインチューニング(文章分類)
ここからは livedoor ニュースコーパスを用いて文章分類をしてみます。事前学習は TPU で動かしたので、今度は GPU でやってみましょう。Colab を使用している場合は、新しいノートブックを開いてください。3.1 環境のセットアップに示した手順で再度セットアップします。ただしランタイムアクセラレータには GPU を選んでください。
GPU でファインチューニングする際もライブラリを次のようにバージョン指定することで、本記事の内容を実行可能です。( 2020/6/4 追記 )
!pip install t5[gcp]==0.3.0 !pip install tensorflow-gpu==1.15.3
データは以下のような形式の TSV にします。また半角英字は全て小文字化しました。ラベルは ['dokujo-tsushin', 'it-life-hack', 'kaden-channel', 'livedoor-homme', 'movie-enter', 'peachy', 'smax', 'sports-watch', 'topic-news']
のインデックスです。
!head -10 t5_train.tsv #大島優子がここからどう破滅していくのか? 『闇金ウシジマくん...省略... 4 #インタビュー:クリスチャン・ベール「演じることができるのは役...省略... 4 #ブラックマジックデザイン、hyperdeck ssd レコーダーに タイム ...省略... 2 #...
学習/検証/テストの分割は以下のとおりです(この分割はこの連載で何回か使用したものと同一なのですが、もはや秘伝のタレ状態でどうやって分割したのか本人も覚えていません)。
!wc -l t5_*.tsv. # 1472 t5_dev.tsv # 1472 t5_test.tsv # 4420 t5_train.tsv # 7364 total
livedoor ニュースコーパスの TextLineTask の作成
t5 にはテキストファイルを学習に使う為のクラスとして TextLineTask
が用意されていますので、こちらのクラスを使って Task
を作成します。クラス数 9 の文章分類になるので、メトリクスとして mean_multiclass_f1(num_classes=9)
を指定しています。また、Free TPU を使用しない場合は、Colab のカーネルで動かす必要がないので、ファイルに出力しておきます。
%%bash cat <<EOF > add_ldcc.py import functools from t5.evaluation import metrics from t5.data import preprocessors from t5.data.utils import TaskRegistry from t5.data.utils import TextLineTask task_name_ldcc = "ldcc" ldcc_tsv_path = { "train": "./t5_train.tsv", "validation": "./t5_dev.tsv", "test": "./t5_test.tsv", } TaskRegistry.add( task_name_ldcc, TextLineTask, split_to_filepattern=ldcc_tsv_path, text_preprocessor=[ functools.partial( preprocessors.parse_tsv, field_names=["inputs", "targets"]), ], sentencepiece_model_path="gs://somewhere/t5/sentencepiece/wikipedia_20190301_ja_v003.model", metric_fns=[metrics.mean_multiclass_f1(num_classes=9)]) EOF
ファインチューニングの実行
ファインチューニングを行うフォルダに事前学習のチェックポイントと operative_config.gin
をコピーしておきます。
operative_config.gin
は事前学習時に指定した gin の設定が全て出力されたファイルで、このファイルを読み込むことで設定内容を引き継ぐことができます。
!gsutil cp gs://somewhere//t5/wikipedia_20190301.ja_v003/model/checkpoint gs://somewhere//t5/wikipedia_20190301.ja_v003/ldcc/ !gsutil cp gs://somewhere//t5/wikipedia_20190301.ja_v003/model/model.ckpt-524288* gs://somewhere//t5/wikipedia_20190301.ja_v003/ldcc/ !gsutil cp gs://somewhere//t5/wikipedia_20190301.ja_v003/model/operative_config.gin gs://somewhere//t5/wikipedia_20190301.ja_v003/ldcc/
以下のようにしてファインチューニングを実行します。
!export PYTHONPATH=${PYTHONPATH}:. && \ \ PRE_TRAINED_MODEL_DIR='gs://somewhere/t5/wikipedia_20190301.ja_v003/model' && \ OPERATIVE_CONFIG=$PRE_TRAINED_MODEL_DIR'/operative_config.gin' && \ FINE_TUNED_MODEL_DIR='gs://somewhere/t5/wikipedia_20190301.ja_v003/ldcc' && \ FINE_TUNING_BATCH_SIZE=`expr 512 \* 8` && \ PRE_TRAINGING_STEPS=524288 && \ FINE_TUNING_STEPS=`expr $PRE_TRAINGING_STEPS + 14000` && \ TARGET_SEQ_LEN=3 &&\ \ echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\ echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\ echo "FINE_TUNING_BATCH_SIZE=$FINE_TUNING_BATCH_SIZE" &&\ echo "PRE_TRAINGING_STEPS=$PRE_TRAINGING_STEPS" &&\ echo "FINE_TUNING_STEPS=$FINE_TUNING_STEPS" && \ echo "TARGET_SEQ_LEN=$TARGET_SEQ_LEN" && \ \ t5_mesh_transformer \ --model_dir="$FINE_TUNED_MODEL_DIR" \ --module_import="add_ldcc" \ --gin_file="dataset.gin" \ --gin_file="$OPERATIVE_CONFIG" \ --gin_param="run.layout_rules=''" \ --gin_param="run.mesh_shape=''" \ --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \ --gin_param="MIXTURE_NAME = 'ldcc'" \ --gin_file="learning_rate_schedules/constant_0_001.gin" \ --gin_param="run.train_steps=$FINE_TUNING_STEPS" \ --gin_param="run.sequence_length = {'inputs': 512, 'targets': $TARGET_SEQ_LEN}" \ --gin_param="run.save_checkpoints_steps=1000" \ --gin_param="run.batch_size=('tokens_per_batch', $FINE_TUNING_BATCH_SIZE)"
いくつかポイントになる部分を補足しておきます。
--module_import
で先ほど作成したadd_ldcc.py
をインポートして、t5 がタスクを見つけられるようにする。- Colab の GPU は 1コアなので
run.layout_rules
,run.mesh_shape
に""
を指定。これで並行実行を指定しない設定になる。 utils.get_variable_dtype.activation_dtype
はデフォルトで TPU 用のbfloat16
なのでfloat32
に切り替え。MIXTURE_NAME
はadd_ldcc.py
で付けた名前(ldcc
)を指定。- 学習レートは論文 3.1.2 節に従い、 0.001 とする(
constant_0_001.gin
)。 - 1000ステップ毎にチェックポイントを出力(
run.save_checkpoints_steps
) - 出力は “0”, “1”,…“8” のラベルとするが、SentencePiece は先頭に特殊トークン(“?”)を付けるときがあること、シーケンス終了トークン(“
</s>
”)を考慮して出力シーケンス長を 3 とした(run.sequence_length
)。
ファインチューニング結果の検証
検証は以下のように行います。
!export PYTHONPATH=${PYTHONPATH}:. && \ \ FINE_TUNED_MODEL_DIR='gs://somewhere/t5/wikipedia_20190301.ja_v003/ldcc' && \ OPERATIVE_CONFIG=$FINE_TUNED_MODEL_DIR'/operative_config.gin' && \ \ echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\ echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\ \ t5_mesh_transformer \ --model_dir="$FINE_TUNED_MODEL_DIR" \ --module_import="add_ldcc" \ --gin_file="$OPERATIVE_CONFIG" \ --gin_param="run.layout_rules=''" \ --gin_param="run.mesh_shape=''" \ --gin_file="eval.gin" \ --gin_file="beam_search.gin" \ --gin_param="utils.get_variable_dtype.slice_dtype='float32'" \ --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \ --gin_param="MIXTURE_NAME = 'ldcc'" \ --gin_param="run.dataset_split='validation'" \ --gin_param="eval_checkpoint_step = 'all'" 2>&1 | tee eval.log
ファインチューニングの学習時と異なるポイントを押さえておくと、
- ファインチューニング時の
operative_config.gin
を読み込み。 eval.gin
を指定。add_ldcc.py
で指定したメトリクス(mean_multiclass_f1(num_classes=9)
)で精度がログに出力される。- 出力のサンプリング方法にビームサーチ25を使用 (
beam_search.gin
)。 - データセットに
validation
セットを指定、add_ldcc.py
で記述した内容に対応する。(run.dataset_split='validation'
) - 指定フォルダの全てのチェックポイントに対して検証を行う(
eval_checkpoint_step = 'all'
)。
出力結果は以下の通りで、ステップ 536288 が最良でした。
!cat eval.log | grep -e "^INFO.*f1" #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 524288: 0.000 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 525288: 95.442 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 526288: 97.015 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 527288: 97.210 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 528288: 97.044 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 529288: 97.231 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 530288: 97.382 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 531288: 97.115 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 532288: 97.169 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 533288: 97.207 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 534288: 97.188 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 535288: 97.187 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 536288: 97.487 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 537288: 96.803 #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 538288: 97.056
テストデータでの精度確認
以下のようにしてテストデータで精度を確認します。run.dataset_split='test'
と eval_checkpoint_step = 536288
以外は先ほどと同じです。
!export PYTHONPATH=${PYTHONPATH}:. && \ \ FINE_TUNED_MODEL_DIR='gs://somewhere/t5/wikipedia_20190301.ja_v003/ldcc' && \ OPERATIVE_CONFIG=$FINE_TUNED_MODEL_DIR'/operative_config.gin' && \ \ echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\ echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\ \ t5_mesh_transformer \ --model_dir="$FINE_TUNED_MODEL_DIR" \ --module_import="add_ldcc" \ --gin_file="$OPERATIVE_CONFIG" \ --gin_param="run.layout_rules=''" \ --gin_param="run.mesh_shape=''" \ --gin_file="eval.gin" \ --gin_file="beam_search.gin" \ --gin_param="utils.get_variable_dtype.slice_dtype='float32'" \ --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \ --gin_param="MIXTURE_NAME = 'ldcc'" \ --gin_param="run.dataset_split='test'" \ --gin_param="eval_checkpoint_step = 536288" 2>&1 | tee test.log
出力は以下のとおりです。
!cat test.log | grep -e "^INFO.*f1" #INFO:tensorflow:eval/ldcc/mean_9class_f1 at step 536288: 95.508
過去の連載での livedoor ニュースコーパスの分類結果と比較してみましょう。%単位で丸めてしまうと bert-japanese と同じになりました。どちらも encoder は BERTBASE 相当サイズのモデルを Wikipedia 日本語版で事前学習しているのですが、"Text-to-Text" 形式だからと言って精度に大きな劣化がないことが確認できました(decoder があるので T5 のパラメータ数は約2倍ですが)。
推論モードでの実行
最後に推論モードで動かしてみましょう。eval.gin
の代わりに infer.gin
を指定します。input_filename
に入力テキストファイルを指定すると、変換結果を output_filename
で指定したテキストファイル(末尾に“-ステップ番号"が付きます)に出力してくれます。
!export PYTHONPATH=${PYTHONPATH}:. && \ \ FINE_TUNED_MODEL_DIR='gs://somewhere/t5/wikipedia_20190301.ja_v003/ldcc' && \ OPERATIVE_CONFIG=$FINE_TUNED_MODEL_DIR'/operative_config.gin' && \ \ echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\ echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\ \ t5_mesh_transformer \ --model_dir="$FINE_TUNED_MODEL_DIR" \ --module_import="add_ldcc" \ --gin_file="$OPERATIVE_CONFIG" \ --gin_param="run.layout_rules=''" \ --gin_param="run.mesh_shape=''" \ --gin_file="infer.gin" \ --gin_file="beam_search.gin" \ --gin_param="utils.get_variable_dtype.slice_dtype='float32'" \ --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \ --gin_param="infer_checkpoint_step = 536288" \ --gin_param="input_filename = './infer_inputs.txt'" \ --gin_param="output_filename = './infer_ouputs.txt'" 2>&1 | tee infer.log
入力はこんな感じで、
!head -3 ./infer_inputs.txt #新記録でロンドンに乗り込む“バタフライの女王”加藤ゆか3日に行われた競泳の日... #家電チャンネルの記事も配信!向かうところ敵なしのスマホアプリ「itニュース b... #彼にあげたい韓国メンズコスメ、韓流俳優のような美肌男へ!年末の大イベント、... #独女と上司の気になる関係人事異動の多い春は、職場の人間関係の悩みも増える時...
こちらが変換結果です。
!cat infer_ouputs.txt-536288 #7 #2 #5 #4 #0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
"0” と判定した場合に “0 0 0 0 …” と際限なく出力するようになってしまいました。"0" は “10000” とか連続する傾向があるので、その影響が出たのかもしれませんね。素直にラベルを “dokujo-tushin” とかにしておいた方が良かったかもしれません。
ここからは、検証の本題であるテキスト生成にトライしてみましょう。
3.5 ファインチューニング(テキスト生成)
ここからは やさしい日本語(拡張)データセットを使ってテキスト生成の実験をしてみましょう。Colab を使用している場合は、新しいノートブックを開いてください。3.1 環境のセットアップに示した手順で再度セットアップ、ランタイムアクセラレータは GPU です。
さらに以下のパッケージをインストールして下さい。精度の計算に MeCab と chakki-works さんの sumeval 26 を一部利用します。
!apt-get install mecab mecab-ipadic-utf8 !pip install mecab-python3 sumeval
データの加工に使うので以下もインストールします。
!apt-get install nkf !pip install python-Levenshtein
やさしい日本語(拡張)コーパス
データセットとしては、長岡技術科学大学 自然言語処理研究室さんの「SNOW T15:やさしい日本語コーパス」、「SNOW T23:やさしい日本語拡張コーパス」を使用させて頂きました。この二つのコーパスは通常の日本語をやさしい日本語語彙(2000語)で書き換えた対訳コーパスです。筆者の知る限りでは日本語対日本語の対訳コーパスとして類似のものがなく非常にありがたいです。
以下のようにしてダウンロードします。 ※年明けに修正版が公開され、URLとファイル名が変わっています。適宜読み替えて下さい。
!curl -L -o T15.xlsx "https://filedn.com/lit4DCIlHwxfS1gj9zcYuDJ/SNOW/T15-2018.2.27.xlsx" # 現在はこちらから : https://filedn.com/lit4DCIlHwxfS1gj9zcYuDJ/SNOW/T15-2020.1.7.xlsx !curl -L -o T23.xlsx "https://filedn.com/lit4DCIlHwxfS1gj9zcYuDJ/SNOW/T23-2019.7.2.xlsx" # 現在はこちらから : https://filedn.com/lit4DCIlHwxfS1gj9zcYuDJ/SNOW/T23-2020.1.7.xlsx
まずは、やさしい日本語コーパスを pandas で読み込みます。
import pandas as pd snow_t15 = pd.read_excel('T15.xlsx') snow_t15 = snow_t15.rename(columns={'#日本語(原文)': 'input', '#やさしい日本語':'target'})[['ID', 'input','target']]
このコーパス、ありがたいのですが以下のような変換も結構混じっています。
確かに「『ピザ』を使わずに言い換えろ」と言われると難しいのですが、何というか概念の説明になっている感じがしたので今回の実験ではできるだけ除外して試すことにしました。ただ二つのコーパスを合わせて8万件以上を目検でチェックも出来ないので、原文とやさしい日本語のレーベンシュタイン距離27が 10 以上のものを除外しています。また、全角英数字は半角に直して小文字化してます。
import Levenshtein def levenshtein_distance(row): return Levenshtein.distance(row["target"], row["input"]) snow_t15['levenshtein_distance'] = snow_t15.apply(levenshtein_distance, axis=1) snow_t15 = snow_t15.query('levenshtein_distance < 10')[['input', 'target']] snow_t15.to_csv("temp.tsv", sep="\t", header=False, index=False) !cat temp.tsv | nkf -m0Z1 | tr "[:upper:]" "[:lower:]" > snow_t15.tsv
やさしい日本語拡張コーパスも同様に処理します。
snow_t23 = pd.read_excel('T23.xlsx') snow_t23 = snow_t23.rename(columns={'#日本語(原文)': 'input', '#やさしい日本語':'target'})[['ID', 'input','target']] snow_t23['levenshtein_distance'] = snow_t23.apply(levenshtein_distance, axis=1) snow_t23 = snow_t23.query('levenshtein_distance < 10')[['input', 'target']] snow_t23.to_csv("temp.tsv", sep="\t", header=False, index=False) !cat temp.tsv | nkf -m0Z1 | tr "[:upper:]" "[:lower:]" > snow_t23.tsv
二つのファイルを結合して 8:1:1 に分割します。
!cat snow_t15.tsv snow_t23.tsv > snow.tsv with open("snow.tsv", "r") as f: lines = f.readlines() lines = [line.strip() for line in lines] from sklearn.model_selection import train_test_split train, dev_test = train_test_split(lines, train_size=0.8, random_state=4) dev, test = train_test_split(dev_test, train_size=0.5, random_state=7) with open("snow_t15_23_train.tsv", "w") as f: f.write("\n".join(train)+"\n") with open("snow_t15_23_dev.tsv", "w") as f: f.write("\n".join(dev)+"\n") with open("snow_t15_23_test.tsv", "w") as f: f.write("\n".join(test)+"\n")
各ファイルの行数は以下の通りです。
!wc -l snow_t15_23_*.tsv # 7040 snow_t15_23_dev.tsv # 7040 snow_t15_23_test.tsv # 56317 snow_t15_23_train.tsv # 70397 total
中身はこんな感じです。
!head -5 snow_t15_23_train.tsv # では今晩またね、さようなら。 では今日の夜にまたね、さようなら。 # 夜のこんな時間に電話をかけるものではない。 夜のこんな時間に電話をかけるものではない。 # 愛とは夢にまで彼女を見ることだ。 愛とは夢にまで彼女を見ることだ。 # そのリンゴの木はよく実がなる。 そのりんごの木はよく実がなる。 # この馬は手に負えない。 この馬は扱うことができない
中身を見るとわかりますが、原文とやさしい日本語が同一のサンプルが結構混じっています。加工後のファイル(snow_t15_23_*.tsv
)ですと、全体の 30.6 %程が素通しです。どうしようかと思ったのですが、「この単語はやさしい日本語としてOKですよ」と教える意味があるだろうと判断し、そのままにしてあります。
それでは Task
を作ってファインチューニング!と行きたいところですが、やさしい日本語変換の精度評価についておさえておきましょう。
BLEU スコア
文章分類なら正答率なり F1 スコアなりで良いのですが、テキスト生成では BLEU スコアと呼ばれる指標を使うことが一般的なようです。
BLEU は元々機械翻訳の品質を自動評価するために考案された手法で、機械翻訳の出力がお手本となる人間が作成した参照訳にどの程度近いかを 0 ~ 1 のスコアで表現します。数式で示すと以下のとおりです。
- micand は、参照訳と一致する候補訳内の i グラムのカウントです。
- miref は、参照訳内の i グラムのカウントです。
- wit は、候補訳内の i グラムの総数です。
ようするに機械翻訳が出力した候補訳に含まれる全ての 1~4-GRAM の参照訳に対する適合率の相乗平均+αです。"+α“は以下のような処理です。
- 参照訳 : "I have a pen."、 候補訳 : "I have a pen I have a pen I have a pen.” のようなケースに対応する為に micand と miref の min をとります。
- 参照訳 : “I have a great premium mervelous expensive precious pen."、候補訳: "a pen.” のようなケースに対応する為、候補訳が参照訳より短い場合に brevity penalty をかけます。
また以下のような注意点があります。
- コーパス単位の指標として設計されているので文単位での評価は非推奨。
- コーパスを跨ったスコアの比較も非推奨。
- 他にも単語の意味合いを考慮していない("a" と “the” の違いと “pen” と “dog” の違いが同じ)。
- 文法の評価が十分でない(長めの N-GRAM を使ってある程度担保します。1~4-GRAMを使うのが一般的なようです)。
やさしい日本語変換のような文章平易化では BLEU に加え SARI という指標を併用するようです。SARI についてはテストデータでの精度評価のところで説明します。
それでは TextLineTask
を作成してみましょう。
やさしい日本語 TextLineTask の作成
まず、t5 には t5.metrics.bleu()
という関数が用意されていますが、当然のことながら日本語には対応していません。日本語での BLEU スコアの計算となると sumeval なのですが、API 的には文単位の計算のように見えます。文単位での評価は非推奨と説明した手前、ちょっと気になります。
#https://github.com/chakki-works/sumeval/blob/a5188a8304fd6eb89af9071c227964643670eb4d/sumeval/metrics/bleu.py#L34 def bleu(self, summary, references, score_only=True): """ Calculate BLEU score by sacrebleu. Parameters ---------- summary: str summary text references: str or str[] reference or references to evaluate summary score_only: bool when True, return only score
幸い、t5 も sumeval も計算の本体は sacreBLEU 28 を使用しています。ですので、日本語のトークナイズの部分を sumeval から借りて、t5 の実装に寄せる形で sacreBLEU を呼び出しコーパス単位で BLEU スコアを算出することにしました。
細かく言うと smooth_method
を sumeval の "floor"
ではなく t5.metrics.bleu()
の "exp"
にしています。日本語を対象とする場合は "floor"
が良いとかあるかもしれませんが、そこまで調べられてません。smooth_method
の "exp"
はこちらの論文29 で提案されているスムージング手法の Method 3 に相当します。先ほど説明した BLEU の計算では、1~3-GRAM でどんなに高スコアでも 4-GRAM の適合率が 0 だと最終的な BLEU スコアも 0 になってしまいます。smooth_method
はこの課題に対応する為のものです。
コードとしてはこんな感じになります。
%%bash cat <<EOF > add_snow_t15_23.py import functools import tensorflow as tf from t5.evaluation import metrics from t5.data import preprocessors from t5.data.utils import TaskRegistry from t5.data.utils import TextLineTask from sumeval.metrics.lang.lang_ja import LangJA from sacrebleu import corpus_bleu, TOKENIZERS lang_ja = LangJA() def tokenizer_ja(text): words = lang_ja.tokenize_with_preprocess(text) return " ".join(words) TOKENIZERS["ja"] = tokenizer_ja def bleu(targets, predictions): predictions = [tf.compat.as_text(x) for x in predictions] if isinstance(targets[0], list): targets = [[tf.compat.as_text(x) for x in target] for target in targets] else: targets = [tf.compat.as_text(x) for x in targets] targets = [targets] bleu_score = corpus_bleu(predictions, targets,smooth_method="exp", smooth_value=0.0, force=False,lowercase=False,tokenize="ja", use_effective_order=False) return {"bleu": bleu_score.score} task_name_snow = "snow_t15_23" snow_tsv_path = { "train": "./snow_t15_23_train.tsv", "validation": "./snow_t15_23_dev.tsv", "test": "./snow_t15_23_test.tsv", } TaskRegistry.add( task_name_snow, TextLineTask, split_to_filepattern=snow_tsv_path, text_preprocessor=[ functools.partial( preprocessors.parse_tsv, field_names=["inputs", "targets"]), ], sentencepiece_model_path="gs://somewhere/t5/sentencepiece/wikipedia_20190301_ja_v003.model", metric_fns=[bleu]) EOF
ファインチューニングの実行
livedoor ニュースコーパスの時と同様の手順でファインチューニングを実行していきます。
!gsutil cp gs://somewhere/t5/wikipedia_20190301.ja_v003/model/checkpoint gs://somewhere/t5/wikipedia_20190301.ja_v003/snow_t15_23/ !gsutil cp gs://somewhere/t5/wikipedia_20190301.ja_v003/model/model.ckpt-524288* gs://somewhere/t5/wikipedia_20190301.ja_v003/snow_t15_23/ !gsutil cp gs://somewhere/t5/wikipedia_20190301.ja_v003/model/operative_config.gin gs://somewhere/t5/wikipedia_20190301.ja_v003/snow_t15_23/
タスクの内容は異なるのですが、"Text-to-Text" と t5 ライブラリのおかげで、ほぼそのままで大丈夫です。
!export PYTHONPATH=${PYTHONPATH}:. && \ \ PRE_TRAINED_MODEL_DIR='gs://somewhere/t5/wikipedia_20190301.ja_v003/pre_trained' && \ OPERATIVE_CONFIG=$PRE_TRAINED_MODEL_DIR'/operative_config.gin' && \ FINE_TUNED_MODEL_DIR='gs://somewhere/t5/wikipedia_20190301.ja_v003/snow_t15_23' && \ FINE_TUNING_BATCH_SIZE=`expr 512 \* 8` && \ PRE_TRAINGING_STEPS=524288 && \ FINE_TUNING_STEPS=`expr $PRE_TRAINGING_STEPS + 2000` && \ INPUT_SEQ_LEN=64 &&\ TARGET_SEQ_LEN=64 &&\ \ echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\ echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\ echo "FINE_TUNING_BATCH_SIZE=$FINE_TUNING_BATCH_SIZE" &&\ echo "PRE_TRAINGING_STEPS=$PRE_TRAINGING_STEPS" &&\ echo "FINE_TUNING_STEPS=$FINE_TUNING_STEPS" && \ echo "INPUT_SEQ_LEN=$INPUT_SEQ_LEN" && \ echo "TARGET_SEQ_LEN=$TARGET_SEQ_LEN" && \ \ t5_mesh_transformer \ --model_dir="$FINE_TUNED_MODEL_DIR" \ --module_import="add_snow_t15_23" \ --gin_file="dataset.gin" \ --gin_file="$OPERATIVE_CONFIG" \ --gin_param="run.layout_rules=''" \ --gin_param="run.mesh_shape=''" \ --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \ --gin_param="MIXTURE_NAME = 'snow_t15_23'" \ --gin_file="learning_rate_schedules/constant_0_001.gin" \ --gin_param="run.train_steps=$FINE_TUNING_STEPS" \ --gin_param="run.sequence_length = {'inputs': $INPUT_SEQ_LEN, 'targets': $TARGET_SEQ_LEN}" \ --gin_param="run.save_checkpoints_steps=200" \ --gin_param="run.batch_size=('tokens_per_batch', $FINE_TUNING_BATCH_SIZE)"
livedoor ニュースコーパスの文章分類と違うのは以下のポイントです。
--module_import
で先ほど作成したadd_snow_t15_23.py
をインポート。MIXTURE_NAME
はsnow_t15_23
を指定。- やさしい日本語コーパスは文長が短いので、シーケンス長は
{ 'inputs': 64, 'targets':64 }
とした。 - 収束が速かったので、総ステップ数を 524288(+2000)ステップとし、 200ステップ毎にチェックポイントを出力した。
ファインチューニング結果の検証
検証も同様です。ただしメモリエラーが発生した為、バッチサイズを小さく ('tokens_per_batch', 64)
に設定しました。
!export PYTHONPATH=${PYTHONPATH}:. && \ \ FINE_TUNED_MODEL_DIR='gs://somewhere/t5/wikipedia_20190301.ja_v003/snow_t15_23' && \ OPERATIVE_CONFIG=$FINE_TUNED_MODEL_DIR'/operative_config.gin' && \ \ echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\ echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\ \ t5_mesh_transformer \ --model_dir="$FINE_TUNED_MODEL_DIR" \ --module_import="add_snow_t15_23" \ --gin_file="$OPERATIVE_CONFIG" \ --gin_param="run.layout_rules=''" \ --gin_param="run.mesh_shape=''" \ --gin_file="eval.gin" \ --gin_file="beam_search.gin" \ --gin_param="utils.get_variable_dtype.slice_dtype='float32'" \ --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \ --gin_param="MIXTURE_NAME = 'snow_t15_23'" \ --gin_param="run.dataset_split='validation'" \ --gin_param="run.batch_size=('tokens_per_batch', 64)" \ --gin_param="eval_checkpoint_step = 'all'" 2>&1 | tee eval.log
結果は以下のとおりです。ステップ 525488 が最良でした。
!cat eval.log | grep -e "^INFO.*bleu" # INFO:tensorflow:eval/snow_t15_23/bleu at step 524288: 7.466 # INFO:tensorflow:eval/snow_t15_23/bleu at step 524488: 76.173 # INFO:tensorflow:eval/snow_t15_23/bleu at step 524688: 78.056 # INFO:tensorflow:eval/snow_t15_23/bleu at step 524888: 78.673 # INFO:tensorflow:eval/snow_t15_23/bleu at step 525088: 78.848 # INFO:tensorflow:eval/snow_t15_23/bleu at step 525288: 78.818 # INFO:tensorflow:eval/snow_t15_23/bleu at step 525488: 78.894 # INFO:tensorflow:eval/snow_t15_23/bleu at step 525688: 78.823 # INFO:tensorflow:eval/snow_t15_23/bleu at step 525888: 78.724 # INFO:tensorflow:eval/snow_t15_23/bleu at step 526088: 78.501 # INFO:tensorflow:eval/snow_t15_23/bleu at step 526288: 78.249
テストデータでの精度検証
テストデータでの精度確認も同様です。
!export PYTHONPATH=${PYTHONPATH}:. && \ \ FINE_TUNED_MODEL_DIR='gs://somewhere/t5/wikipedia_20190301.ja_v003/snow_t15_23' && \ OPERATIVE_CONFIG=$FINE_TUNED_MODEL_DIR'/operative_config.gin' && \ \ echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\ echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\ \ t5_mesh_transformer \ --model_dir="$FINE_TUNED_MODEL_DIR" \ --module_import="add_snow_t15_23" \ --gin_file="$OPERATIVE_CONFIG" \ --gin_param="run.layout_rules=''" \ --gin_param="run.mesh_shape=''" \ --gin_file="eval.gin" \ --gin_file="beam_search.gin" \ --gin_param="utils.get_variable_dtype.slice_dtype='float32'" \ --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \ --gin_param="MIXTURE_NAME = 'snow_t15_23'" \ --gin_param="run.dataset_split='test'" \ --gin_param="run.batch_size=('tokens_per_batch', 64)" \ --gin_param="eval_checkpoint_step = 525488" 2>&1 | tee test.log
結果は以下のとおりです。BLEU スコアで 78.795 でした。
!cat test.log | grep -e "^INFO.*bleu" # INFO:tensorflow:eval/snow_t15_23/bleu at step 525488: 78.795
これだけだと「だから何だ」という話なので、原文そのままで計算した BLEU スコアの値と比較してみましょう。 今回のテストセットのデータをリストに読み込んで計算します。
exec(open("./add_snow_t15_23.py").read()) import t5 import tensorflow_datasets as tfds import tensorflow as tf snow_t15_23 = t5.data.TaskRegistry.get("snow_t15_23") ds = snow_t15_23.get_dataset(split="test", sequence_length={"inputs": 64, "targets": 64}) inputs = [] targets = [] for i, ex in enumerate(tfds.as_numpy(ds)): inputs.append(ex["inputs_plaintext"]) targets.append(ex["targets_plaintext"]) bleu(targets, inputs) # {'bleu': 66.03612333056641}
66.04 => 78.80 なので、学習結果がかなり「やさしい日本語」に近づいたことが確認できました。
次に文章平易化の指標である SARI の値も確認してみましょう。
SARI
SARI はこちらの論文30で提案された文章平易化の指標です。分かりやすく言うと以下の3つ評価の平均が SARI です。
- 入力文に含まれず参照文に含まれる単語がどの程度追加されたか。
- 入力文にも参照文にも含まれる単語がどの程度維持されたか。
- 入力文に含まれるが参照文に含まれない単語がどの程度削除されたか。
計算式で書くと以下のようになります。
先ほど単語と書きましたが正確には BLEU と同様に 1~4-GRAM を用います(数式中の k です)。また、過剰な削除が可読性を大きく損なうことを考慮して、削除(del)の場合だけ F 値ではなく適合率を使います。
padd|kepp|del, radd|keep は追加、維持、削除の適合率と再現率で、概念的には下図の濃緑+淡緑部分と濃緑部の面積比になります(実際は参照文が複数ある場合を考慮してもう少し細かい計算です)。
では SARI を計算してみましょう。ソースコードは github から入手できます。
!git clone https://github.com/cocoxu/simplification
こちらも日本語には対応していません。とりあえず、分かち書きの実装を sumeval から借りてきて以下のようにしてしまいました。
import os import sys sys.path.append(os.path.join(".", 'simplification')) from SARI import SARIsent from sumeval.metrics.lang.lang_ja import LangJA lang_ja = LangJA() class sari_string(str): def lower(self): return sari_string(super().lower()) def split(self, delim): return lang_ja.tokenize_with_preprocess(self) def sari(input, output, references): input = sari_string(input) output = sari_string(output) references = [sari_string(r) for r in references] return SARIsent(input, output, references)
SARI の計算には、入力文、出力文、参照文が必要です。まず、出力文と参照文をテストデータでの実行結果から拾ってきます。
!gsutil cp gs://somewhere/t5/wikipedia_20190301.ja_v003/snow_t15_23/test_eval/snow_t15_23_525488_predictions . !gsutil cp gs://somewhere/t5/wikipedia_20190301.ja_v003/snow_t15_23/test_eval/snow_t15_23_targets .
入力文はテストデータセットの TSV から切り出します。
!cat snow_t15_23_test.tsv | awk -F "\t" '{print $1}' > ./snow_t15_23_inputs
集めたデータをロードして、
def load(file): with open(file, "r") as f: lines = f.readlines() return [line.strip() for line in lines] inputs = load("snow_t15_23_inputs") outputs = load("snow_t15_23_525488_predictions") references = load("snow_t15_23_targets")
ようやく SARI の計算ができました。
num_examples = len(inputs) saris = [] for i, o, r in zip(inputs, outputs, references): saris.append(sari(i, o, [r])) print("SARI = %5.3f" % (sum(saris) / num_examples)) # SARI = 0.617
BLEU 同様に出力を原文そのままとした場合の値も計算してみましょう。
num_examples = len(inputs) saris = [] for i, r in zip(inputs, references): saris.append(sari(i, i, [r])) print("SARI = %5.3f" % (sum(saris) / num_examples)) # SARI = 0.257
0.257 => 0.617 なので SARI でも数値の向上が確認できました。
推論モードでの実行
推論モードで動かしてやさしい日本語変換の雰囲気を見てみましょう。以下のようなサンプルを用意して、
%%bash cat <<EOF > inputs.txt 彼らは高飛びしたらしい。 彼は分析のスペシャリストだ。 吾輩は猫である。 彼は今時のナウい若者です。 あなたをプロジェクトリーダーにアサインします。 EOF
これまでの要領で動かします。
!export PYTHONPATH=${PYTHONPATH}:. && \ \ FINE_TUNED_MODEL_DIR='gs://somewhere/t5/wikipedia_20190301.ja_v003/snow_t15_23' && \ OPERATIVE_CONFIG=$FINE_TUNED_MODEL_DIR'/operative_config.gin' && \ \ echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\ echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\ \ t5_mesh_transformer \ --model_dir="$FINE_TUNED_MODEL_DIR" \ --module_import="add_snow_t15_23" \ --gin_file="$OPERATIVE_CONFIG" \ --gin_param="run.layout_rules=''" \ --gin_param="run.mesh_shape=''" \ --gin_file="infer.gin" \ --gin_file="beam_search.gin" \ --gin_param="utils.get_variable_dtype.slice_dtype='float32'" \ --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \ --gin_param="run.batch_size=('tokens_per_batch', 64)" \ --gin_param="infer_checkpoint_step = 525488" \ --gin_param="input_filename = './inputs.txt'" \ --gin_param="output_filename = './ouputs.txt'" 2>&1 | tee infer.log
変換前と後を比べるとこんな感じになりました。「高跳び」、「スペシャリスト」、「吾輩」、「ナウい」、「アサイン」はファインチューニングの学習データに含まれていないことを確認しています。
「高跳び」、「吾輩」を間違っていますが、なかなか自然な文章に変換されていますね。ファインチューニングにない単語もそれなりに扱えているので、単純な単語=>単語の変換をしている訳ではなく、 Encoder で抽出した埋め込み表現にもとづき、 Decoder が使える単語の中から自然な並びを選択しているように感じられます。
やたらとカタカナを使ってくる人のメールにはこのモデルを適用してもいいかもしれません。
学習データ量と精度の関係
さて、それなりに雰囲気のあるモデルができることはわかりましたが、どの程度のデータ量がいるでしょうか。ファインチューニング時の学習データの件数を変えて試してみました。
図中の “original” はモデルの出力ではなく、原文そのままでの BLEU, SARI の値です。雰囲気からするとデータ量を増やすともう少し伸びるかもしれませんが、労力を考えると今回のデータでは2万件くらいが目安でしょうか。ただし今回のサンプルはかなり短いですし、30%は原文そのままの素通しなので、数人でちょっと(かなり?)頑張れば何とかなりそうですね。
4. おわりに
今回は T5 の論文で示された知見を紹介し、その実装である t5 を用いてテキスト生成の実験を行ってみました。 t5 はこなれてない部分があるものの、Mesh Tensorflow や gin による柔軟性と様々なタスクを “Text-to-Text” で解くというスタイルとも相まって使いやすいライブラリになっていると思います。実はまだまだ書き足りないので次回も t5 です。良さそうなモデルができたらやっぱりデプロイしてアプリから使いたいですよね。そのあたりの解説をしたいと思います。
-
https://github.com/tensorflow/mesh/blob/c81d299092f432a9be8bf6e6dea8674fed2d84de/mesh_tensorflow/transformer/transformer_layers.py#L223 ↩
-
https://github.com/tensorflow/mesh/blob/c81d299092f432a9be8bf6e6dea8674fed2d84de/mesh_tensorflow/transformer/attention.py#L68 ↩
-
論文の Table.2 を参照 ↩
-
論文の Table.4, 5, 6, 7 を参照 ↩
-
論文の 2.2 を参照 ↩
-
SuperGLUEには小説を対象とするタスク(MultiRC)、ニュース記事を対象とするタスク(ReCoRD)があり、SQuAD は Wikipedia から抽出されたデータセットです。 ↩
-
論文の Table.10 を参照 ↩
-
論文の Table.11 を参照 ↩
-
https://github.com/google-research/text-to-text-transfer-transformer ↩
-
でも “Text-to-Text” なら小文字化はしないほうがよかった? “iPhone” とか出力できなくなるし。というか、やるなら事前学習の時も小文字化しないと。あと全角半角とか。。。 ↩
-
もう少しわかりやすいとこに書いておいてほしいですね。 https://github.com/google-research/text-to-text-transfer-transformer/blob/835e82043b9beab809391fddc1f5d718f73a1280/t5/data/sentencepiece_vocabulary.py#L34 ↩
-
なぜこうしたのかは忘れてしまいましたが、Free TPU を使用しようとすると、こうしないと動かなかったような。。。 ↩
-
“2x2"が8コアというのはわかりにくいですが、ソースを確認するとそうなってます。。。https://github.com/tensorflow/mesh/blob/c81d299092f432a9be8bf6e6dea8674fed2d84de/mesh_tensorflow/transformer/utils.py#L204 ↩
-
Decoder における系列生成のシンプルな方法は最も確率が高いトークンを選ぶことですが、ビームサーチでは各ステップにおいて、出力先頭からの累積生成確率が上位K件の系列を考慮します。K件の系列全てについて次ステップにおける全語彙の選択確率を計算し、次ステップにおける累積生成確率が上位のK件の系列を選択することによって、ある特定ステップの計算結果が全体に与える影響を抑えることができます。ようは探索幅を各ステップの上位K件に限定した幅優先探索です。 ↩
-
二つの文字列がどのくらい異なるかの尺度で編集距離とも呼ばれます。片方の文字列に文字の追加、削除、置換の操作を何回行えば二つの文字列が一致するかを意味します。操作に隣接文字の交換を加えたダメラウ・レーベンシュタイン距離というのもあります。 ↩
-
https://cocoxu.github.io/publications/tacl2016-smt-simplification.pdf ↩