オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第21回 T5X と Prompt Tuning の検証
オージス総研 技術部 データエンジニアリングセンター
鵜野 和也
2022年6月23日

今回は T5X と Prompt Tuning の検証をしてみました。T5X は JAX と Flax で実装された T5 の新世代実装です。 Prompt Tuning は近年流行している事前学習済みモデルとプロンプトで下流タスクを解く手法の一つです。 Prompt Tuning に関しては T5X で実装されたコードが公開されていたので、合わせて検証してみることにしました。

1. はじめに

今回は T5X1 と Prompt Tuning2 の検証とご紹介になります。

T5X は第7回で紹介した T53 の次世代実装になります。T5 は、Mesh Tensorflow4 を採用することで、 単一の TPU や GPU に全パラメータが格納できない大規模モデルを実現していますが、学習ループ周辺の実装は Tensorflow 1.x 系列の Estimator API を用いた、やや古びた構成になっていました。

「いつまで、これで引っ張るんだろう。そのうち Tensorflow 2.x 系に移行するんかな?」と思っていましたが、一気に JAX / Flax ベースの実装になりました。今回は実際に動かしてみて従来実装( T5 )と同じレベルの精度がでるか確認して見たいと思います。

そして、「T5X を動かしてみるだけでは寂しいな。。。」と思ってたところ Prompt Tuning の実装5が T5X ベースで公開されているのに気づきました。

Prompt Tuning は 2021 年あたりから流行している、プロンプトを用いて事前学習済みモデルを変更することなく下流タスクに適応させる手法の一つです。 せっかく T5X の動かし方を覚えるんですから、Prompt Tuning も実際に動かしてみて従来手法(ファインチューニング)に対してどの程度の精度が出せるか試してみることにしました。

まずは T5X についてもう少し見ていきましょう。

2. T5X

Google は 2019 年の T5 の発表以降、mT5, ByT5, ExT5 と T5 系の論文を発表してきており、これらは T53 をベースとした実装になっていました。

それが T5X として実装を新たに作り直した訳ですから、今後 Google が T5 系の論文を出してくるとしたら、その実装は T5X になるのでしょう(実際に Prompt Tuning は T5X ベースの実装ですし)。そういうことなら、ここで T5X の使い方を覚えておいて損はないと思います。

この記事を全部書いてしまった後で追記

まぁ、「今後、 T5 系の論文を出してくるとしたら」に関しては、 GLaM, LaMDA, Gopher, PaLM とかの流れを見てると「どうかな~?」という気もしてきましたが。。。

さて、前述のとおり T5X は JAX6 と Flax7 で実装された T5 の新世代実装です。まずは JAX と Flax がどんなものか、簡単に押さえておきましょう。

jax

JAX について Quick Start のページ8に端的な説明があります。

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

ようは「GPU や TPU 上で動かせる自動微分機能のある NumPy」ということです。ループ、条件分岐、再帰、クロージャ等を含んだ Python と NumPy のコードを自動微分でき、複数のアクセラレータ( GPU や TPU )の上で実行することが出来ます。

詳細はこの記事には書ききれないので、前述の Quick Start8 を試してみることをお勧めします。注視すべきポイントは以下になります。

  • JAX で扱う Python 関数は純粋関数でないといけない
    同じ入力には同じ返り値。副作用なし。
  • JAX device array (NumPy における ndarray 相当) は更新できない
    x[O,:] = 3.0 ではなく updated = x.at[0,:].set(3.0)
  • JAX では PRNG(疑似乱数発生器)の状態を明示的に管理しないといけない。:
    あれこれ書くより Quick Start を見てもらったほうが分かりやすいかと。
  • jit() :
    関数を XLA コンパイルして高速化する。
  • grad() :
    関数の導関数を得る。
  • vmap() :
    関数をベクトル化する。(データ1件を処理する関数をそのデータのバッチを処理する関数にできる。手でループを書くより高速に。)
  • pmap() :
    関数を XLA コンパイルして複数アクセラレータ上に複製、実行する。一般的には single-program multiple-data (SPMD) と呼ばれるものです。

T5 では Mesh Tensorflow が担っていた複数アクセラレータ対応を、 T5X では JAX 自体がカバーしてくれる形になります。これはコードの見通しが良くなりそうですね。

つづいて Flax です。

flax

Flax は JAX ベースのニューラルネットワークライブラリ(とエコシステム)で柔軟性を重視して設計されています。 ニューラルネットワークライブラリとして必要なものは現時点で揃っていて、 Hugging Face の Transformers でも少し前に Flax サポートが追加されています9

こちらも細かくは説明しないので、興味のある方は Flax の Guided Tour10 等を見て頂くとして、コードの雰囲気は以下のようになります。

# https://github.com/google/flax から引用してコメント部分を加筆
from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])                             # 3 層 MLP の定義
batch = jnp.ones((32, 10))                            # ダミーの入力値
variables = model.init(jax.random.PRNGKey(0), batch)   # モデルのパラメータを初期化。batch は入力数の推論の為。
output = model.apply(variables, batch)                 # モデルのパラメータ(variables)と入力(batch)で出力を得る。
                                                       # output.shape は (32,4)

Tensorflow や PyTorch と比べると、モデルの定義とそのパラメータが明確に分離されているのが特徴的ですね。 複雑なモデルを実装する側の視点からだと、こちらの方が扱いやすいのでしょう。

さて、T5X に話を戻しましょう。T5X はモデルの実装ライブラリが JAX / Flax に置き代わっていることが大きな相違点ですが、 以下のような点を T5 から受け継いでいます。

  • 学習タスクを SeqIO11 の Task で表現する(個別のタスクの定義は t5.data からの流用)。
  • gin12 で部品の組み立てやパラメータ設定を行う。

これまで T5 を利用するときも、「タスクを定義し、gin で各種設定を記述して実行する」のがメインで、あまり Mesh Tensorflow で書かれたコードと格闘するようなことはありませんでしたが、この点は T5X になっても同じです。

タスクの定義の仕方は T5 と同じ(と言うか、そのもの)ですので、 T5 を触ったことがある人は (gin による設定の表記方法に違いはあるものの)あまり違和感なく T5X を動かせると思います。

それでは、実際に T5X を動かして見ましょう。

3. T5X でのファインチューニング

ここからは GPU を使って T5X で事前学習済みモデルをファインチューニングして行きます。

今回も記事内のコードスニペットは、特に断りがない場合は Google Colaboratory (以下、Colab)で動かす想定にしています。 新たに Colab のノートブックを開いて、アクセラレータに GPU を選んでください。

ちょっと脱線:事前学習はどうした?

「あれ、事前学習の仕方は説明してくれないの?」と思われた方がいらっしゃるかもしれませんね。一言でいうと Colab の TPU では動かせませんでした。 どうも Colab の TPU は古い世代のアーキテクチャであり、これが JAX の SMPD partitioner で対応していない13ということらしいです。。。 ただ、こちらの issue14 では T5X の開発者さんが問題を認識していて何とかしようとしている感があるので、この記事が公開されるころには動くようになっているかもです。

たぶんファインチューニングを動かせるようになっていれば、事前学習を動かすのもそんなに難しくないでしょう。

では話をファインチューニングに戻します。

セットアップ

T5X をインストールします。普通に入れると nightly やら github のリポ直撃やらでガンガン入り、動かす日によって問題が起こったりするので、 出来るだけバージョンを固定するようにしました。

!git clone --branch=main https://github.com/google-research/t5x
!cd t5x && git checkout df1ee3f3d5fd63ad0959e110cdfa46383356c99b
# ...
# HEAD is now at df1ee3f Migrate away from using private JAX test utils
!cd t5x && python -m pip install -e '.[gcp]'

私の場合、ある日突然エラーで動かなくなったので、flax を差し戻しました。

!pip uninstall -y flax
!pip install git+https://github.com/google/flax@9bd65b20752e7bfc172796e66948b6c216405b9b

日本語 で BLEU の計算をするので、エクステンションを追加します。

!pip install sacrebleu[ja]

念のため、ここで一旦再起動します。

再起動したら、まずは GCS の認証を通しておいて下さい。

import os
os.environ["USE_AUTH_EPHEM"] = '0'
from google.colab import auth
auth.authenticate_user()

環境変数 USE_AUTH_EPHEM を 0 に変更してます。この辺の話は後述しますが、GPU で実行するなら不要だったかもしれません。

言い訳

上記の手順でインストールした commit と実行時のログの日付、学習時と検証時のログの日付が前後してたりしますが、ご容赦ください。 一度、動かして記事を書いた後、後日再び動作チェックしたら、アレコレ動かなくなってたのでインストールするバージョン の組み合わせを試行錯誤して上記の手順にたどり着いた訳で。。。動作確認は少ないステップ数で実施したのですが、大丈夫だと思います(多分)。

学習データの準備

ファインチューニングのタスクですが、この連載で何度か使った「やさしい日本語変換」を試しました。 加工済みのデータを GCS から取得します。加工の仕方は第14回を参照してください。

!gsutil cp gs://somewhere/snow_t15_23_*.tsv .
!wc -l snow*.tsv
#   7040 snow_t15_23_dev.tsv
#   7040 snow_t15_23_test.tsv
#  56317 snow_t15_23_train.tsv
#  70397 total

Tensorflow の事前学習済みモデルを T5X の形式の変換する。

