オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第17回 ByT5 と Charformer の検証
オージス総研 技術部 データエンジニアリングセンター
鵜野 和也
2021年10月20日

トークナイザを使わない自然言語処理モデルである ByT5 と Charformer のご紹介です。従来の自然言語処理では多くの場合で文章を単語(あるいはサブワード)単位に分かち書きして処理しましたが、今回のモデルは直接、生のテキストを処理します。それでは実際に動かして単語(サブワード)ベースのモデルと比較してみましょう。

1. はじめに

今回は今年5月と6月に発表された ByT51 と Charformer2 の紹介をしたいと思います。一本の記事で 2 つのモデルを扱うのは、この連載では珍しいのですが、この二つはよく似ているというか、Charformer は 「ByT5 にもう一工夫加えたもの」くらいの認識なので、一度にさばいてしまいましょうということで。

さて、この二つのモデルの特徴ですが「分かち書きをしない」という点に尽きます。

今まで、この連載では BERT や T5 等の Transformer ベースのモデルを紹介してきましたが、どれも入力テキストを MeCab や Sentencepiece 等で有限かつ固定の語彙に含まれる単語(正確には更に細かいサブワード)単位のトークンに分割し、そのトークンの ID をモデルに投入するという流れでした。

この「有限かつ固定の語彙」ですが、この連載では公開されているコードのデフォルトの語彙数をそのまま使っており、その値は 32000 です。 この 32000 が認識できる単語(サブワード)の全てなのですが、日本語は使用する文字の種数が多い言語であり日常使用のレベルで軽く 2000 文字を超え、漢字検定1級レベルだと約 6000 文字になります3 。正直、32000 では足りません。

ですが、表示したい全ての文字を語彙に収録してしまうと、サブワードに使える空間を圧迫して性能が落ちそうです。かと言って語彙を 32000 から増やすと推論は難しく重くなっていきます。

仕方がないので、Sentencepiece の学習時に --vocab_size=32000 --character_coverage=0.9995 としてコーパス中に出現した文字の 0.0005 % は諦めてました。推論時にこの諦めた文字が入力されると “<UNK>” (unkown) トークンとして処理されることになります。

下流のタスクが文章分類だと入力文の一部が “<UNK>” になっても出力は分類されたクラスのインデックスになるので、さほど気にならないのですが、書き換え等のテキスト生成系のタスクだと出力に “⁇”(U+2047) 4が混じるので目立つんですよね。。。

そんな時に ByT5 の論文が出て、試してみようかとなった次第です。では ByT5 について見ていきましょう。

2. ByT5

ByT5 は論文には書いてありませんが “ByTe level T5” の略なんだと思います。 もう名前で何をやっているか察しがついてしまいそうですね。 以下の図を見てください。

byt5

左側は mT54 で T5 の多言語バージョンです。多言語のトークンを扱う為に語彙数が拡大5している以外は普通の T5 だったと思います。 英語に日本語が混在した文章が Sentencepiece で分かち書きされて、トークン ID の系列に変換されています。 Inputs の “<X>”, “<Y>” が伏字にしたところで、Targets は “<X>”, “<Y>” の穴埋めですね。

右側が ByT5 です。図中に “UTF-8 Encode” とあるように、テキストを UTF-8 で表現されたバイト列だと見立て、各バイトの値 (0 ~ 255) をトークンID の系列の代わりに使う訳です。基本的にはこれだけです。

出力時も通常の T5 と同様に decoder を一回実行する毎に 1 トークンずつ自己再帰的に生成します。ただし ByT5 のトークンはバイト単位なので、 生成されたバイト列が UTF-8 的に不正なものになるケースもあります。その場合は単純に不正バイトを無視して対応します。

通常の T5 では 32000 だった語彙数が 256 (0 ~ 255) になって、どんな言語のどんな文字の並びでも( UTF-8 で表現できるならば)扱うことができ、原理的に“<UNK>"が発生することがありません

入力をバイト列とすることには、もう一つ、ノイズに強いという利点があります。これはイメージで見てもらったほうが分かりやすいですね。

以下の例は "The sky is beautiful” を 1 文字だけ typo して “The sky is besutiful” になってしまった場合に、どれだけ元のトークン ID 系列から変わってしまうかと?という例です。

robust_to_noise

サブワードの方がトークンの単位が粗い(=トークンがよりリッチな情報を保持している)のですが、 1 文字の typo で “beautiful” が全く異なる 3 つのトークン ["_be", "su", "tiful"] に代わってしまいました。これらに元々の意味合いの面影は残ってなさそうです。 UTF-8 の方は typo した “s” のトークンは変わっていますが、それ以外は元のままなので、こちらの方が情報を維持出来てそうです。

とは言っても、これは英語の場合です。日本語だと UTF-8 では大体 1 文字 = 3 bytes なので 3 トークン入れ替わってしまいます。なのでノイズに対する頑健性という意味では日本語だとやや不利になりそうですね。

実装上の細かい話

あと、ここまで語彙数が256と書いてきましたが実装上は実はちょっと違います。 特殊トークンとして PAD=0, EOS=1, UNK=2 として、 3 トークン予約されており、実際のトークンID は UTF-8 で表現したバイト + 3 になります。 実装上の語彙数も正確には 256 + 3 の 259 になります。以下に該当部分のコードを引用します。

#Quoted from https://github.com/google/seqio/blob/31da3912ad190dc8686b5527fe0e2bff777a6855/seqio/vocabularies.py#L349-L435
349|class ByteVocabulary(Vocabulary):
350|
351|  """Byte-level vocabulary.
352|  Encode/decode text directly to 256 "byte IDs" using UTF-8 encoding. Three
353|  special IDs are reserved (0=padding, 1=EOS, 2=UNK), so our encoded byte IDs
354|  are +3 greater than UTF-8 byte values.
355|
356|  This is the vocabulary used by the ByT5 models:
358|  https://arxiv.org/abs/2105.13626
358|  """
359|
360|  def __init__(self, extra_ids: int = 0):
...|
369|    self._byte_size = 256
370|    # The special tokens: 0=PAD, 1=EOS,and 2=UNK
371|    self._num_special_tokens = 3
372|    super().__init__(extra_ids=extra_ids)
...|
382|  def _convert_strings_to_ids(self, s):
...|
390|    return list(s.encode("utf-8"))
...|
415|  def _base_vocab_size(self):
...|
421|    return self._num_special_tokens + self._byte_size
422|
423|  def _encode(self, s):
...|
434|    ids = self._convert_strings_to_ids(s)
435|    return [i + self._num_special_tokens for i in ids]

上図の左側を見たら “In Japan” が “73 110 32 74 97 112 97 110” になっていて、

[b for b in "In Japan".encode("utf-8")]
# [73, 110, 32, 74, 97, 112, 97, 110] 

なのですっかり騙されました(?)が、実際には + 3 された値で encoder に投入されます。

次に構造的な特徴にも触れておきましょう。

構造的な特徴

構造的な違いとして、T5 や mT5 では encoder と decoder の層数は同じだったのですが、 ByT5 は encoder を深く、 decoder を浅くした構造を採っています。 先ほどの図の右下にも “Heavy Encoder”, “Light Decoder” とありますね。

論文の Table. 1 に比較表があります。encoder と decoder の層数の合計は mT5 と ByT5 で同じですが、 ByT5 は encoder と decoder の層数の比が 3 : 1 になっています。

archtecture

encoder が深い理由ですが、トークンをバイト単位と細切れにすることで、サブワード単位のトークンが保持していたリッチな情報が失われています。この失われた情報を encoder を深くすることで頑張って取り戻そうとしているようです。

次に decoder が浅くなっている理由ですが、出力テキストを生成する際は 1 トークン出力する毎に 1 回 decoder を実行することになります。トークンの粒度がサブワードからバイトに細かくなったことで、同じ長さの文字列でも decoder の実行回数は増えることになり、その分キャパシティを絞ることができるようです。選択する語彙数も小さくなっていますしね6

パラメータ数にも触れておきましょう。 ByT5 も Small, Base, Large, XL, XXL とお馴染みのフレーバーが用意されていますが、 Transformer の次元数はだいぶ変わっています。例えば Base で比べてみると、dmodel は 768 ⇒ 1536 に、dff は 2048 ⇒ 3968 になっています。

これは同じフレーバー同士で ByT5 のパラメータ数を mT5 に合わせて設定したことに起因します。 mT5 は通常の T5 に比べて語彙数が多い分、トークンの埋め込み表現に使用されるパラメータが多くなります。この埋め込み表現のパラメータが ByT5 では語彙数が 256 になることで大幅に削減されます。そこで浮いた分のパラメータを dmodel, dff といった Transformer 内部の次元数拡大に回した形になっています。

mT5 でトークンの埋め込み表現に使われていたパラメータの役割は、Transformer の入り口でトークン ID に対応した埋め込み表現ベクトルを拾い、出口で語彙選択の softmax をするだけだったのですが、 ByT5 で dmodel, dff に回されたパラメータはがっつり計算に使われますので処理としては重くなります。おまけにバイト単位にしたことでシーケンス長も長くなります。同じ長さの文字列を処理することを考えると ByT5 は mT5 よりも重くなっていると言えるでしょう。

どの程度重くなったか、論文の Table.10 に推論時間の比較が出ています。

inference_time

表の単位は秒で左の XNLI zeroshot は分類タスク、右の GEM-XSum は生成タスクになります。XNLI zeroshot の方は x1.1 ~ 2.0 で収まっていますが、 GEM-XSum の方は x1.5 ~ 7.0 と速度低下が大きくなっています。これは XNLI zeroshot は分類タスクなので “0”, “1”, “2” と言ったクラスの ID を 吐くだけ(decoder は1回実行でOK)なのに対し、 GEM-XSum は生成タスクである為、 decorder をぐるぐる回って 1 バイトずつ出力テキストを生成するからでしょう。

この GEM-XSum ですが英語のタスクになります。英語なら 1 文字を decoder 1 回で出力できますが、日本語は 1 文字で decoder を 3 回通す必要があります。 もちろん “beautiful” と “美しい” のようなケースもあるのですが、日本語がなんとなく不利な気がしますね。。。

ちょうど論文に言語別で 1 トークンが平均何バイトになるかという図がありました。

avg_byte_size_of_token_per_lang

英語で 4 バイト弱、日本語で約 5 バイトというところでしょうか。やはり日本語は少し不利なようです。 decoder が浅くなっているとはいえ、サブワードベースのモデルに比べて decoder を 5 倍回す必要があるのはちょっとツライ気がします。。。

と、思っていたところで Charformer を知りました。 こちらも T5 にバイト単位のシーケンスを入力する点は同じですが更に工夫がされているようです。 Charformer の論文ではバイト単位やサブワード単位の Transformer に対して 28-100% 高速と記載されています。ホントでしょうか。

次は Charformer を見てみましょう。

3. Charformer

ByT5 は「分かち書き」を廃して Transformer にバイト列をそのまま投入しました。Charformer は「分かち書き」に相当する処理をニューラルネットに組み込んで end-to-end で学習するというアプローチになります。下図の Gradient Based Subword Tokenization(GBST) が「分かち書き」に相当する部分です。

charformer_archtecture

それでは GBST の仕組みについて確認しておきましょう。

GBST (Gradient Based Subword Tokenization)

まず、 GBST への入力 X ∈ ℝL×d は入力するバイト列に対応する埋め込み表現のシーケンスを1次元畳み込みでスムージングしたものになります。 L がバイト長, d が各バイト値に対応する埋め込み表現の次元数です。

GBST

次に上図左側のように X を異なる長さ(上図では 1 ~ 4)の固定長ブロックに切り分けます。X の i バイト目を先頭とするブロックを Xi:i+b ∈ ℝb×d とすると b がブロック長で 0 ≦ b ≦ L - b です。青線で囲まれたのがブロックですね7

ブロック長 b における潜在的なサブワード表現のシーケンス Xb は以下のとおりです。

Xb

  • F はプーリング関数で F : ℝb×d → ℝd です。論文によると mean pooling を使用したそうです。
  • s はブロック形成時のストライドです。実際は s = b でブロックの重なりが無いように処理します(その代わり最初に1次元畳み込みしてブロック跨りの情報をカバーしています)。
  • ブロック Xi:i+b∈ℝb×d に F を適用したものを、ブロックサイズ b で, i ~ i+b バイト目における潜在的なサブワード表現 Xb,i∈ ℝdとします。

はやい話が上図左側の青枠を各々横(ブロック長)方向に平均(mean pooling)してくしゃっと潰して、青枠 1 つを d 次元ベクトルにします。

ここから上図右側の話です。最大ブロックサイズを M として b ∈ 1, …, M で Xb を計算し、 全ブロックサイズで Xb の長さが L になるようにアップサンプリングして揃えた上で、 Xb,i (アップサンプリング後の Xb の i 番目の d 次元ベクトルですね)のスコアを pb,i = FR(Xb,i) とします。

  • FR は Block Scoring Network で FR : ℝd → ℝ です。単純な線形写像が使われます。

関数 F によるダウンサンプリングで短くなった Xb をアップサンプリングで元の長さに戻しているので、 pb,i はちょうど元の X の i バイト目のスコアに相当します。

ここまできたら、入力シーケンスの i バイト目の各ブロックサイズにおける重み Pi を以下のように計算します。上図右側の青枠ですね8

Pi

Pi で i バイト目における各ブロックサイズの重みが得られたので、加重和をとって Xi^ とします。

X^i

最後にダウンサンプリング関数 FD を X^ に適用して長さを縮め、潜在的サブワード表現のシーケンス X˜ を得ます。

X~

  • FD : ℝL×d → ℝL/ds×d にも mean pooling が使われています。
  • ds はハイパーパラメータです。 論文では 2 ~ 4 あたりが使われています。

ようは上図左側の青枠をくしゃっと潰したら、ブロックサイズ間で長さ不揃いになるので、アップサンプリングで長さ L に揃え、i バイト目の表現をブロックサイズ跨り(上図右側では縦方向)の softmax と加重和で算出、出来上がった長さ L の潜在表現を ds 個単位で平均して潰し、 L / ds の長さに短くするだけです。

encoder に入力するシーケンス長は 日本語だと ds = 3 を使ったとしてもシーケンス長≒文字数になります。バイト列よりは短くなりましたが、サブワードベースのモデルよりは大分長くなってしまいますね。ところで decoder はどうなるんでしょう?

decoder 側の扱い

decorder 側は特に何も書いてないというか。

encoder は潜在的なサブワード表現で OK なんですが、 decoder はサブワードなり、文字なり、バイトなりの語彙から出力するトークンを選ばないといけません。選ぶには各語彙の埋め込み表現が必要な訳で、使えそうなのは GBST の入力に使った各バイト値に対応する埋め込み表現だけですね。

なので、 decoder 側は ByT5 と同じですかね。1 バイトずつの出力になるんでしょう。推論速度を考えると、ぐるぐる回る decoder の方が影響度大きそうなので、ByT5 に比べてそんなに速くなりそうな気がしないというか。。。

構造的な話

Charformer も Small, Base, Tall と複数のフレーバーで実験されています。論文の Table. 1 に比較表がありますね。

table1

この表を見ると、同じフレーバであればバイト単位モデルと遜色ない精度がでています。サブワード単位と比べてもさほど大きく見劣りはしないようですね。

Charformer の各フレーバーについて簡単に説明しておきましょう。

  • Small :
    T5 1.0 Small に準じた構成9。encoder, decoder は共に6層。
  • Base :
    T5 1.0 Base に準じた構成10。encoder, decoder は共に12層。
  • Tall :
    CharformerSmall の構成で encoder を 24 層、decoder を 6 層の設定にしたもの。また活性化関数が GEGLU に変更されてます11