事前に変換しなくても自動判定して変換してくれる仕様になっているようですが、変換の仕方を説明する意味でも事前に変換することにしました。

  • 肝心の事前学習済みモデルですが、これまでの連載の過程で作ったものがお手元にない方は第14回の 2 章を参考にして作って下さい(すみません)。
!gsutil ls gs://somewhere/t5/pre_trained_t5_1.1
# gs://somewhere/t5/pre_trained_t5_1.1/checkpoint
# gs://somewhere/t5/pre_trained_t5_1.1/graph.pbtxt
# gs://somewhere/t5/pre_trained_t5_1.1/model.ckpt-524288.data-00000-of-00002
# gs://somewhere/t5/pre_trained_t5_1.1/model.ckpt-524288.data-00001-of-00002
# gs://somewhere/t5/pre_trained_t5_1.1/model.ckpt-524288.index
# gs://somewhere/t5/pre_trained_t5_1.1/model.ckpt-524288.meta
# gs://somewhere/t5/pre_trained_t5_1.1/operative_config.gin

以下のようにしてコンバートします。

!export PYTHONPATH=${PYTHONPATH}:.:./t5x && \
CUDA_VISIBLE_DEVICES="" && \
python -m t5x.scripts.convert_tf_checkpoint \
 --gin_file="t5x/examples/t5/t5_1_1/base.gin" \
 --gin.convert_checkpoint.model=%MODEL \
 --gin.DROPOUT_RATE="0.0" \
 --gin.convert_checkpoint.tf_checkpoint_path=\"gs://somewhere/t5/pre_trained_t5_1.1/model.ckpt-524288\" \
 --gin.convert_checkpoint.output_dir=\"gs://somewhere/t5x/pre_trained_t5_1.1_native\" \
 --logtostderr

# Rewritten gin arg: --gin_bindings=convert_checkpoint.model = %MODEL
# Rewritten gin arg: --gin_bindings=DROPOUT_RATE = 0.0
# Rewritten gin arg: --gin_bindings=convert_checkpoint.tf_checkpoint_path = "gs://somewhere/t5/pre_trained_t5_1.1/model.ckpt-524288"
# Rewritten gin arg: --gin_bindings=convert_checkpoint.output_dir = "gs://somewhere/t5x/pre_trained_t5_1.1_native"
# I0318 06:43:10.417450 140574798473088 resource_reader.py:50] system_path_file_exists:t5x/examples/t5/t5_1_1/base.gin
# I0318 06:43:10.419274 140574798473088 resource_reader.py:37] gin-config opened resource file:/content/t5x/t5x/examples/t5/t5_1_1/base.gin
# I0318 06:43:10.448343 140574798473088 gin_utils.py:63] Gin Configuration:
# from __gin__ import dynamic_registration
# ...
# I0318 06:44:44.707433 140574798473088 checkpoints.py:637] Saved checkpoint for step 524288 to gs://somewhere/t5x/pre_trained_t5_1.1_native/checkpoint_524288

タスクの登録

以下のようにして「やさしい日本語変換」タスクを登録する処理を準備します。過去の記事と微妙に違っていますが、ほとんど同じです。 gs://somewhere/t5/sentencepiece/sp.model は事前学習モデル済みモデルを作る時に使った Sentencepiece のモデルですね。

%%bash
cat << EOF > snow_task.py
import functools
import tensorflow as tf
from t5.evaluation import metrics
from t5.data import preprocessors
from seqio import vocabularies
from t5.data.utils import DEFAULT_EXTRA_IDS
from seqio import Feature
from t5.data.dataset_providers import TaskRegistry
from t5.data.dataset_providers import TextLineTask
from sacrebleu import corpus_bleu

def bleu(targets, predictions):
  predictions = [tf.compat.as_text(x) for x in predictions]
  if isinstance(targets[0], list):
    targets = [[tf.compat.as_text(x) for x in target] for target in targets]
  else:
    targets = [tf.compat.as_text(x) for x in targets]
    targets = [targets]

  bleu_score = corpus_bleu(predictions, targets,
                                     smooth_method="exp",
                                     smooth_value=0.0,
                                     force=False,
                                     lowercase=False,
                                     tokenize="ja-mecab",
                                     use_effective_order=False)
  return {"bleu": bleu_score.score}

task_name = "snow_t15_23"

tsv_path = {
    "train": "./snow_t15_23_train.tsv",
    "validation": "./snow_t15_23_dev.tsv",
    "test": "./snow_t15_23_test.tsv",
}

TaskRegistry.add(
    task_name,
    TextLineTask,
    split_to_filepattern=tsv_path,
    text_preprocessor=[
      functools.partial(
          preprocessors.parse_tsv,
          field_names=["inputs", "targets"]),
    ],
    output_features = Feature(vocabularies.SentencePieceVocabulary(
      "gs://somewhere/t5/sentencepiece/sp.model",
      DEFAULT_EXTRA_IDS)),
    metric_fns=[bleu])
EOF

gin ファイルの定義

以下のようにして gin ファイルを定義します。

%%bash
cat << EOF > t5x_t5_1.1_base_finetune_snow.gin
from __gin__ import dynamic_registration
import t5.data.mixtures
import __main__ as train_script
from t5x import utils
from t5x import models
from t5x import trainer
from t5x.examples.t5 import network
import seqio
import snow_task # 1.

include 't5x/examples/t5/t5_1_1/base.gin' # 2.
include 't5x/configs/runs/finetune.gin'   # 3.

seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://somewhere/t5/sentencepiece/sp.model" # 4.
seqio.SentencePieceVocabulary.extra_ids = 100 # 4.

# 32 x 64 = 2048
LOSS_NORMALIZING_FACTOR = 2048 # 5.

BATCH_SIZE = 32
MIXTURE_OR_TASK_NAME = "snow_t15_23"
TASK_FEATURE_LENGTHS = {'inputs': 64, 'targets': 64}
USE_CACHED_TASKS = False

train_script.train:
  train_dataset_cfg = @train/utils.DatasetConfig()
  train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
  eval_period = 400

utils.SaveCheckpointConfig:
  period = 400

network.T5Config:
  dtype = 'float32' # 6.
EOF

これまでと一番違うのかココですね。ポイントを絞って説明します。

  1. : 前節で定義したタスクを T5X が見つけられるようにする為、 snow_task を import します。
  2. : T5 のモデルの構造を決める設定を include します。今回は事前学習済みモデルが T5 1.1 の Base なので t5x/examples/t5/t5_1_1/base.gin を読み込みます。
  3. : ファインチューニング時は t5x/configs/runs/finetune.gin を読み込みます。
  4. : タスクの登録のところと同じ Sentencepiece のモデルを設定します。二か所で同じこと書くのが微妙ですが書かないとエラーになったような。
  5. : LOSS_NORMALIZING_FACTOR は T5 の従来実装と挙動を合わせるために必要です。今回はファインチューニング時のバッチサイズ×出力シーケンス長を指定しました。

    • : base.gin では以下のように記述があります。
    # NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF)
    # the loss normalizing factor should be set to pretraining batch_size *
    # target_token_length.
    LOSS_NORMALIZING_FACTOR = None
    
    • 上記のコメントには “should be set to pretraining batch_size * target_token_length.” とありますが、従来実装で相当するのはこの部分15で以下のような説明が記載されています。明示指定しなければ、ファインチューニング時のバッチサイズ×出力シーケンス長になるようだったので、それに合わせました16
      loss_denominator: an optional float.  The default behavior is to
        compute the mean loss across all tokens in the batch, making the
        denomiator the size of the targets tensor (omitting ensemble
        dimensions).
        Passing a float here provides an alternative denomiator.
        One use case is that when fine-tuning a model using a much smaller
        batch size than the original training batch, one might want to use the
        same denominator as was used for the pretraining.  This complication
        might be avoided by always using loss_denominator = 1.0.
    
  6. : これを書いておかないと bfloat16 になってしまったので追記してます。

ファインチューニングの実行

あとは、先ほどの gin ファイルといくつかのパラメータを指定して実行するだけですね。TRAIN_STEPS には事前学習のステップ数+ファインチューニングのステップ数の値を設定することに注意して下さい。

# TRAIN_STEPS = 526288 # 524288 + 2000
!export PYTHONPATH=${PYTHONPATH}:.:./t5x && \
  BUCKET="gs://somewhere" && \
  MODEL_DIR="${BUCKET}/t5x/t5_1.1_base_snow_fp32" && \
  INITIAL_CHECKPOINT_PATH="${BUCKET}/t5x/pre_trained_t5_1.1_native/checkpoint_524288" && \
  TRAIN_STEPS=526288 && \
  T5X_DIR="./t5x" && \
  python ${T5X_DIR}/t5x/train.py \
    --gin_file="./t5x_t5_1.1_base_finetune_snow.gin" \
    --gin.MODEL_DIR=\"${MODEL_DIR}\" \
    --gin.INITIAL_CHECKPOINT_PATH=\"${INITIAL_CHECKPOINT_PATH}\" \
    --gin.TRAIN_STEPS=${TRAIN_STEPS} \

# Rewritten gin arg: --gin_bindings=MODEL_DIR = "gs://somewhere/t5x/t5_1.1_base_snow_fp32"
# Rewritten gin arg: --gin_bindings=INITIAL_CHECKPOINT_PATH = "gs://somewhere/t5x/pre_trained_t5_1.1_native/checkpoint_524288"
# Rewritten gin arg: --gin_bindings=TRAIN_STEPS = 526288
# I0323 05:23:47.033981 140516635555712 resource_reader.py:50] system_path_file_exists:t5x/examples/t5/t5_1_1/base.gin
# I0323 05:23:47.034569 140516635555712 resource_reader.py:37] gin-config opened resource file:/content/t5x/t5x/examples/t5/t5_1_1/base.gin
# I0323 05:23:47.043861 140516635555712 resource_reader.py:50] system_path_file_exists:t5x/configs/runs/finetune.gin
# I0323 05:23:47.044457 140516635555712 resource_reader.py:37] gin-config opened resource file:/content/t5x/t5x/configs/runs/finetune.gin
# I0323 05:23:47.072142 140516635555712 gin_utils.py:63] Gin Configuration:
...
# /usr/local/lib/python3.7/dist-packages/jax/experimental/pjit.py:183: UserWarning: pjit is an experimental feature and probably has bugs!
# ...
# I0323 06:20:52.515951 140516635555712 utils.py:740] Inference of batch [7008 7009 7010 7011 7012 7013 7014 7015 7016 7017 7018 7019 7020 7021
 7022 7023 7024 7025 7026 7027 7028 7029 7030 7031 7032 7033 7034 7035
 7036 7037 7038 7039] done.
# I0323 06:20:52.522318 140516635555712 utils.py:755] Inference of all batches done.
# I0323 06:20:52.581886 140516635555712 evaluation.py:477] Time waiting for previous metrics run: 0.000025 secs.
# I0323 06:20:52.582392 140494084175616 evaluation.py:525] Computing metrics for snow_t15_23
# I0323 06:20:52.588871 140496691066624 logging_writer.py:48] [526288] collection=train timing/evaluate_seconds=78.500227
# I0323 06:20:54.758864 140494084175616 loggers.py:89] snow_t15_23/bleu at step 526288: 79.906
# I0323 06:20:55.093552 140494084175616 loggers.py:354] Appending metrics to gs://somewhere/t5x/t5_1.1_base_snow_fp32/inference_eval/snow_t15_23-metrics.jsonl
# I0323 06:20:56.074098 140494084175616 loggers.py:382] Writing inferences to gs://somewhere/t5x/t5_1.1_base_snow_fp32/inference_eval/snow_t15_23-526288.jsonl
# I0323 06:20:58.451827 140494084175616 loggers.py:415] Writing completed in 2.377776 seconds (0.841122 examples/sec).
# I0323 06:20:58.452686 140494084175616 evaluation.py:483] Time computing metrics: 5.870297 secs.
# I0323 06:20:58.454619 140516635555712 train.py:559] Finished.    

とりあえず動きましたが、 pjit is an experimental feature and probably has bugs! と警告がでていますね。この辺りは JAX / Flax の発展と共に成熟してきくことを期待しましょう。

学習曲線は以下のようになりました。

learning_curve

こんどはテストデータを使って検証してみましょう。

検証の実行

以下のようにして gin ファイルを定義します。

%%bash
cat << EOF > t5x_t5_1.1_base_eval_snow.gin
from __gin__ import dynamic_registration
import __main__ as eval_script
import t5.data.mixtures
from t5x import utils
from t5x import partitioning
from t5x import models
from t5x.examples.t5 import network
import seqio
import snow_task

include 't5x/examples/t5/t5_1_1/base.gin'
include 't5x/configs/runs/eval.gin'

seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://somewhere/t5/sentencepiece/sp.model"
seqio.SentencePieceVocabulary.extra_ids = 100

DROPOUT_RATE = 0.0

MIXTURE_OR_TASK_NAME = "snow_t15_23"

network.T5Config:
  dtype = 'float32'
EOF

include するのが eval.gin に変わった以外は特に説明することもなさそうですね。スプリットに test を使う指定は eval.gin に記述されています。以下のようにして実行します。

!export PYTHONPATH=${PYTHONPATH}:.:./t5x && \
  BUCKET="gs://somewhere" && \
  EVAL_OUTPUT_DIR="${BUCKET}/t5x/t5_1.1_base_snow_fp32/eval" && \
  CHECKPOINT_PATH="${BUCKET}/t5x/t5_1.1_base_snow_fp32/checkpoint_526288" && \
  T5X_DIR="./t5x" && \
  python ${T5X_DIR}/t5x/eval.py \
    --gin_file="./t5x_t5_1.1_base_eval_snow.gin" \
    --gin.EVAL_OUTPUT_DIR=\"${EVAL_OUTPUT_DIR}\" \
    --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \

# Rewritten gin arg: --gin_bindings=EVAL_OUTPUT_DIR = "gs://somewhere/t5x/t5_1.1_base_snow_fp32/eval2"
# Rewritten gin arg: --gin_bindings=CHECKPOINT_PATH = "gs://somewhere/t5x/t5_1.1_base_snow_fp32/checkpoint_526288"
# I0324 00:59:15.374249 140200483284864 resource_reader.py:50] system_path_file_exists:t5x/examples/t5/t5_1_1/base.gin
# I0324 00:59:15.374909 140200483284864 resource_reader.py:37] gin-config opened resource file:/content/t5x/t5x/examples/t5/t5_1_1/base.gin
# I0324 00:59:15.383445 140200483284864 resource_reader.py:50] system_path_file_exists:t5x/configs/runs/eval.gin
# I0324 00:59:15.384062 140200483284864 resource_reader.py:37] gin-config opened resource file:/content/t5x/t5x/configs/runs/eval.gin
# I0324 00:59:15.399366 140200483284864 gin_utils.py:63] Gin Configuration:
# from __gin__ import dynamic_registration
# ...
# 7036 7037 7038 7039] done.
# I0324 01:04:23.973330 140200483284864 utils.py:755] Inference of all batches done.
# I0324 01:04:24.034694 140195098597120 evaluation.py:525] Computing metrics for snow_t15_23
# I0324 01:04:26.957364 140195098597120 loggers.py:89] snow_t15_23/bleu at step 526288: 79.899
# I0324 01:04:28.332274 140195098597120 loggers.py:354] Appending metrics to gs://somewhere/t5x/t5_1.1_base_snow_fp32/eval2/inference_eval/snow_t15_23-metrics.jsonl
# I0324 01:04:29.085039 140195098597120 loggers.py:382] Writing inferences to gs://somewhere/t5x/t5_1.1_base_snow_fp32/eval2/inference_eval/snow_t15_23-526288.jsonl
# I0324 01:04:31.947225 140195098597120 loggers.py:415] Writing completed in 2.862218 seconds (0.698759 examples/sec).
# I0324 01:04:31.947912 140195098597120 evaluation.py:483] Time computing metrics: 7.913265 secs.
# I0324 01:04:31.951103 140200483284864 eval.py:177] Finished.

BLEU スコアで 79.899 がでました。

T5 のバージョン、事前学習のデータセット、ファインチューニング時のバッチサイズが違うので直接比較できませんが、第7回では 78.795 だったのでとりあえず期待したように動いてはいるようですね。

推論もしてみましょう。

推論の実行

gin ファイルの定義です。

%%bash
cat << EOF > t5x_t5_1.1_base_infer_snow.gin
from __gin__ import dynamic_registration
import __main__ as infer_script
from t5x import utils
from t5x import partitioning
from t5x import models
from t5x.examples.t5 import network
import seqio
import snow_task

include 't5x/examples/t5/t5_1_1/base.gin'
include 't5x/configs/runs/infer.gin'

seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://somewhere/t5/sentencepiece/sp.model"
seqio.SentencePieceVocabulary.extra_ids = 100

DROPOUT_RATE = 0.0

MIXTURE_OR_TASK_NAME = "snow_t15_23"
TASK_FEATURE_LENGTHS = {'inputs': 64, 'targets': 64}

infer_script.infer:
  restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()

utils.RestoreCheckpointConfig:
  path = %CHECKPOINT_PATH
  mode = 'specific'
  dtype = 'float32'

network.T5Config:
  dtype = 'float32'
EOF

こんどは infer.gin を読み込んで、以下のように実行します。

!export PYTHONPATH=${PYTHONPATH}:.:./t5x && \
  BUCKET="gs://somewhere" && \
  INFER_OUTPUT_DIR="${BUCKET}/t5x/t5_1.1_base_snow_fp32/infer2" && \
  CHECKPOINT_PATH="${BUCKET}/t5x/t5_1.1_base_snow_fp32/checkpoint_526288" && \
  T5X_DIR="./t5x" && \
  python ${T5X_DIR}/t5x/infer.py \
    --gin_file="./t5x_t5_1.1_base_infer_snow.gin" \
    --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
    --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