あと、表中の “Byte-level T5” は ByT5 のことなのですが、ByT5Small, Base のように mT5 のパラメータ数に合わせる調整が入っていません。なので、 Byte-level T5Base と ByT5Base ではだいぶパラメータ数が違うことに注意してください。また、ByT5Small, Base は活性化関数に GEGLU を使った T5 1.1 をベースにしています。

あと、速度的なところを確認しておきましょう。

速度的な話

論文のアブストラクトでは Charformer はバイト単位やサブワード単位の Transformer に対して 28-100% 高速だと記載されています。 論文の Table. 5 に mT5 との速度比較がありました。

table5

CharformerTall を mT5Base と比べると 1.98 / 1.54 ≒ 1.28 なので 28% 高速ということみたいですね。バッチサイズが同じ、入力シーケンス長が 2 倍、それをダウンサンプリングで ½ にするので Transformer が処理するトークン数は同じになります。ですが、トークンの単位がバイトとサブワードなので注意が必要です。平均すると大体 1 サブワード ≒ 4.1 バイトだそうですので、処理した文章の量としては CharformerTall は mT5Base の半分くらいになります。

CharformerTall, Long PT はバッチサイズが mT5 の 2 倍ですから、おおよそ処理した文章の量を合わせた比較になるかと思います。mT5 に対して 1.01 / 1.54 ≒ 0.65 になっちゃいました。。。

でも最大で 100% 高速だという記述もあるので、もう少し見てみましょう。論文の Table. 6 にさらに比較があります。

table6

ds = 3 の CharformerTall と T5Base を比べると 20 / 9.3 ≒ 2.15 なので 115% 高速になるでしょうか。 CharformerTall のシーケンス長 1024 を 4.1 で割ると 249.76 トークン相当になります。T5Base のシーケンス長は 512 トークンなので、同じ文章を処理した場合で比較すると 20 * 249.76 / 512 / 9.3 ≒ 1.05 です。速く…なりましたか。

ですが、上記の速度は事前学習での処理ステップ/秒です。事前学習では Teacher forcing になるので 1 サンプルを処理するときに docoder を 1 回実行するだけですみます。推論時は decoder を自己再帰で 1 バイトずつぐるぐる回すのでサブワード単位よりも速いということはないでしょうね。

何かテンション下がり気味になってますが、 ByT5 より速く、それなりに精度が確保できて、"<UNK>“が出ないなら、まぁいいかという感じですね。

それでは実際に動かしてみましょう。

4. ByT5 の事前学習

いつものように、記事内のコードスニペットは、特に断りがない場合は Google Colaboratory (以下、Colab)で動かす想定にしています。ノートブックを開き、アクセラレータは TPU を選んで下さい。

環境のセットアップ

Tensorflow は 2.x 系を使います。

%tensorflow_version 2.x 

今回は以下のバージョンを使っています。

!pip install t5[gcp]==0.9.1 mesh-tensorflow==0.1.19 tensorflow-datasets==4.2.0

ByT5 のコードをクローンして、

!git clone https://github.com/google-research/byt5
!cd byt5 && git checkout 069d3cbc1cdfeb7e6
# ...
# HEAD is now at 069d3cb Add GEM-XSum tasks.

mT5 のコードに依存する部分があるので、 mT5 もクローンします。

!git clone https://github.com/google-research/multilingual-t5
!cd multilingual-t5 && git checkout 35f723155f
# ...
# HEAD is now at 35f7231 Add GEM-XSum, GLUE, SuperGLUE tasks, and task test.

ちなみにチェックアウトしているコミットは「これで試したら動きました」というだけで深い意味はありません。

次は事前学習タスクの登録です。

事前学習タスクの登録

とりあえず ByT5 と mT5 を Python のモジュール検索パスに追加して、

import sys
sys.path.append("./byt5")
import byt5.metrics as byt5_metrics   # おまじない
sys.path.append("./multilingual-t5")

妙な import が混じってますがw、「これがないとエラーになった。これ入れたらとりあえず動いた。」という話です。 ちゃんと調べれば原因はわかるんでしょうが、動いたからいいやで進めちゃいます。

ByT5 のコードに含まれるタスクを登録しておきましょう。

%run ./byt5/byt5/tasks.py

次は事前学習に使う MIXTURE を定義します。

ちなみに MIXTURE は複数の学習タスクに混合比を設定した定義体のことです。 T5 は全部 text to text の学習なのでマルチタスク学習が簡単ですね。

ByT5 の README.md12 では MIXTURE に byt5_mc4 を指定するように記述されていますが、コードを見ると byt5_mc4 は 多言語コーパスの mC413 を使った各言語毎の学習をマルチタスクにしたものになってますね14

ですが、 tensorflow datasets の mC4 は wikipedia のように公開された GCS を気軽につつける訳ではなく、それなりにセットアップが必要です。 しかも巨大なので学習も大変。README.md に v3-256 の TPU で 1000000 ステップとか書いてあります。。。。ムリ。

ですので、今回は以前に紹介した T5 とかと比較する意味もあり、日本語 Wikipedia で学習することにしました。 日本語 Wikipedia も ByT5 のコード中にタスクが定義されているので、それを拾って新たな MIXTURE として登録するだけで OK です。

import t5
t5.data.MixtureRegistry.add("byt5_wiki_ja", ["byt5_wiki.ja"], default_rate=DEFAULT_MIX_RATE)

話がちょっとそれますが、バイト単位でシーケンスを T5 に入出力するという ByT5 の核心の部分は以下になります。

# copied from https://github.com/google-research/byt5/blob/4066a7270d6cc1cbc04fe30dd7762275bacb22c3/byt5/tasks.py#L36-L39
36: DEFAULT_BYTE_OUTPUT_FEATURES = {
37:     "inputs": t5.data.Feature(vocabulary=t5.data.ByteVocabulary()),
38:     "targets": t5.data.Feature(vocabulary=t5.data.ByteVocabulary())
39: }

通常の T5 では SentencePieceVocabulary が使われる部分を ByteVocabulary に置き換えているだけですね。

では、本題に戻りましょう。

事前学習の実行

まず Colab の TPU のアドレスを確認します。

import os
import pprint
import json
import tensorflow.compat.v1 as tf

assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print('TPU address is', TPU_ADDRESS)

# TPU address is grpc://xx.xxx.xx.xxx:xxxx

学習のハイパーパラメータを定義していきましょう。

from t5.models.mesh_transformer_main import FLAGS
from t5.models.mesh_transformer_main import main
FLAGS.mark_as_parsed()

README.md にしたがって、以下の設定としたところ Colab の TPU (v2-8) では OOM エラーになってしまいました。

  • 入力シーケンス長 : 1024
  • 学習ステップ数 : 1000000
  • バッチサイズ : 220 トークン (1048576)

妥協して以下の設定にしています。

  • 入力シーケンス長 : 1024 (論文のまま)
  • 学習ステップ数 : 524288 (第7回の T5 の学習に合わせました)
  • バッチサイズ : 217 トークン (131072 v2-8 での学習時間も考慮し T5 の学習時の 2 倍にした) 

ですが、動かすとこれでも OOM になりました。でも、出来るだけバッチは大きくしたいです。。。 調べてみると gradients accumulation が実装されているから、それを使えば大丈夫そうです15

バッチをマイクロバッチに小分けにして複数回で実行し、複数回分の勾配を累積してパラメータ更新することで、 疑似的にTPU に収まらない大きさのバッチサイズを実行できるという機能ですね。

初めて使うのでちょっと試してみましょう。こんな感じの計算になるはずです。

tokens_per_microbatch_per_replica = 4096                                 # 1 TPUコアが担当するマイクロバッチのトークン数を 4096 とすると。。。
tokens_per_batch = 131072                                                # バッチサイズが 131072 で
input_seq_len = 1024                                                     # 入力シーケンスが 1024 で
num_tpu_cores = 8                                                        # v2-8 は 8コア だから

batch_size = tokens_per_batch // input_seq_len                           
batch_per_replica = batch_size // num_tpu_cores                           
microbatch_size = tokens_per_microbatch_per_replica // input_seq_len     
num_microbatches = batch_per_replica // microbatch_size                  
print("batch_size = {}".format(batch_size))                             
print("batch_per_replica = {}".format(batch_per_replica))                
print("microbatch_size = {}".format(microbatch_size))                    
print("num_microbatches = {}".format(num_microbatches))                  

# batch_size = 128                                                         # バッチサイズは 128 (131072/1024)
# batch_per_replica = 16                                                   # 1 TPUコア辺りのバッチサイズは 16 になる
# microbatch_size = 4                                                      # 1 TPUコアが担当するマイクロバッチサイズは 4 (4096/1024)
# num_microbatches = 4                                                     # なので集約するマイクロバッチの数は 4 (16 / 4)

確認してみましょう。関数呼び出しに必要な変数を定義して、実行してみます。

from mesh_tensorflow.transformer.utils import serialize_num_microbatches, tpu_mesh_shape
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf

batch_dim = mtf.Dimension(name='batch', size=131072 // 1024)
sequence_length = {'inputs': 1024, 'targets': 189} 
mesh_shape = tpu_mesh_shape(tpu_topology="v2-8", model_parallelism=1, ensemble_parallelism=None)
layout_rules = mtf.LayoutRules({('heads', 'model'), ('ensemble', 'ensemble'), ('batch', 'batch'), ('experts', 'batch'), ('vocab', 'model'), ('d_ff', 'model')})


serialize_num_microbatches(batch_dim = batch_dim, 
                           sequence_length = sequence_length, 
                           mesh_shape = mesh_shape,
                           layout_rules = layout_rules,
                           tokens_per_microbatch_per_replica=4096)

# INFO:tensorflow:serialize_num_microbatches: tokens_per_microbatch_per_replica=4096 batch_dim=Dimension(name='batch', size=128) sequence_length={'inputs': 1024, 'targets': 189} batch_per_replica=16 num_microbatches=4
# 4

試算した通り 4 になりました。大きなバッチを 4 回に分割して実行するんですね。これなら動きそうです。結局、こんな感じの設定になりました。

FLAGS.tpu = TPU_ADDRESS
FLAGS.model_dir = 'gs://somewhere/byt5/wiki_ja/byt5.base'
tf.flags.FLAGS.gin_file=[
  "dataset.gin",                        
  "./byt5/byt5/gin/models/byt5.base.gin", 
  "learning_rate_schedules/rsqrt_no_ramp_down.gin"
]
tf.flags.FLAGS.gin_param=[
  "utils.tpu_mesh_shape.model_parallelism = 1",
  "utils.tpu_mesh_shape.tpu_topology = 'v2-8'",
  "utils.run.sequence_length = {'inputs': 1024, 'targets': 189}",
  "run.batch_size = ('tokens_per_batch', 131072)",
  "run.train_steps = 524288",
  "run.save_checkpoints_steps = 1000",
  "serialize_num_microbatches.tokens_per_microbatch_per_replica = 4096", 
  "MIXTURE_NAME = 'byt5_wiki_ja'"
]
  • ByT5Baseを作ることにしたので、byt5.base.gin を指定します。
  • gradients accumulation は serialize_num_microbatches.tokens_per_microbatch_per_replica で指定します。

さて、ようやく実行です。

tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.INFO)
main([])

パラメータ数は以下のとおり、

INFO:tensorflow:Trainable Variables            count: 197     Total size: 581653248        Total slice_size: 581653248      
INFO:tensorflow:All Variables                  count: 212     Total size: 582416640        Total slice_size: 582416640      

実行速度はこんな感じです。

INFO:tensorflow:global_step/sec: 0.262902
INFO:tensorflow:examples/sec: 33.6515

前回と同様の手順で検証してみると以下のような曲線になりました。

nlppl_curve

ではファインチューニングをしてみましょう。

5. ByT5 のファインチューニング

ファインチューニングの仕方も通常の T5 と大きな違いはありません。

新しくノートブックを開き、アクセラレータに GPU を選んだら、前章の環境のセットアップに従ってセットアップして下さい。

ファインチューニングのデータは前回も使った、やさしい日本語データセットです。加工の仕方は第14回第7回を参考にして下さい。

加工済みの学習データはこんな感じでしたね。日本語の短い文章を2000語程度の限定された語彙で言い換えるタスクになっています。

!wc -l *.tsv
#   7040 snow_t15_23_dev.tsv
#   7040 snow_t15_23_test.tsv
#  56317 snow_t15_23_train.tsv
#  70397 total

!head -5 snow_t15_23_train.tsv
# では今晩またね、さようなら。    では今日の夜にまたね、さようなら。
# 夜のこんな時間に電話をかけるものではない。   夜のこんな時間に電話をかけるものではない。
# 愛とは夢にまで彼女を見ることだ。  愛とは夢にまで彼女を見ることだ。
# そのリンゴの木はよく実がなる。 そのりんごの木はよく実がなる。
# この馬は手に負えない。 この馬は扱うことができない。

ファインチューニングに使うタスクを定義します。

%%bash
cat <<EOF > add_byte_snow.py
import functools
import tensorflow as tf
from t5.evaluation import metrics
from t5.data import preprocessors
from seqio import vocabularies
from seqio import Feature
from t5.data.dataset_providers import TaskRegistry
from t5.data.dataset_providers import TextLineTask
from sacrebleu import corpus_bleu
import t5.data

class BatchedByteVocabulary(t5.data.ByteVocabulary):
  def batch_decode(self, batched_ids):
    if len(batched_ids.shape) > 1:
      batched_strings = [self.decode(batched_ids[i]) for i in range(len(batched_ids))]
      return  tf.constant(batched_strings)
    else:
      return tf.constant(self.decode(batched_ids))

  def _decode_tf(self, ids):
    return tf.py_function(func=self.batch_decode, inp=[ids], Tout=tf.string)

BYTE_OUTPUT_FEATURES = {
    "inputs": t5.data.Feature(vocabulary=BatchedByteVocabulary()),
    "targets": t5.data.Feature(vocabulary=BatchedByteVocabulary())
}

def bleu(targets, predictions):
  predictions = [tf.compat.as_text(x) for x in predictions]
  if isinstance(targets[0], list):
    targets = [[tf.compat.as_text(x) for x in target] for target in targets]
  else:
    targets = [tf.compat.as_text(x) for x in targets]
    targets = [targets]

  bleu_score = corpus_bleu(predictions, targets,
                                     smooth_method="exp",
                                     smooth_value=0.0,
                                     force=False,
                                     lowercase=False,
                                     tokenize="ja-mecab",
                                     use_effective_order=False)
  return {"bleu": bleu_score.score}

task_name = "snow_t15_23"

tsv_path = {
    "train": "./snow_t15_23_train.tsv",
    "validation": "./snow_t15_23_dev.tsv",
    "test": "./snow_t15_23_test.tsv",
}

TaskRegistry.add(
    task_name,
    TextLineTask,
    split_to_filepattern=tsv_path,
    text_preprocessor=[
      functools.partial(
          preprocessors.parse_tsv,
          field_names=["inputs", "targets"]),
    ],
    output_features = BYTE_OUTPUT_FEATURES,
    metric_fns=[bleu])
EOF

ポイントは BYTE_OUTPUT_FEATURES の定義で自前の BatchedByteVocabulary を使っているところです。

なんでこんなクラスを作ることになっているかというと、検証時に Vocaburarydecode_tf() にトークンID系列のバッチを投入するんですが、ByteVocabularydecode_tf() がバッチに対応できておらず、エラーになってしまって。。。16