# Rewritten gin arg: --gin_bindings=INFER_OUTPUT_DIR = "gs://somewhere/t5x/t5_1.1_base_snow_fp32/infer2"
# Rewritten gin arg: --gin_bindings=CHECKPOINT_PATH = "gs://somewhere/t5x/t5_1.1_base_snow_fp32/checkpoint_526288"
# I0324 00:27:33.787390 139703101564800 resource_reader.py:50] system_path_file_exists:t5x/examples/t5/t5_1_1/base.gin
# I0324 00:27:33.787983 139703101564800 resource_reader.py:37] gin-config opened resource file:/content/t5x/t5x/examples/t5/t5_1_1/base.gin
# I0324 00:27:33.796423 139703101564800 resource_reader.py:50] system_path_file_exists:t5x/configs/runs/infer.gin
# I0324 00:27:33.797021 139703101564800 resource_reader.py:37] gin-config opened resource file:/content/t5x/t5x/configs/runs/infer.gin
# I0324 00:27:33.811678 139703101564800 gin_utils.py:63] Gin Configuration:
# from __gin__ import dynamic_registration
# ...
# I0324 00:35:16.663071 139703101564800 utils.py:740] Inference of batch [608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639] done.
# I0324 00:35:16.665302 139703101564800 utils.py:755] Inference of all batches done.
# I0324 00:35:16.673525 139703101564800 infer.py:469] chunk completed in 29.875761 seconds (21.422049 examples/sec).
# I0324 00:35:19.715288 139703101564800 infer.py:487] Checkpoint written to temporary location in 3.041550 seconds.
# I0324 00:35:19.715629 139696855222016 infer.py:421] Writing chunk 2 results to gs://somewhere/t5x/t5_1.1_base_snow_fp32/infer2/tmp-snow_t15_23-00000-of-00001/snow_t15_23-predict.jsonl-00000-of-00001-chunk00002
# I0324 00:35:23.881667 139696855222016 infer.py:427] Writing completed in 4.165986 seconds (153.625103 examples/sec).
# I0324 00:35:25.081416 139703101564800 infer.py:502] Finished inference for task 'snow_t15_23'.
# I0324 00:35:25.081631 139703101564800 infer.py:504] Waiting for chunk writes to complete.
# I0324 00:35:25.082005 139703101564800 infer.py:508] Merging chunk results.
# I0324 00:35:25.657400 139703101564800 infer.py:526] Results written to gs://somewhere/t5x/t5_1.1_base_snow_fp32/infer2/snow_t15_23-predict.jsonl-00000-of-00001.
# I0324 00:35:25.657627 139703101564800 infer.py:527] Deleting temporary files.
# I0324 00:35:27.043891 139703101564800 infer.py:536] DONE

推論結果を確認してみましょう。

!gsutil cp gs://somewhere/t5x/t5_1.1_base_snow_fp32/infer2/snow_t15_23-predict.jsonl-00000-of-00001 .
import json
with open("snow_t15_23-predict.jsonl-00000-of-00001", "r") as f:
  lines = f.readlines()
  lines = [line.strip() for line in lines]
  examples = [json.loads(line) for line in lines]

for example in examples[:100]:
  print("input: {} => predict: {}".format(example["inputs"]["inputs_pretokenized"], example["prediction"]))

# input: まあ当分はそれで間に合うだろう。 => predict: まあしばらくはそれで間に合うだろう。
# input: 私たちは死に直面した。 => predict: 私たちは死に直接当たることになった。
...
# input: その料理人は彼の信じられないほどの食欲に驚いた。 => predict: その料理人は彼の信じられないほどの食べたい気持ちに驚いた。

大丈夫そうですね。それでは次にエクスポートを。。。といきたいところなのですが、まだ実装されてないようです。 jax2tf17 というのがありますが、こちらも experimental のようですね。こちらの issue18 によると

We have a more direct way to convert from T5X to SavedModel if that’s what you’re interested in, but we haven’t prioritized open sourcing it thus far.

とのことなので、待っていたら公開してくれるかもしれません。前述の issue では Hugging Face Transformers に変換するスクリプトの話題もでているので興味のある方は試してみて下さい。

それでは、ここからは Prompt Tuning のお話です。

4. Prompt Tuning

Prompt Tuning の前に、まず prompt について少し説明します。

Prompt とは

BERT や T5 では事前学習済みモデルを起点として、下流タスク向けにモデルのパラメータを更新(Fine Tuning)してきましたが、より大規模なモデルである GPT-3 ではモデルのパラメータを固定し、推論時に出力を条件付けるテキスト(prompt)を与える few-shot learning という手法が用いられています。

以下は Few-Shot Learning と Fine Tuning を比較したイメージになります。タスクの定義(task description)の後に推論のサンプル(example)をいくつ入れるかで、Zero-shot, One-shot, Few-Shot のバリエーションが示されています。

few-shot

GPT-3 は現在 API として利用できるようになっており19、以下は Zero-shot で GPT-3 に翻訳をさせる例です20

prompt