事前学習済みのチェックポイントをコピーして、

!gsutil cp gs://somewhere/byt5/wiki_ja/byt5.base/checkpoint gs://somewhere/byt5/wiki_ja/byt5.base/snow/
!gsutil gs://somewhere/byt5/wiki_ja/byt5.base/ope* gs://somewhere/byt5/wiki_ja/byt5.base/snow/
!gsutil gs://somewhere/byt5/wiki_ja/byt5.base/model.ckpt-524288* gs://somewhere/byt5/wiki_ja/byt5.base/snow/

!gsutil ls gs://somewhere/byt5/wiki_ja/byt5.base/snow/
# checkpoint                   model.ckpt-524288.index
# model.ckpt-524288.data-00000-of-00002  model.ckpt-524288.meta
# model.ckpt-524288.data-00001-of-00002  operative_config.gin

以下のようにして実行しました。

!export PYTHONPATH=${PYTHONPATH}:.:./byt5:./multilingual-t5 && \
  \
  OPERATIVE_CONFIG='gs://somewhere/byt5/wiki_ja/byt5.base/snow/operative_config.gin' && \
  FINE_TUNED_MODEL_DIR='gs://somewhere/byt5/wiki_ja/byt5.base/snow' && \
  FINE_TUNING_BATCH_SIZE=`expr 120 \* 8` && \
  PRE_TRAINGING_STEPS=524288 && \
  FINE_TUNING_STEPS=`expr $PRE_TRAINGING_STEPS + 10000` && \
  INPUT_SEQ_LEN=120 &&\
  TARGET_SEQ_LEN=120 &&\
  \
  echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\
  echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\
  echo "FINE_TUNING_BATCH_SIZE=$FINE_TUNING_BATCH_SIZE" &&\
  echo "PRE_TRAINGING_STEPS=$PRE_TRAINGING_STEPS" &&\
  echo "FINE_TUNING_STEPS=$FINE_TUNING_STEPS" && \
  echo "INPUT_SEQ_LEN=$INPUT_SEQ_LEN" && \
  echo "TARGET_SEQ_LEN=$TARGET_SEQ_LEN" && \
  \
  t5_mesh_transformer \
  --model_dir="$FINE_TUNED_MODEL_DIR" \
  --module_import="add_byte_snow" \
  --module_import="byt5.tasks" \
  --gin_file="dataset.gin" \
  --gin_file="$OPERATIVE_CONFIG" \
  --gin_param="run.layout_rules=''" \
  --gin_param="run.mesh_shape=''" \
  --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \
  --gin_param="MIXTURE_NAME = 'snow_t15_23'" \
  --gin_file="learning_rate_schedules/constant_0_001.gin" \
  --gin_param="run.train_steps=$FINE_TUNING_STEPS" \
  --gin_param="run.sequence_length = {'inputs': $INPUT_SEQ_LEN, 'targets': $TARGET_SEQ_LEN}" \
  --gin_param="run.save_checkpoints_steps=1000" \
  --gin_param='dropout_rate=0.1' \
  --gin_param="run.batch_size=('tokens_per_batch', $FINE_TUNING_BATCH_SIZE)"  \
  --gin_param="Bitransformer.decode.max_decode_length = 120" 

上記の内容は ByT5 の README.md12 にある記載と以下の部分を変えています。

  • gin_filelearning_rate_schedules/constant_0_001.gin を追加。
    事前学習の続きをする感じなので、これがないと学習レートが小さくなりすぎる気がしたので。
  • Bitransformer.decode.max_decode_length = 120--gin_param で設定。
    書いてあるとおりに --eval_gin_param を指定したら「そんなん知らん」的なエラーになったので。

eval.gin を使った検証もちゃんと動きます。

ただ、綺麗に動いた風のコードサンプルになっていますが最近の Colab は GPU が混んでるみたいで、 以下のコードは 1 回で全部回りきらない可能性が高いです。動いた分だけログから拾うなり、GCP でインスタンス立てて課金するなり、 Colab Pro (日本でも使えるようになりました!)にするなりして下さい。

!export PYTHONPATH=${PYTHONPATH}:.:./byt5:./multilingual-t5 && \
  \
  FINE_TUNED_MODEL_DIR='gs://somewhere/byt5/wiki_ja/byt5.base/snow' && \
  OPERATIVE_CONFIG=$FINE_TUNED_MODEL_DIR'/operative_config.gin' && \
  \
  echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\
  echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\
  \
  t5_mesh_transformer \
  --model_dir="$FINE_TUNED_MODEL_DIR" \
  --module_import="add_byte_snow" \
  --gin_file="$OPERATIVE_CONFIG" \
  --gin_param="run.layout_rules=''" \
  --gin_param="run.mesh_shape=''" \
  --gin_file="eval.gin" \
  --gin_file="beam_search.gin" \
  --gin_param="utils.get_variable_dtype.slice_dtype='float32'" \
  --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \
  --gin_param="MIXTURE_NAME = 'snow_t15_23'" \
  --gin_param="run.dataset_split='validation'" \
  --gin_param="run.batch_size=('tokens_per_batch', 960)" \
  --gin_param="eval_checkpoint_step ='all'" \
  --gin_param="Bitransformer.decode.max_decode_length = 120" 2>&1 | tee eval.log

結果は以下のようになりました。

INFO:tensorflow:eval/snow_t15_23/bleu at step 524288: 2.581
INFO:tensorflow:eval/snow_t15_23/bleu at step 525288: 77.042
INFO:tensorflow:eval/snow_t15_23/bleu at step 526288: 78.483
INFO:tensorflow:eval/snow_t15_23/bleu at step 527288: 79.688
INFO:tensorflow:eval/snow_t15_23/bleu at step 528288: 79.951
INFO:tensorflow:eval/snow_t15_23/bleu at step 529288: 80.024
INFO:tensorflow:eval/snow_t15_23/bleu at step 530288: 80.111
INFO:tensorflow:eval/snow_t15_23/bleu at step 531288: 80.300
INFO:tensorflow:eval/snow_t15_23/bleu at step 532288: 80.214
INFO:tensorflow:eval/snow_t15_23/bleu at step 533288: 80.527
INFO:tensorflow:eval/snow_t15_23/bleu at step 534288: 80.485

最良のチェックポイントを使ってテストデータでのスコアを確認します。

!export PYTHONPATH=${PYTHONPATH}:. && \
  \
  FINE_TUNED_MODEL_DIR='gs://somewhere/byt5/wiki_ja/byt5.base/snow' && \
  OPERATIVE_CONFIG=$FINE_TUNED_MODEL_DIR'/operative_config.gin' && \
  \
  echo "OPERATIVE_CONFIG=$OPERATIVE_CONFIG" &&\
  echo "FINE_TUNED_MODEL_DIR=$FINE_TUNED_MODEL_DIR" &&\
  \
  t5_mesh_transformer \
  --model_dir="$FINE_TUNED_MODEL_DIR" \
  --module_import="add_byte_snow" \
  --gin_file="$OPERATIVE_CONFIG" \
  --gin_param="run.layout_rules=''" \
  --gin_param="run.mesh_shape=''" \
  --gin_file="eval.gin" \
  --gin_file="beam_search.gin" \
  --gin_param="utils.get_variable_dtype.slice_dtype='float32'" \
  --gin_param="utils.get_variable_dtype.activation_dtype='float32'" \
  --gin_param="MIXTURE_NAME = 'snow_t15_23'" \
  --gin_param="run.dataset_split='test'" \
  --gin_param="run.batch_size=('tokens_per_batch', 960)" \
  --gin_param="eval_checkpoint_step = 533288" \
  --gin_param="Bitransformer.decode.max_decode_length = 120" \
  > test.log 2>&1

結果は以下のとおりです。

!cat test.log | grep -e "^INFO.*bleu"
# INFO:tensorflow:eval/snow_t15_23/bleu at step 533288: 80.584