GPT-3 は出だしのテキスト(prompt)を与えられると、その続きの文章を生成するように学習されたモデルですから、prompt(ここでは“Translate this … 1.")を工夫することによって、モデルのパラメータを更新することなく意図したタスクを実行することが出来ます。

近年は事前学習済みモデルが巨大化の一途をたどっているので、Fine Tuning にもそれなりの計算リソースが必要ですし、fine tuning をしてタスク単位にモデルのコピーを作るとデプロイ時の必要メモリ量に悩まされることにもなります。モデルのパラメータを固定して推論時の prompt で対応できれば、これらの悩みを回避できるので、昨年あたりから注目を集めていた訳ですね。

Prompt Design

さて、意図したとおりの推論を行わせるには、どのような prompt を入力するかが重要になってきます。これが Prompt Design で前述の GPT-3 の API ではこちら21にそのノウハウが記載されています。基本は "Show and Tell” で AI にして欲しい内容を記載(task description)し、その実例を見せる(example)ことになります。

ただ Prompt Design は人手の介在が必要で、最適な記述をするには試行錯誤が必要になります。また prompt を構成する単語(もしくはサブワード)はモデルが認識する有限の語彙集合から選択しなければなりません。

ここで、ようやく Prompt Tuning の話になります。

Prompt Tuning の仕組み

Prompt Tuning はアイデアとしては簡単です。以下は T5 に “I like fruits.” を入力してポジ/ネガ判定するタスクを Fine Tuning と Prompt Tuning で比較したイメージです。暖色系はパラメータが更新可能であること、寒色系はパラメータ固定であることを示します。

prompt_tuning

“I like fruits.” の入力が [“I”, “like”, “fruits”, “.”] にトークナイズされ、各トークンがモデルのパラメータの一部である埋め込み表現に置き換えられるところまでは同じです。

Fine Tuning では、暖色で示したトークンの埋め込み表現と Transformer のモデル全体が更新対象となります。

ところが Prompt Tuning では、寒色系で示したようにモデル全体のパラメータを固定します。その代わり、入力テキストの埋め込み表現の系列の先頭に prompt として固定長22の埋め込み表現の系列を結合します。暖色で示したとおり、この prompt 部分は更新可能なパラメータになっており、誤差逆伝播で更新されます。

このようにして、prompt の最適な埋め込み表現を固定の語彙(に対応する埋め込み表現)に縛られることなく、自由かつ自動的に決定することができます。モデルは固定なので別タスクを解くときは prompt だけ学習しなおして差し替えればOKですよね。

Prompt Tuning についてもう少しだけ補足しますね。

Language Model Adaptation

Prompt Designの節で prompt を工夫して GPT-3 に「その後に自然に続く文章を生成」させる旨の説明をしました。 GPT-3 は事前学習段階からそういう学習をしているので良いのですが、 T5 の事前学習は Span Corruption と呼ばれる文章の穴埋め問題です。ですので、事前学習が終わった直後の T5 は「その後に続く自然な文章を生成」したことが一度もありません。いくら prompt を最適化すると言っても事前学習済みモデルがコレでは厳しそうです。というか厳しかったそうです。

pretrain_task

上図は事前学習タスク別にモデルサイズを変化させながら SuperGLUE のスコアをプロットしたものです。Span Corruption で学習した T5 の素の事前学習済みモデルは、両端(Small と XXL)はともかく真ん中 3 つ(Base, Large, XL)はかなり厳しい結果(青線)になっているのが見て取れます。

これに対応する為、論文では Span Corruption での事前学習後に継続して Language Model(入力テキストの続きをモデルに出力させる)タスクである程度(論文では10万ステップ)学習させる Language Model Adaptation を行っています。上図からは、かなり改善(緑線)していることが見て取れます。

prompt の初期化

prompt は学習パラメータですので何かしらで初期化しないといけません。論文によると事前学習済みモデルに含まれるトークン(特に出現頻度が高いもの)の埋め込み表現からサンプリングする(橙線)と一様分布で初期化したもの(青線)よりも良い結果が出たようです23

prompt_initialization

Fine Tuning との比較

最後に Prompt Tuning と Fine Tuning の比較も紹介しておきます。

prompt_vs_finetuning

Prompt Tuning(緑線)に対して Fine Tuning(赤線はタスク毎に、橙線はマルチタスクで学習)が優勢ですが、モデルサイズが大きくなるにつれて差が小さくなり、XXL(右端)では追いついていますね。 Prompt Tuning を使いたい局面は、モデルサイズが大きくて Fine Tuning でモデルのコピーが出来るのがツラいケースなので、これは喜ばしい傾向だと言えます。

本記事では Base サイズ(左から二番目)で実験するので、ちょうど一番差が大きくなっている所です。少し悲しいですが、下流タスクによるかも知れませんし、どの程度の精度劣化になるのか見てみることにしましょう。

ただ Base サイズなら普通に Fine Tuning して精度を確保した上で、そこを起点に各ユーザ向けに Prompt Tuning することで、基本は同じタスクだけれどユーザ毎に少しづつ違うモデルを提供するというようなことが出来るかもしれませんね。

では、 Prompt Tuning を動かしてみましょう。

5. Prompt Tuning の実験

まず前章で記述したように Language Model Adaptation を行う必要があります。

LM Adaptation の実行

3章に記載したとおり、 T5X を Colab の TPU で動かすことが出来ていないので、 T5 を使います。 新たに Colab のノートブックを開いて、アクセラレータに TPU を選んでください。

セットアップ

必要なライブラリをセットアップします。

!pip install t5[gcp]==0.9.3 tensorflow==2.8.*
!apt-get install -y libmagic1
!pip install xtract
!pip install tensorflow-datasets==4.2.0
!pip install numpy --upgrade --ignore-installed
# WARNING: The following packages were previously imported in this runtime:
#  [numpy]
# You must restart the runtime in order to use newly installed versions.

tensorflow==2.8.* の縛りを入れているのは、コレが無いと tensorflow が 2.9.1 に差し替えられてしまい、 後続の auth.authenticate_user() のところで tensorflow_gcs がらみのエラーになったので、その回避策です。

インポート済みの numpy が更新された旨の警告がでたので、ここでランタイムを再起動することにしました。

Prefix LM タスクを登録する

再起動後に GCS にアクセスする認証を通します。2022年3月末頃に auth.authenticate_user() の挙動が変わり、 USE_AUTH_EPHEM0 にしておかないと Colab の TPU から GCS にアクセスできないようになってしまったので、 その対応を入れています24

import os
os.environ["USE_AUTH_EPHEM"] = '0'
from google.colab import auth
auth.authenticate_user()

必要なモジュールやクラスをインポートして、

import functools
from t5.data import preprocessors
from t5.data.dataset_providers import TfdsTask
from t5.data.utils import DEFAULT_EXTRA_IDS
import seqio
from seqio import Feature
from seqio import TaskRegistry
from seqio import vocabularies

次にタスクを登録します。

タスクの内容ですが、論文には「 LM Adaptation を 100K ステップ実行した」とだけ記載されていて詳細が分かりませんでした。 仕方がないので、公開されている LM Adaptation 済みのチェックポイントの内容をチェックすることにします。

!gsutil cp gs://t5-data/pretrained_models/t5.1.1.lm100k.base/operative_config.gin .

まずは実際に動かしたタスクを確認します。

!cat operative_config.gin | grep -e "^MIXTURE_NAME" 
# MIXTURE_NAME = 'c4_v220_prefix_lm'

ここ25に定義がありました。これをベースに読み込むコーパスと Sentencepiece のモデルを修正して以下のようにしました。

task_name = "wiki_ja_prefix_lm"
SPM_PATH = "gs://somewhere/t5/sentencepiece/sp.model"

TaskRegistry.add(
    task_name,
    source=seqio.TfdsDataSource(tfds_name="wikipedia/20190301.ja:1.0.0"),
    preprocessors=[
        functools.partial(
            preprocessors.rekey, key_map={
                "inputs": None,
                "targets": "text"
            }),
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        preprocessors.prefix_lm,
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features = {
        "inputs": Feature(vocabularies.SentencePieceVocabulary(
           SPM_PATH, DEFAULT_EXTRA_IDS)),
        "targets": Feature(vocabularies.SentencePieceVocabulary(
           SPM_PATH, DEFAULT_EXTRA_IDS))},
    metric_fns=[])
お断り

上記の例は(簡単に試せるように)Wikipedia 日本語版を使った例で記載していますが、以降の処理結果は筆者が自前で作った日本語データセットを使って実験したものになります。

次に T5 のソースコードの中に prefix_lm.gin26 という気になる名前にファイルがあるので、そちらを確認しました。

# cited from https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/models/gin/objectives/prefix_lm.gin#L16-L19
16: preprocessors.denoise.noise_mask_fn = @preprocessors.random_prefix_noise_mask
17: preprocessors.denoise.noise_density = 0.5
18: preprocessors.denoise.inputs_fn = @preprocessors.drop_nonnoise_tokens
19: preprocessors.denoise.targets_fn = @preprocessors.drop_noise_tokens

noise_density0.5 を設定しています。先ほどの operative_config.gin をチェックすると、

!cat operative_config.gin | grep noise_density
# noise_density = 0.15
# random_spans_helper.noise_density = %noise_density

こちらは 0.15 なので話が合いませんね。先程のタスクの定義の中で出てきた prefix_lm() をチェックしてみます。

# cited from https://github.com/google-research/text-to-text-transfer-transformer/blob/1084db9477e443e3783c05119da4741500a0d4ff/t5/data/preprocessors.py#L2000-L2015
2000: def prefix_lm(dataset, sequence_length, output_features):
2001:   """Prefix language modeling objective used in Raffel et al. 2019."""
2002:   ds = dataset
2003:   ds = select_random_chunk(ds, output_features=output_features,
2004:                            feature_key='targets', max_length=65536)
2005:   ds = split_tokens_to_inputs_length(ds, output_features=output_features,
2006:                                      sequence_length=sequence_length)
2007:   ds = denoise(
2008:       ds,
2009:       output_features,
2010:       inputs_fn=drop_nonnoise_tokens,
2011:       targets_fn=drop_noise_tokens,
2012:       noise_density=0.5,
2013:       noise_mask_fn=random_prefix_noise_mask,
2014:   )
2015:   return ds

2012 行目を見ると、ベタ書きで 0.5 を設定していますね。ということは先程の prefix_lm.gin は別物なのでしょう。今回は使わないことにします。 operative_config.ginnoise_density = 0.15 なのは Span Corruption で事前学習したときの名残なのではないかと思います。

系列長も気になったので確認します。

!cat operative_config.gin | grep run.sequence_length
# run.sequence_length = {'inputs': 1024, 'targets': 256}

Span Corruption では {'inputs': 512, 'targets': 114} だったので、入・出力とも系列長を伸ばしたようですね。 それとバッチサイズも見ておきましょう。

!cat operative_config.gin | grep batch_size
# run.batch_size = ('tokens_per_batch', 1048576)
# tpu_estimator_model_fn.outer_batch_size = 1

これで大体把握できた気がします。

TPU アドレスの確認

TPU のアドレスを確認しておきます。

import os
import pprint
import json
import tensorflow.compat.v1 as tf

assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print('TPU address is', TPU_ADDRESS)
# TPU address is grpc://10.70.100.178:8470

事前学習済みモデルのコピー

T5 作成した Span Corruption による事前学習済みモデルが gs://somewhere/t5/pre_trained_t5_1.1 にあるものとして記述します。

  • 肝心の事前学習済みモデルですが、これまでの連載の過程で作ったものがお手元にない方は第14回の 2 章を参考にして作って下さい(すみません)。

ミスって事前学習済みモデルを消したりすると厄介なので、念のため別フォルダにコピーしておきましょう。

!gsutil cp gs://somewhere/t5/pre_trained_t5_1.1/operative_config.gin gs://somewhere/t5/pre_trained_t5_1.1-LM/
!gsutil cp gs://somewhere/t5/pre_trained_t5_1.1/graph.pbtxt gs://somewhere/t5/pre_trained_t5_1.1-LM/
!gsutil cp gs://somewhere/t5/pre_trained_t5_1.1/checkpoint gs://somewhere/t5/pre_trained_t5_1.1-LM/
!gsutil cp gs://somewhere/t5/pre_trained_t5_1.1/model.ckpt-524288* gs://somewhere/t5/pre_trained_t5_1.1-LM/

学習の実行

準備が出来たので以下のようにして実行します。train_steps はステップ 524288 を起点に+100K ステップとし、 sequence_lengthbatch_size には先程確認した値を設定しています。

from t5.models.mesh_transformer_main import FLAGS
from t5.models.mesh_transformer_main import main
import os

FLAGS.mark_as_parsed()

FLAGS.tpu = TPU_ADDRESS
FLAGS.model_dir = 'gs://somewhere/t5/pre_trained_t5_1.1-LM'
tf.flags.FLAGS.gin_file=[os.path.join(FLAGS.model_dir, "operative_config.gin")]
tf.flags.FLAGS.gin_param=[
  "utils.tpu_mesh_shape.model_parallelism = 1",
  "utils.tpu_mesh_shape.tpu_topology = 'v2-8'",
  "utils.run.sequence_length = {'inputs': 1024, 'targets': 256}",
  "run.batch_size = ('tokens_per_batch', 1048576)",
  "run.train_steps = 624288",
  "run.save_checkpoints_steps = 500",
  "MIXTURE_NAME = 'wiki_ja_prefix_lm'"
]

tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.INFO)
main([])

# WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/compat/v2_compat.py:107: disable_resource_variables (from # tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
# Instructions for updating:
# non-resource variables are not supported in the long term
# ERROR:root:Path not found: gs://somewhere/t5/pre_trained_t5_1.1-LM/operative_config.gin
# 絶対ダメそうなエラー(↑)ですが、大丈夫っぽいです。
# INFO:tensorflow:model_type=bitransformer
# INFO:tensorflow:mode=train
# INFO:tensorflow:sequence_length={'inputs': 1024, 'targets': 256}
# INFO:tensorflow:batch_size=1024
# INFO:tensorflow:train_steps=624288
# 本当はなぜか二行ずつ同じ行がでる(↑)のですが、一行にしてます。
# ...

上記の設定だと、とても Colab のランタイムの寿命には収まらないので、Cloud TPU を使うなり、気長に繰り返すなり、学習量やバッチサイズを妥協するなりして下さい。。。論文2の図 3.(d) からすると、ステップ数を 1/10 の 10K にしても得られる効果の 9 割程度は確保できるようです。

あと、結構な量の checkpoint ができるので、適宜手動で削除するなどして下さい

では、いよいよ Prompt Tuning です!

Prompt Tuning の学習

まずは、環境のセットアップからです。

セットアップ

新しくノートブックを開いてもらって、アクセラレータに GPU を選択してください。

まずは prompt-tuning のコードをインストールします。

!git clone --branch=main https://github.com/google-research/prompt-tuning
!cd prompt-tuning && git checkout 279e53e88c6268fd2dcf903cf6f3e949b5fff119
!cd prompt-tuning && python -m pip install . --use-deprecated=legacy-resolver
# ...
# ERROR: pip's legacy dependency resolver does not consider dependency conflicts when selecting packages. This behaviour is the source of the following dependency conflicts.
# tensorflow-text 2.9.0 requires tensorflow<2.10,>=2.9.0; platform_machine != "arm64" or platform_system != "Darwin", but you'll have tensorflow 2.8.2+zzzcolab20220527125636 which is incompatible.
# ...

なにやら tensorflow-text のバージョンが Colab の tensorflow と合わないようなので、差し替えます。

!pip install tensorflow-text==2.8.2

T5X の commit を T5X のファインチューニングで使用したものに入れ替えて、

!pip uninstall -y t5x
!pip install git+https://github.com/google-research/t5x@df1ee3f3d5fd63ad0959e110cdfa46383356c99b

T5X のセットアップのところで説明した理由で flax も入れ替えてます。

!pip uninstall -y flax
!pip install git+https://github.com/google/flax@9bd65b20752e7bfc172796e66948b6c216405b9b

このままだと、検証時に AttributeError: 'Evaluator' object has no attribute 'model_feature_shapes' というエラーになったので、 SeqIO も入れ替えました。

!pip uninstall -y seqio
!pip install git+https://github.com/google/seqio@9748501b00d707a482fdd07ea5bda1caa46fca8c

日本語の BLEU スコアの計算のエクステンションを追加します。

!pip install sacrebleu[ja]

commit を指定して差し替えているので、バージョンを確認しても仕方ない気もしますが、こんな感じです。

!pip list | grep -e t5x -e jax -e flax -e prompt-tuning
# flax                          0.4.0
# flaxformer                    0.4.2
# jax                           0.3.8
# jaxlib                        0.3.7+cuda11.cudnn805
# prompt-tuning                 0.1.0
# t5x                           0.0.0

ここでランタイムを再起動しました。

再起動後に再び認証を通しておきます。

import os
os.environ["USE_AUTH_EPHEM"] = '0'
from google.colab import auth
auth.authenticate_user()

タスクは T5X に揃えて「やさしい日本語変換」とし、精度劣化の程度を確認することにします。 学習データを取得し直しておきましょう。

!gsutil cp gs://somewhere/work/snow_t15_23_*.tsv .

LM Adaptation したモデルを T5X 形式にコンバート

LM Adaptation 済みのモデルを T5X の検証のときと同じ要領で変換します。

!export PYTHONPATH=${PYTHONPATH}:.:./t5x && \
CUDA_VISIBLE_DEVICES="" && \
python -m t5x.scripts.convert_tf_checkpoint \
 --gin_file="t5x/examples/t5/t5_1_1/base.gin" \
 --gin.convert_checkpoint.model=%MODEL \
 --gin.DROPOUT_RATE="0.0" \
 --gin.convert_checkpoint.tf_checkpoint_path=\"gs://somewhere/t5/pre_trained_t5_1.1-LM/model.ckpt-624288\" \
 --gin.convert_checkpoint.output_dir=\"gs://somewhere/t5x/pre_trained_t5_1.1-LM_native" \
 --logtostderr

タスクの登録

タスクの登録は T5X の時と同じですので、ここを実行して snow_task.py を生成しておいて下さい。

gin ファイルの定義

gin ファイルの定義は以下のようになります。

%%bash
cat << EOF > t5x_t5_1.1_base_prompt_snow.gin
from __gin__ import dynamic_registration
import t5.data.mixtures
import __main__ as train_script
from t5x import utils
from t5x import models
from t5x import trainer
import seqio
import snow_task # 1.

include 'prompt_tuning/configs/runs/prompt_finetune.gin' # 2.

seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://somewhere/t5/sentencepiece/sp.model"
seqio.SentencePieceVocabulary.extra_ids = 100

ACTIVATION_DTYPE = 'float32'

# 32 x 64 = 2048
LOSS_NORMALIZING_FACTOR = 2048
BATCH_SIZE = 32 # 3.
MIXTURE_OR_TASK_NAME = "snow_t15_23"
TASK_FEATURE_LENGTHS = {'inputs': 64, 'targets': 64}
USE_CACHED_TASKS = False

train_script.train:
  eval_period = 400

utils.SaveCheckpointConfig:
  period = 400

EOF

上記の設定でいくつかポイントになるところを説明します。

  1. : タスクを import します。snow_task は T5X のファインチューニングで使った物と同じです。
  2. : Prompt Tuning で学習するための設定です。この中で T5X のファインチューニングで使った finetune.gin を include しています。
  3. : バッチサイズは論文2の“3. Results"の記述に合わせました。

学習の実行

以下のようにして実行です。

!export PYTHONPATH=${PYTHONPATH}:.:./t5x && \
  BUCKET="gs://somewhere" && \
  MODEL_DIR="${BUCKET}/t5x/t5_1.1_base_snow_prompt" && \
  INITIAL_CHECKPOINT_PATH="${BUCKET}/t5x/pre_trained_t5_1.1-LM_native/checkpoint_624288" && \
  TRAIN_STEPS=628288 && \
  T5X_DIR=`python -m prompt_tuning.scripts.find_module t5x`/.. && \
  FLAXFORMER_DIR=`python -m prompt_tuning.scripts.find_module flaxformer`/.. && \
  PROMPT_DIR=`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/.. && \
  python ${T5X_DIR}/t5x/train.py \
    --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" \
    --gin_file="prompt_tuning/configs/models/t5_1_1_base_prompt.gin" \
    --gin_file="prompt_tuning/configs/prompts/from_sampled_vocab.gin" \
    --gin_file="./t5x_t5_1.1_base_prompt_snow.gin" \
    --gin.MODEL_DIR=\"${MODEL_DIR}\" \
    --gin.INITIAL_CHECKPOINT_PATH=\"${INITIAL_CHECKPOINT_PATH}\" \
    --gin.TRAIN_STEPS=${TRAIN_STEPS} \

ここも少し補足します。

  • TRAIN_STEPS=628288:
    LM Adaptation 終了時点の 624288 を起点に 4000 ステップ追加で学習します。論文2の"3. Results"では 30000 ステップでしたが、ファインチューニングしたときは 2000 で結果がでたので、これくらいで十分だと思います。
  • --gin_file="prompt_tuning/configs/models/t5_1_1_base_prompt.gin":
    T5 1.1 Base の構造の定義と入力に prompt を結合する設定です。T5 1.1 の構造の設定に関しては T5X の方は t5x/examples/t5/t5_1_1/base.gin を使用していましたが、Prompt Tuning では prompt_tuning/configs/architectures/prompt_encoder_t5_1_1_flaxformer.gin と flaxformer27 をベースにしたものになっています。内容的にはほぼ等価なのでしょうが、なぜ出所が異なるのかまでは確認できていません。
  • --gin_file="prompt_tuning/configs/prompts/from_sampled_vocab.gin":
    prompt の初期化の設定です。タスクは分類問題ではないので、こちらで説明したように使用頻度の高い語彙の埋め込み表現からサンプリングします。

ところが、もう少しで終了というところで Segmentation fault を起こして落ちてしまいました。。。

I0426 04:24:59.985513 140372512053120 train.py:550] END Train loop.
I0426 04:25:00.985516 140367055963904 logging_writer.py:48] [627888] collection=train accuracy=0.738969, cross_ent_loss=0.196093, cross_ent_loss_per_all_target_tokens=0.000096, learning_rate=0.300001, learning_rate/current=0.30000001192092896, loss=0.201769, loss_per_all_target_tokens=0.000099, loss_per_nonpadding_target_token=0.000757, nonpadding_fraction=0.130114, timing/seconds=810.619023, timing/seqs=12800, timing/seqs_per_second=15.790402, timing/seqs_per_second_per_core=15.790402, timing/steps_per_second=0.493450, timing/target_tokens_per_second=1010.585709, timing/target_tokens_per_second_per_core=1010.585709, z_loss=0.005675, z_loss_per_all_target_tokens=0.000003
I0426 04:25:00.985788 140372512053120 train.py:565] Saving checkpoint.
I0426 04:25:00.989233 140372512053120 utils.py:107] Saving Numpy checkpoints for step 627888 to gs://somewhere/t5x/t5_1.1_base_snow_prompt/numpy_checkpoints/checkpoint_627888.tmp-1650947100
I0426 04:25:03.601961 140372512053120 utils.py:138] Saved Numpy Arrays for step 627888 to gs://somewhere/t5x/t5_1.1_base_snow_prompt/numpy_checkpoints/checkpoint_627888
I0426 04:25:03.684093 140372512053120 checkpoints.py:631] Saving checkpoint for step 627888 to gs://somewhere/t5x/t5_1.1_base_snow_prompt/checkpoint_627888.tmp-1650947103
Fatal Python error: Segmentation fault

Thread 0x00007fa9caab8700 (most recent call first):
  File "/usr/lib/python3.7/concurrent/futures/thread.py", line 78 in _worker
  File "/usr/lib/python3.7/threading.py", line 870 in run
  File "/usr/lib/python3.7/threading.py", line 926 in _bootstrap_inner
  File "/usr/lib/python3.7/threading.py", line 890 in _bootstrap

どうも 627888 のチェックポイントを書き出すところで落ちたようなので、その手前から再開します。 ちなみに検証や推論のときも何度か落ちたのですが、ランタイムを再起動して再実行したらなんとかなりました。

先程から以下のように変更し、途中から再開します。

  • --gin_file="prompt_tuning/configs/prompts/from_file.gin":
    from_sampled_vocab.gin ではなく、from_file.gin を指定して、ファイルに保存された numpy 配列で初期化を行います。
  • --gin.PROMPT_FILE=\"${PROMPT_FILE}:
    from_file.gin が読み込む numpy 配列のファイルです。今回は直前のチェックポイント(627488)で生成されたファイル(encoder.prompt.prompt.prompt)を指定しました。
!export PYTHONPATH=${PYTHONPATH}:.:./t5x && \
  BUCKET="gs://somewhere" && \
  MODEL_DIR="${BUCKET}/t5x/t5_1.1_base_snow_prompt" && \
  INITIAL_CHECKPOINT_PATH="${BUCKET}/t5x/pre_trained_t5_1.1-LM_native/checkpoint_624288" && \
  TRAIN_STEPS=628288 && \
  PROMPT_FILE="${BUCKET}/t5x/t5_1.1_base_snow_prompt/numpy_checkpoints/checkpoint_627488/encoder.prompt.prompt.prompt" && \
  T5X_DIR=`python -m prompt_tuning.scripts.find_module t5x`/.. && \
  FLAXFORMER_DIR=`python -m prompt_tuning.scripts.find_module flaxformer`/.. && \
  PROMPT_DIR=`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/.. && \
  python ${T5X_DIR}/t5x/train.py \
    --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" \
    --gin_file="prompt_tuning/configs/models/t5_1_1_base_prompt.gin" \
    --gin_file="prompt_tuning/configs/prompts/from_file.gin" \
    --gin_file="./t5x_t5_1.1_base_prompt_snow.gin" \
    --gin.MODEL_DIR=\"${MODEL_DIR}\" \
    --gin.INITIAL_CHECKPOINT_PATH=\"${INITIAL_CHECKPOINT_PATH}\" \
    --gin.TRAIN_STEPS=${TRAIN_STEPS} \
    --gin.PROMPT_FILE=\"${PROMPT_FILE}\" \

# ...
# I0427 05:52:10.227691 140377481688960 utils.py:768] Inference of all batches done.
# I0427 05:52:17.289286 140377481688960 evaluation.py:477] Time waiting for previous metrics run: 0.000031 secs.
# I0427 05:52:17.289782 140377481688960 train.py:587] Finished.
# I0427 05:52:17.290004 140372021327616 logging_writer.py:48] [628288] collection=train timing/evaluate_seconds=458.502212
# I0427 05:52:17.290227 140372208850688 evaluation.py:525] Computing metrics for snow_t15_23
# I0427 05:52:20.294728 140372208850688 loggers.py:89] snow_t15_23/bleu at step 628288: 68.345
# I0427 05:52:20.679590 140372208850688 loggers.py:354] Appending metrics to gs://somewhere/t5x/t5_1.1_base_snow_prompt/inference_eval/snow_t15_23-metrics.jsonl
# I0427 05:52:21.481836 140372208850688 loggers.py:382] Writing inferences to gs://somewhere/t5x/t5_1.1_base_snow_prompt/inference_eval/snow_t15_23-624388.jsonl
# I0427 05:52:24.464202 140372208850688 loggers.py:415] Writing completed in 2.982369 seconds (0.670608 examples/sec).
# I0427 05:52:24.464763 140372208850688 evaluation.py:483] Time computing metrics: 7.174546 secs.   

どうやら、一応最後まで動いたようです。学習曲線を見てみました。

学習曲線

学習曲線は以下のようになりました。

prompt_learning_curve

ではテストデータでの精度を確認してみましょう。

検証の実行

学習した prompt を使ってテストデータでの精度を見てみましょう。gin ファイルは以下のようになります。

%%bash
cat << EOF > t5x_t5_1.1_base_eval_snow_prompt.gin
from __gin__ import dynamic_registration
import __main__ as eval_script
import t5.data.mixtures
from t5x import utils
from t5x import partitioning
from t5x import models
import seqio
import snow_task

include 'prompt_tuning/configs/runs/prompt_eval.gin' 

seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://somewhere/t5/sentencepiece/sp.model"
seqio.SentencePieceVocabulary.extra_ids = 100

ACTIVATION_DTYPE = 'float32'

DROPOUT_RATE = 0.0

MIXTURE_OR_TASK_NAME = "snow_t15_23"
EOF

include するファイルが prompt_eval.gin になる以外は特に説明いらなさそうですね。

先程の学習曲線ではステップ 654288 が最良だったので、以下のようにして実行です。

!export PYTHONPATH=${PYTHONPATH}:.:./t5x && \
  BUCKET="gs://somewhere" && \
  EVAL_OUTPUT_DIR="${BUCKET}/t5x/t5_1.1_base_snow_prompt/eval" && \
  CHECKPOINT_PATH="${BUCKET}/t5x/pre_trained_t5_1.1-LM_native/checkpoint_624288" && \
  PROMPT_FILE="${BUCKET}/t5x/t5_1.1_base_snow_prompt/numpy_checkpoints/checkpoint_654288/encoder.prompt.prompt.prompt" && \
  T5X_DIR=`python -m prompt_tuning.scripts.find_module t5x`/.. && \
  FLAXFORMER_DIR=`python -m prompt_tuning.scripts.find_module flaxformer`/.. && \
  PROMPT_DIR=`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/.. && \
  python ${T5X_DIR}/t5x/eval.py \
    --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" \
    --gin_file="prompt_tuning/configs/models/t5_1_1_base_prompt.gin" \
    --gin_file="./t5x_t5_1.1_base_eval_snow_prompt.gin" \
    --gin.EVAL_OUTPUT_DIR=\"${EVAL_OUTPUT_DIR}\" \
    --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
    --gin.PROMPT_FILE=\"${PROMPT_FILE}\" \

# I0324 02:00:24.871591 140214831568768 utils.py:159] NumExpr defaulting to 2 threads.
# Rewritten gin arg: --gin_bindings=EVAL_OUTPUT_DIR = "gs://somewhere/t5x/t5_1.1_base_snow_prompt/eval"
# Rewritten gin arg: --gin_bindings=CHECKPOINT_PATH = "gs://somewhere/t5x/pre_trained_t5_1.1-LM_native/checkpoint_624288"
# Rewritten gin arg: --gin_bindings=PROMPT_FILE = "gs://somewhere/t5x/t5_1.1_base_snow_prompt/numpy_checkpoints/checkpoint_654288/encoder.prompt.prompt.prompt"
# ...
# I0426 07:14:03.032807 140710294812544 utils.py:768] Inference of all batches done.
# I0426 07:14:13.796017 140704905778944 evaluation.py:525] Computing metrics for snow_t15_23
# I0426 07:14:16.424289 140704905778944 loggers.py:89] snow_t15_23/bleu at step 624288: 68.184
# I0426 07:14:17.841377 140704905778944 loggers.py:354] Appending metrics to gs://somewhere/t5x/t5_1.1_base_snow_prompt/eval/inference_eval/snow_t15_23-metrics.jsonl
# I0426 07:14:18.503051 140704905778944 loggers.py:382] Writing inferences to gs://somewhere/t5x/t5_1.1_base_snow_prompt/eval/inference_eval/snow_t15_23-624288.jsonl
# I0426 07:14:21.328506 140704905778944 loggers.py:415] Writing completed in 2.825474 seconds (0.707846 examples/sec).
# I0426 07:14:21.328948 140704905778944 evaluation.py:483] Time computing metrics: 7.532975 secs.
# I0426 07:14:21.332164 140710294812544 eval.py:177] Finished.

検証時のポイントは以下になります。

  • CHECKPOINT_PATH="${BUCKET}/t5x/pre_trained_t5_1.1-LM_native/checkpoint_624288"
    検証時に使うチェックポイントが LM Adaptation 終了時点のものであることに注意して下さい。Prompt Tuning ではモデルを更新しないので当然と言えば当然ですね。
  • PROMPT_FILE="${BUCKET}/t5x/t5_1.1_base_snow_prompt/numpy_checkpoints/checkpoint_654288/encoder.prompt.prompt.prompt"
    学習した Prompt は numpy のアレイを保存したものになっています。これを読み込んで入力の先頭に結合する訳ですね。

BLEU スコアは 68.184 になってしまいました。

fine_tuning_vs_prompt_tuning

Fine Tuning との比較 で示したとおり、Base サイズだと Fine Tuning には及ばない結果となりました。 ただ、LM Adaptation なしの事前学習済みモデルで試したときは BLEU スコアが 0.0 だったので、これでもだいぶマシにはなりました。

では、最後に実際に推論させてみましょう。

推論の実行

gin ファイルは以下のようになります。

%%bash
cat << EOF > t5x_t5_1.1_base_infer_snow_prompt.gin
from __gin__ import dynamic_registration
import __main__ as infer_script
from t5x import utils
from t5x import partitioning
from t5x import models
import seqio
import snow_task

include 'prompt_tuning/configs/runs/prompt_infer.gin'

seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://somewhere/t5/sentencepiece/sp.model"
seqio.SentencePieceVocabulary.extra_ids = 100

ACTIVATION_DTYPE = 'float32'

DROPOUT_RATE = 0.0

MIXTURE_OR_TASK_NAME = "snow_t15_23"
TASK_FEATURE_LENGTHS = {'inputs': 64, 'targets': 64}

infer_script.infer:
  restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()

utils.RestoreCheckpointConfig:
  path = %CHECKPOINT_PATH
  mode = 'specific'
  dtype = 'float32'
EOF

include するファイルが prompt_infer.gin になった以外は検証と大きな違いはありません。

utils.RestoreCheckpointConfig の設定はこれを追記しておかないと dtypebloat16 になってしまったので、 念のため追記しました。

以下のようにして実行します。

!export PYTHONPATH=${PYTHONPATH}:.:./t5x && \
  BUCKET="gs://somewhere" && \
  INFER_OUTPUT_DIR="${BUCKET}/t5x/t5_1.1_base_snow_prompt/infer" && \
  CHECKPOINT_PATH="${BUCKET}/t5x/pre_trained_t5_1.1-LM_native/checkpoint_624288" && \
  PROMPT_FILE="${BUCKET}/t5x/t5_1.1_base_snow_prompt/numpy_checkpoints/checkpoint_654288/encoder.prompt.prompt.prompt" && \
  T5X_DIR=`python -m prompt_tuning.scripts.find_module t5x`/.. && \
  FLAXFORMER_DIR=`python -m prompt_tuning.scripts.find_module flaxformer`/.. && \
  PROMPT_DIR=`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/.. && \
  python ${T5X_DIR}/t5x/infer.py \
    --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" \
    --gin_file="prompt_tuning/configs/models/t5_1_1_base_prompt.gin" \
    --gin_file="./t5x_t5_1.1_base_infer_snow_prompt.gin" \
    --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
    --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
    --gin.PROMPT_FILE=\"${PROMPT_FILE}\" \
# I0426 07:25:58.851496 139880186759040 utils.py:159] NumExpr defaulting to 2 threads.
# Rewritten gin arg: --gin_bindings=INFER_OUTPUT_DIR = "gs://somewhere/t5x/t5_1.1_base_snow_prompt/infer"
# Rewritten gin arg: --gin_bindings=CHECKPOINT_PATH = "gs://somewhere/t5x/pre_trained_t5_1.1-LM_native/checkpoint_624288"
# Rewritten gin arg: --gin_bindings=PROMPT_FILE = "gs://somewhere/t5x/t5_1.1_base_snow_prompt/numpy_checkpoints/checkpoint_654288/encoder.prompt.prompt.prompt"    
# ...
# I0426 07:37:03.782565 139659424847744 utils.py:748] Inference of batch [608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
 626 627 628 629 630 631 632 633 634 635 636 637 638 639] done.
# I0426 07:37:03.785068 139659424847744 utils.py:768] Inference of all batches done.
# I0426 07:37:07.081478 139659424847744 infer.py:529] chunk completed in 48.218109 seconds (13.273022 examples/sec).
# I0426 07:37:09.977994 139659424847744 infer.py:545] Checkpoint written to temporary location in 2.896274 seconds.
# I0426 07:37:09.978313 139653178451712 infer.py:476] Writing chunk 2 results to gs://somewhere/t5x/t5_1.1_base_snow_prompt/infer/tmp-snow_t15_23-00000-of-00001/snow_t15_23-predict.jsonl-00000-of-00001-chunk00002
# I0426 07:37:14.300440 139653178451712 infer.py:482] Writing completed in 4.322091 seconds (148.076463 examples/sec).
# I0426 07:37:15.242524 139659424847744 infer.py:560] Finished inference for task 'snow_t15_23'.
# I0426 07:37:15.242832 139659424847744 infer.py:562] Waiting for chunk writes to complete.
# I0426 07:37:15.245240 139659424847744 infer.py:162] Merging chunk results.
# I0426 07:37:15.889412 139659424847744 infer.py:181] Results written to gs://somewhere/t5x/t5_1.1_base_snow_prompt/infer/snow_t15_23-predict.jsonl-00000-of-00001.
# I0426 07:37:15.889755 139659424847744 infer.py:568] Deleting temporary files.
# I0426 07:37:17.067812 139659424847744 infer.py:577] DONE