あら、前回 の Switch Transformer を使った結果(79.014)よりも良い結果になってしましました。。。 ただ、ByT5 は GEGLU ありの T5 1.1 がベースなので、そのあたりの影響もあるかもしれません。17

6. Charformer の事前学習

ここからは Charformer の事前学習です。CharformerTall を作って、実際にサブワード単位のモデルと比べどの程度の推論速度になるのか見てみましょう。

環境のセットアップ

新しくノートブックを開き、アクセラレータは TPU を選んで下さい。

Tensorflow は 2.x 系を使います。

%tensorflow_version 2.x 

出来るだけ条件を合わせたかったので ByT5 と同じバージョンを使いました。

!pip install t5[gcp]==0.9.1 mesh-tensorflow==0.1.19 tensorflow-datasets==4.2.0

ここからは依存するリポジトリをクローンします。 まずは ByT5 です。 Charformer の事前学習には ByT5 の学習タスクを流用しました。

!git clone https://github.com/google-research/byt5
!cd byt5 && git checkout 069d3cbc1cdfeb7e6
# ...
# HEAD is now at 069d3cb Add GEM-XSum tasks.

mT5 は ByT5 からの依存があるのでクローンしてます。

!git clone https://github.com/google-research/multilingual-t5
!cd multilingual-t5 && git checkout 35f723155f
# ...
# HEAD is now at 35f7231 Add GEM-XSum, GLUE, SuperGLUE tasks, and task test.

Charformer のコードは google-research リポジトリの 1 ディレクトリになっていますね。

!git clone https://github.com/google-research/google-research
!cd google-research &&  git checkout 111c4ff4a
# ...
# HEAD is now at 111c4ff4 Show error category breakdowns by system.

Charformer を動かすには mesh-tensorflow のコードに修正が必要なので、その作業をします。

ソースコードの修正

ここからは Charformer の README.md18 に従ってソースコードの書き換えを行います。

まずは Step 1 です。不要部分を切り取って

!cat /usr/local/lib/python3.7/dist-packages/mesh_tensorflow/transformer/transformer.py | sed '697,742d' > transformer.py

修正コードを入れ込みます。

%%bash
cat << 'EOS' | sed -i '696r /dev/stdin' transformer.py
               token_dropout_rate=0.0,
               gradient_subwords=None,
               gradient_subword_layer=None):

    self.gradient_subwords = gradient_subwords
    self.num_gsw_layers = 1 # FIXME

    if self.gradient_subwords:
      tf.logging.info("Using gradient subwords..")
      self.grad_layer = [gradient_subword_layer()] * self.num_gsw_layers

EOS

Step 1 のコード追記箇所(698から706行目)はこんな感じです。ちょうど Unitransformer__init__ の先頭部分ですね。 self.num_gsw_layers は謎変数なのですが、とりあえず 1 にしました19

!cat -n transformer.py | sed -n 697,708p 
#   697                token_dropout_rate=0.0,
#   698                gradient_subwords=None,
#   699                gradient_subword_layer=None):
#   700 
#   701     self.gradient_subwords = gradient_subwords
#   702     self.num_gsw_layers = 1 # FIXME
#   703 
#   704     if self.gradient_subwords:
#   705       tf.logging.info("Using gradient subwords..")
#   706       self.grad_layer = [gradient_subword_layer()] * self.num_gsw_layers
#   707 
#   708     self.layer_stack = layer_stack

次は Step 2 の修正です。

%%bash
cat << 'EOS' | sed -i '857r /dev/stdin' transformer.py

    if self.gradient_subwords and self.grad_layer:
      tf.logging.info("Using Charformer before computing layer stack.")
      # tensor should be batch x char_length x dim]
      for grad_layer in self.grad_layer:
        x, context = grad_layer.call(context, x)

EOS

追記箇所(859から863行目)は以下のとおりです。ちょうど Unitransformer のスタックの入り口部分で、 GBST が 設定されていれば、 GBST を通す形になっています。

!cat -n transformer.py | sed -n 857,865p 
#   857       x += pos_emb
#   858 
#   859     if self.gradient_subwords and self.grad_layer:
#   860       tf.logging.info("Using Charformer before computing layer stack.")
#   861       # tensor should be batch x char_length x dim]
#   862       for grad_layer in self.grad_layer:
#   863         x, context = grad_layer.call(context, x)
#   864 
#   865     x = self.layer_stack.call(context, x)

Colab の環境は使い捨てなので、少し行儀が悪いですが /usr/local/lib/ の下を書き換えちゃいましょう。

!cp transformer.py /usr/local/lib/python3.7/dist-packages/mesh_tensorflow/transformer/transformer.py 

次は CharformerTall の gin ファイルを用意していきます。

gin ファイルの作成

CharformerTall の gin ファイルは公開されたコードに含まれていないので、論文の記述に従って自作しました。 まず、 2.2 Transformer Stack に以下の記述があります。

Consequently, we explore a scaling variant of CHARFORMER that puts more parameters at the encoder at the expense of the decoder while preferring a tall narrow model over a larger wide model. Specifically, we re-configure a base model to become a small model with an expanded 24 layers in the encoder.The resulting CHARFORMER Tall has 134M parameters, which is about 67% the parameter footprint of the standard base T5 model (200M parameters) [Raffel et al., 2020]. Moreover, this particular CHARFORMER model is approximately 50-100% faster than the T5 base model (see §4.1).4 For the tall variant, we also used the GLU variant described in [Shazeer, 2020] which is commonly referred to as the V1.1 variant in the T5 library.

また、 7.1 Hyperparameters には以下のように記載されています。

Monolingual English Datasets Our small model follows the T5 small model size with 6 encoder layers and 6 decoder layers, hidden size d_model of 512, 8 heads, d_kv of 32 and d_ff of 2048. This corresponds to bi_v1_small.gin in the T5 codebase. The base model (corresponding to bi_v1.gin) has 12 encoder layers, 12 decoder layers, d_model of 768, d_ff of 3072 and 12 heads. The tall model has 24 encoder layers and 6 decoder layers, while the remainder of its hyperparameters remain identical to the small model.

For the tall models, the optimal downsampling rate was often 3.

つまり、

  • T5Small をベースにして、
  • 活性化関数に GEGLU を使い、
  • encoder は 24 層、 decoder は 6 層として、
  • ds は 3 が具合良かった。

ということのようです。 CharformerBase の gin ファイルがソースコードに含まれているの20で、それを参考に以下のようにしました。 とりあえず gin ファイルから参照するパッケージをインポートしておいて、

import sys
sys.path.append("./google-research")
sys.path.append("./google-research/charformer/lib")
import charformer.lib.charformer_layers

設定を gin ファイルに出力します。

%%bash
cat << EOF > cf_v2_d3_cv_tall.gin
import charformer.lib.charformer_layers
include 'models/bi_v1_small.gin'

DenseReluDense.activation = ["gelu", "linear"]
dropout_rate = 0.0
Unitransformer.shared_embedding_and_softmax_weights = False

GradientSubwordLayerV2.key_value_size = %d_kv
GradientSubwordLayerV2.num_heads = %num_heads
GradientSubwordLayerV2.dropout_rate = %dropout_rate
GradientSubwordLayerV2.downsample_query = 3.0
GradientSubwordLayerV2.radius = 8
GradientSubwordLayerV2.low_rank_features = 32
GradientSubwordLayerV2.project_kv = False
GradientSubwordLayerV2.use_ffn = False
GradientSubwordLayerV2.local_gate = False
GradientSubwordLayerV2.num_memory_slots = 0
GradientSubwordLayerV2.local_attention = False
GradientSubwordLayerV2.consider_chars_as_blocks = True
GradientSubwordLayerV2.conv_type = "conv1d"

encoder/Unitransformer.gradient_subwords = True

make_layer_stack.layer_stack_cls=@charformer_layers.CharformerLayerStack