実行の仕方は eval.pyinfer.py に変わるだけで、特に追加で説明することはなさそうですね。

以下のようにして推論結果を確認します。

!gsutil cp gs://somewhere/t5x/t5_1.1_base_snow_prompt/infer/snow_t15_23-predict.jsonl-00000-of-00001 .
import json
with open("snow_t15_23-predict.jsonl-00000-of-00001", "r") as f:
  lines = f.readlines()
  lines = [line.strip() for line in lines]
  examples = [json.loads(line) for line in lines]

for example in examples[:100]:
  print("input: {} => predict: {}".format(example["inputs"]["inputs_pretokenized"], example["prediction"]))

# input: まあ当分はそれで間に合うだろう。 => predict: まあ当分はそれで間に合うだろう。
# input: 私たちは死に直面した。 => predict: 私たちは死に直面した。
# input: あなたはなぜ働いているの。 => predict: あなたはなぜ働いているの。
# ...

あれ?素通しで何も変わっていない?

そういえば、このデータセットは完全素通しでも BLEU = 66.04 になるんでした。。。でも、68.184 が出たので完全素通しになった訳でもなさそうです。

入出力が異なるサンプルだけ見てみましょう。

for example in examples:
  if example["inputs"]["inputs_pretokenized"] != example["prediction"]:
    print("input: {} => predict: {}".format(example["inputs"]["inputs_pretokenized"], example["prediction"]))