encoder/Unitransformer.gradient_subword_layer = @charformer_layers.GradientSubwordLayerV2

encoder/transformer.make_layer_stack.num_layers = 24
decoder/transformer.make_layer_stack.num_layers = 6

mesh_train_dataset_fn.pack = False
EOF

T5 1.1 では GEGLU を使うときに d_ff のサイズを縮小するんですが、出来上がったモデルのパラメータ数を見る限り CharformerTall は T5Small の構成のままで(d_ff を変更せずに)、活性化関数を GEGLU にして層数を encoder : decoder = 24 : 6 にしたのではないかと思います(例によって確信はないのですが)。

ただ、このまま動かすと CharformerLayerStack が初期化されずにエラーになるので以下の設定を追加しています。 CharformerLayerStack は基本的に mesh_tensorflow.transformer.transformer.LayerStack のコピーなので、mesh_tensorflow/transformer/gin/defaults.gin から該当の sublayers の設定をコピーしました。

%%bash
cat << EOF > set_stack_sublayers.gin
import mesh_tensorflow.transformer.transformer
import charformer.lib.charformer_layers

charformer_layers.CharformerLayerStack.sublayers_initial = [
    @transformer.sublayer_dropout,
]
charformer_layers.CharformerLayerStack.sublayers_per_layer = [
    @transformer.sublayer_rms_norm,
    @transformer.sublayer_call_layer,
    @transformer.sublayer_dropout,
    @transformer.sublayer_residual,
]
charformer_layers.CharformerLayerStack.sublayers_final = [
    @transformer.sublayer_rms_norm,
    @transformer.sublayer_dropout,
]

EOF

事前学習の実行

あとは ByT5 と同じです。タスクの定義を読み込んで、MIXTURE を登録して、TPU のアドレスを確認して、

import sys
sys.path.append("./byt5")
import byt5.metrics as byt5_metrics   # おまじない
sys.path.append("./multilingual-t5")

%run ./byt5/byt5/tasks.py

import t5
t5.data.MixtureRegistry.add("byt5_wiki_ja", ["byt5_wiki.ja"], default_rate=DEFAULT_MIX_RATE)

import os
import pprint
import json
import tensorflow.compat.v1 as tf

assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print('TPU address is', TPU_ADDRESS)
# TPU address is grpc://xx.xx.xx.xx:xxxx

論文の 3.1 Setup の記述に合わせ、以下の設定とします。

  • 入力シーケンス長 : 1024
  • 学習ステップ数 : 1048576
  • バッチサイズ : 216 トークン(= 65536 = シーケンス長 1024 で 64 サンプル)
from t5.models.mesh_transformer_main import FLAGS
from t5.models.mesh_transformer_main import main

FLAGS.mark_as_parsed()

FLAGS.tpu = TPU_ADDRESS
FLAGS.model_dir = 'gs://somewhere/charformer/wiki_ja/charformer.tall'
tf.flags.FLAGS.gin_file=[
  "dataset.gin",                        
  "./cf_v2_d3_cv_tall.gin", 
  "./set_stack_sublayers.gin",
  "learning_rate_schedules/rsqrt_no_ramp_down.gin"
]
tf.flags.FLAGS.gin_param=[
  "utils.tpu_mesh_shape.model_parallelism = 1",
  "utils.tpu_mesh_shape.tpu_topology = 'v2-8'",
  "utils.run.sequence_length = {'inputs': 1024, 'targets': 189}",
  "run.batch_size = ('tokens_per_batch', 65536)",
  "run.train_steps = 1048576",
  "run.save_checkpoints_steps = 1000",
  "MIXTURE_NAME = 'byt5_wiki_ja'"
]

伏字にする平均スパン長は 20 bytes との記述もありますが、ByT5 のタスクの定義と合致するので大丈夫そうですね。

!cat byt5/byt5/tasks.py | grep ^MEAN_NOISE_SPAN_LENGTH
# MEAN_NOISE_SPAN_LENGTH = 20

それでは実行します。

tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.INFO)
main([])

パラメータ数は以下のとおり、確かに 134M になってるので大丈夫じゃないかな。

INFO:tensorflow:Trainable Variables            count: 240     Total size: 133860352        Total slice_size: 133860352      
INFO:tensorflow:All Variables                  count: 253     Total size: 134280960        Total slice_size: 134280960      

実行速度はこんな感じです。ただバッチサイズは ByT5 の半分であることに注意してください。

INFO:tensorflow:global_step/sec: 1.00239
INFO:tensorflow:examples/sec: 64.1531

examples/sec で比較すると ByT5Base に対して 64.1531 / 33.6515 ≒ 1.91 で decoder が1回であれば倍近い速度にはなっていますね。

検証してみると以下のような曲線になりました。

nlppl_curve_charformer

GBST が入っているせいか学習の出だしでかなり苦しんでますね。無料の TPU じゃなかったら多分諦めてました。 その後もやや不安定な感じで運が悪いとオレンジで示した線のようになったりします。。。巻き戻ってやり直しました。

スコアも ByT5Base よりかなり悪いですが、これはパラメータ数がだいぶ違うので、比べたら可哀そうですかね。 対数で 1.5 だと perplexity に直すと e1.5 = 4.48 くらいですか。。。ちなみに ByT5 のほうは e0.42 = 1.52 です。

さて、ファインチューニングしてどの程度の精度と推論速度になるでしょうか。。。

7. Charformer のファインチューニング

新しくノートブックを開き、アクセラレータに GPU を選んだら、前章の環境のセットアップに従ってセットアップして下さい。あとは、ほとんど ByT5 の繰り返しなので省略です!

修正するのは環境変数 PYTHONPATH と事前学習済みモデルを配置するディレクトリ(FINE_TUNED_MODEL_DIROPERATIVE_CONFIG) がらみ、事前学習のステップ数(PRE_TRAINGING_STEPS)だけですね。出だしが以下のようになります。

!export PYTHONPATH=${PYTHONPATH}:.:./byt5:./multilingual-t5:./google-research:./google-research/charformer/lib && \
  \
  OPERATIVE_CONFIG='gs://somewhere/charformer/wiki_ja/charformer.tall/snow/operative_config.gin' && \
  FINE_TUNED_MODEL_DIR='gs://somewhere/charformer/wiki_ja/charformer.tall/snow' && \
  FINE_TUNING_BATCH_SIZE=`expr 120 \* 8` && \
  PRE_TRAINGING_STEPS=1048576 && \
  FINE_TUNING_STEPS=`expr $PRE_TRAINGING_STEPS + 10000` && \
  INPUT_SEQ_LEN=120 &&\
  TARGET_SEQ_LEN=120 &&\
  ...

結果は以下のとおりです。。。。全然ダメでした。。。

!cat test.log | grep -e "^INFO.*bleu"
# INFO:tensorflow:eval/snow_t15_23/bleu at step 1057576: 0.117

学習ロスをプロットするとこんな感じです。。。

fintune_charformer

うーん、今回の設定だと「こんなもん」なのか、どこかで間違ったのか判別つきませんが、比較にならないので Charformer は忘れちゃいましょう。

比較対象が予定より減ってしまいましたが、T5 と精度と速度を比べてみます。

8. 精度と速度の比較

本記事で学習した ByT5Base と手元にあった T5Base 1.1 で比較してみました。 事前学習とファインチューニングの学習データは同じで活性化関数は全て GEGLU です。各モデルのサイズ感を改めて示すと、以下のとおりバラバラですね。。。

model_and_params

それでは、精度の比較です。

bleu_t5_byt5

ほぼ同じですね。バイト単位の処理で不利にになった分をパラメータ数の増加で補った感じですね。Transformer のサイズ感を揃えていれば T5Base 1.1 が逆転しそうな気がします。