# input: 彼は自分の誤りにきづいていないようだ。 => predict: 彼は自分の間違いにきづいていないようだ。
# input: 君は馬鹿に違いない。 => predict: あなたは馬鹿に違いない。
# input: その服は君に似合っている。 => predict: その服はあなたに似合っている
# ...

微妙な感じですが、一応頑張って変換してはいるみたいですね。

15. おわりに

今回は、T5X を使って Prompt Tuning に挑戦してみました。今後の T5 系の論文は T5X で実装が公開されることになる気がしているので、 ここで使い方を覚えておくとよいのではないかと思います。次回は教師なしの文章ベクトル化を試してみようかと思っています。 文章ベクトル化は Sentence BERT を使っているのですが、教師なしで近い精度が出せたら特定領域への対応等で使い道がないのもかと。


  1. https://github.com/google-research/t5x 

  2. https://aclanthology.org/2021.emnlp-main.243/ 

  3. https://github.com/google-research/text-to-text-transfer-transformer : この連載では、これまでモデルとしての T5 を "T5"、その実装を "t5” と区別して表記していましたが、今回は両方(モデルと実装)を “T5” として表記してます。T5X も (モデルとしての)T5 の実装なので片方だけ小文字表記なのが気持ち悪くて。かと言って “t5x” と小文字表記にするのも何か違う気がして。。。 

  4. https://github.com/tensorflow/mesh 

  5. https://github.com/google-research/prompt-tuning 

  6. https://github.com/google/jax 

  7. https://github.com/google/flax 

  8. https://jax.readthedocs.io/en/latest/notebooks/quickstart.html 

  9. https://huggingface.co/docs/transformers/index#supported-frameworks 

  10. https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html 

  11. https://github.com/google/seqio SeqIO は系列データを処理するライブラリで t5.data から切り出されたものになります。tf.data.Dataset を使っていはいますが、出力を numpy に変換できるので、 PyTorch や JAX と組み合わせて使うことも可能です。  

  12. https://github.com/google/gin-config 

  13. https://github.com/google/jax/issues/8300 

  14. https://github.com/google-research/t5x/issues/214#issuecomment-1028135286 

  15. https://github.com/tensorflow/mesh/blob/6b31c0fc9daf185aae2422976487f8db08fc7369/mesh_tensorflow/transformer/transformer.py#L725 

  16. 実際に事前学習時のバッチサイズとシーケンス長を指定してみるとスコアが下がりました。 

  17. https://github.com/google/jax/tree/main/jax/experimental/jax2tf 

  18. https://github.com/google-research/t5x/issues/198 

  19. https://openai.com/api/ 

  20. 上下の図で “prompt” の範囲が違っていますが、基本的には下図のイメージだと考えてよいと思います。 

  21. https://beta.openai.com/docs/guides/completion/prompt-design 

  22. 図中では長さ = 3 ですが、論文2の 3.2 によると 20 程度が良いようです。 

  23. 緑線は分類問題を解く際の prompt にラベル文字列(ポジ/ネガ分類なら “positive” や “negative")に相当するトークンの埋め込み表現を使ったもの。ラベル文字列が複数トークンになるときはその平均。prompt の長さに対して、ラベル文字列の種類が少なく prompt が余ったときは橙線と同じ初期化で埋めてます。 

  24. この記事が公開されるころには、もう修正されているかもしれませんが。。。  

  25. https://github.com/google-research/text-to-text-transfer-transformer/blob/98964752c9756478203b876255620241f3e2b502/t5/data/tasks.py#L84-L99 

  26. https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/models/gin/objectives/prefix_lm.gin 

  27. https://github.com/google/flaxformer flaxformer は JAX/Flax で記述された Transformer のライブラリです。