Clean Data が本記事中で使ったテストデータでのスコア、Noisy Data が自前のロジックで Clean Data に対して、

  • 1 サンプルに 1 個程度の誤字脱字が入る
  • 誤字脱字は誤変換(10%)、脱字(30%)、ランダムノイズ(10%)、"てにをは"系(50%)の混合

という具合でノイズを加えたデータになっています。

ですが、両者が同じようにスコアを落としてしまい、 ByT5 がノイズに強い感じが全然しない結果になってしまいました。 よく考えてみれば、今回のタスクは変換対象と認識している語句以外は左から右にパススルーな感じになると思うので、 「そりゃそうか。。。」という気もします。文章分類とかの方がよかったかもしれませんね。

続いて速度の比較です。

speed_t5_byt5

T5Base 1.1 の方が 3 倍ちょっと高速でした。ByT5Base はパラメータが増えているので、もっと遅くなるかと思いましたが意外とイケる感じでしょうか。 多言語の文章に対応できて「あの字がでない。。。」と悩まなくて良いので状況次第では ByT5 をチョイスしても良いかもしれませんね。

9. おまけ

ByT5 ですが UTF-8 だと日本語 1 文字出力するのに decoder を 3 周するのがツライところです。「じゃぁ、 UTF-16 とかでエンコードしたらいいんじゃね?」と思いついたので、やってみました。

import tensorflow.compat.v2 as tf
from t5.data import ByteVocabulary

class Utf16Vocabulary(ByteVocabulary):

  def __init__(self, extra_ids: int = 21):
    super().__init__(extra_ids=extra_ids)

  def _convert_strings_to_ids(self, s):
    return list(s.encode("utf-16be"))

  def _convert_ids_to_strings(self, ids):
    return bytes(ids).decode("utf-16be", errors="ignore")

  def _encode_tf(self, s):
    tf_ids = tf.io.decode_raw(
               tf.strings.unicode_transcode(s, input_encoding='UTF-8', 
               output_encoding='UTF-16-BE'), tf.uint8) + self._num_special_tokens
    return tf.dtypes.cast(tf_ids, tf.int32)

  def batch_decode(self, batched_ids):
    if len(batched_ids.shape) > 1:
      batched_strings = [self.decode(batched_ids[i]) for i in range(len(batched_ids))]
      return  tf.constant(batched_strings)
    else:
      return tf.constant(self.decode(batched_ids))

  def _decode_tf(self, ids):
    return tf.py_function(func=self.batch_decode, inp=[ids], Tout=tf.string)

エンコーディングを UTF-8 から UTF-16-BE に替えた以外の変更点としては、extra_ids として 21 個分を確保したところでしょうか。 extra_ids は事前学習の時の伏字(論文の図にでてくる"<X>""<Y>")に対応するものですが、 UTF-8 の場合は使用するビットパターンの末端の方に空き地があるのでそこを利用しているようです。ですが、UTF-16-BE では以下のように FF までガッツリ使われるので、それ用のスペースを確保する訳ですね。

[b for b in "寿、姿、響".encode("utf-16be")]
[91, 255, 48, 1, 89, 255, 48, 1, 151, 255]

とりあえず、本記事で紹介した TPU での事前学習 & 検証、GPU でのファインチューニング & 検証が動作するところまでは確認しました。 事前学習は ByteVocabularyUtf16Vocabulary に置き換えた以外は同じセッティングで実行してます。 予定の全ステップ(=524288)回した訳ではないのですが、133100 ステップ時点で Negative log perplexity = -0.63 でした。 ByT5 は 131400 ステップで Negative log perplexity = -0.51 だったので数値的には劣りますが語彙数が増えているのでその影響かもしれません。

133100 ステップを起点に入出力シーケンス長を 80 にした以外は同様の手順でファインチューニングし、精度評価と速度計測をしてみると、 BLEU = 78.673 で 2.904 examples/sec でした。

精度は事前学習を途中で打ち切ったので何とも言えないですが、速度的には 1.7 倍になりました(1.5 倍になる予定だったんだけど)。 日本語中心で処理するなら、選択肢に入れてもいいかもしれませんね。

10. おわりに

最近は BERT 系の日本語モデルも りんな さんの RoBERTa21 や Megagon Labs さんの ELECTRA22 など CC-100 や mC4 など Wikipedia よりも 大規模な日本語コーパスで学習した事前学習済みモデルが公開されるようになってきました。

ついでにいうと Megagon Labs さんから mC4 で学習した T5Base 1.1 モデル23も公開されてますね!

次回の話です。この連載、最近は第9回の Sentence BERT の記事がアクセス数多いと社内の人から聞きました。 この記事を執筆した当時は良い日本語データセットを見つけらなかったので画像のキャプションから強引に学習データを作ったのですが、 現在では京都大学の黒橋・褚・村脇研究室から SNLI を日本語化した JSNLI データセット24が公開されています。

Sentence BERT のコードも執筆時からだいぶ更新が入っているようなので、次回は JSNLI データセットを用いて Sentence BERT 改め Sentence Transformer をいろんな事前学習済みモデルで作り比べてみようかと思います。


  1. https://arxiv.org/abs/2105.13626 

  2. https://arxiv.org/abs/2106.12672 

  3. https://www.kanken.or.jp/kanken/outline/degree.html 

  4. ちなみに、U+2047 は Sentencepiece のデフォルトで --unk_surface=<STR> で変更することができます。 

  5. mT5 のコードで使用されている Sentencepiece のモデル(gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model)を確認したところ、語彙数は 250100 でした。 

  6. 詳しくは 論文の 6.2 Encoder/Decoder Balance に記述があります。 

  7. 図中の X の表記は Xi ですね。 X の i バイト目の埋め込み表現です。 

  8. 図中の pn:m は pb,i の n ≦ i ≦ m という表記ですね。b は略記されています。 

  9. https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/models/gin/models/bi_v1_small.gin 論文 7.1 の文中には d_kv = 32 とありますが、bi_v1_small.gin に準じた構成だと d_kv = 64 になるはずなんですが。。。  

  10. https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/models/gin/models/bi_v1.gin 

  11. T5 1.1 で GEGLU を使うときは d_ff のサイズを絞るのですが、ChaformerTall は(パラメータ数の記述から推測になりますが) bi_v1_small.gin に定義された d_ff = 2048 のままで GEGLU を使っているようです。  

  12. https://github.com/google-research/byt5/blob/master/README.md 

  13. https://www.tensorflow.org/datasets/catalog/c4#c4multilingual 

  14. https://github.com/google-research/byt5/blob/4066a7270d6cc1cbc04fe30dd7762275bacb22c3/byt5/tasks.py#L66-L86 

  15. https://github.com/google-research/text-to-text-transfer-transformer/issues/687  

  16. mesh-tensorflow 0.1.19 で遭遇しましたが、この記事の公開時点のバージョンでは治ってるかもしれません。 

  17. 前回の Switch Transformer は GEGLU なしの T5 1.0 相当のモデルの FFN を並列化したものでしたから。いろいろ試してると GEGLU 効くんですよね。 

  18. https://github.com/google-research/google-research/blob/master/charformer/README.md 

  19. GBST 層を何層重ねるかという設定です。重ねると ds 分の 1 の長さにどんどん短くなるので 1 に設定しておくのが無難かと。 

  20. https://github.com/google-research/google-research/blob/master/charformer/configs/cf_v2_d3_cv_base.gin 

  21. https://huggingface.co/rinna/japanese-roberta-base 

  22. https://huggingface.co/megagonlabs/transformers-ud-japanese-electra-base-discriminator 

  23. https://huggingface.co/megagonlabs/t5-base-japanese-web 

  24. https://nlp.ist.i.kyoto-u.ac.jp/index.php?%E6%97%A5%E6%9C%AC%E8%AA%9ESNLI%28JSNLI%29%E3%83%87%E3%83%BC%E3%82%BF%E3%82%BB%E3%83%83%E3%83%88 私が Sentence BERT の記事を公開した 2 週間後くらいに公開されてたらしいです。。