オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第27回 DeepSpeed-Chat による RLHF の紹介
オージス総研 技術部 データエンジニアリングセンター
鵜野 和也
2023年6月29日

今回は DeepSpeed-Chat による RLHF のご紹介です。正直、データセットや計算資源の都合もあり、とりあえず動かしてみました!的な話にはなりますが、RLHF の効果が実際に確認できるか見てみたいと思います。

1. はじめに

今回は DeepSpeed-Chat1 を使って RLHF を試してみたいと思います。RLHF は Reinforcement Learning from Human Feedback の略で文字通り「人からのフィードバックを用いた強化学習」ということですね。OpenAI が InstructGPT(ChatGPT の元になったモデル)2 で使ったことで注目された手法になります。

LLM がらみで何か記事にしたいと思いつつ、日々新たな LLM が発表されている昨今に、隔月&内容が実時間から月単位で遅れ気味wの本連載です。 「どうしたもんかな。。。」と悩みに悩んでこのネタになりました。

最初は使う道具も RL4LMs3 を想定していたんですよ。ですが改めて勉強すると、強化学習を行うとき、モデルが複数同時に動く訳です。 ところが筆者が確認した時点では RL4LMs にはスケーラビリティ的な仕組みがあまり入っておらず、transformers の naive な Pipeline Parallel に対応してるぐらいなんです。複数 GPU 載せれば動きそうではあるんですけど、過去にこの Pipeline Parallel 使ったときは GPU に結構な空きが出てしまってツラかった思い出が。。。

それで、「うーむ。。。」と考え込んでいたら DeepSpeed-Chat が出てきて飛びついたわけです。 さて、DeepSpeed-Chat の機能的なことは後述するとして、まずは学習の全体像を見てみましょう。

2. DeepSpeed-Chat による学習の全体像

RLHF を用いて ChatGPT 的なモデルを作る全工程を一枚絵で示すと以下のような感じになります。

overview

学習は大きく 3 ステップになっており、ステップ 3. が RLHF による学習です。各ステップの内容は以下のとおりです。

Step 1. SFT (Supervised Fine Tuning)

Cross Entropy 損失を用いたテキスト生成モデルの普通のファインチューニングです。これはもう説明不要ですね。

Step 2. 報酬モデルの学習

ステップ 1. で学習した SFT モデルが生成するテキストの良し悪しを判定する報酬モデルの学習です。

まずは報酬モデルの学習データを用意する必要があります。プロンプトを SFT モデルに入力してテキストを生成し、 それに対して人手でラベルを付けていきます。これが RLHF の Human Feedback に相当する作業ですね。

なのですが、生成されたテキストに対して直接、「こっちは何点、あっちは何点。。。」と点数を付けるのではなく、 同一プロンプトでの二つの生成テキストに対して、「この二つの比較なら、こっちの方が良い」 というラベル付けをします。

これは、大量のテキストに対して(同じ基準で)点数をつけるのは人間にとっても難しいタスクであり、 二つのテキストの比較をさせたほうが良質のデータセットになるからのようです。 確かに自分でラベル付け作業をすることを考えると、直接点数を付けるのは途中で「あぁ、だんだん分からなくなってきた。。」 みたいなことになりそうですね。M-1 グランプリの採点も毎年少なからず物議を醸すところがありますし。

少し脱線

実際にどういう学習をして、どういう計算で報酬を導出するのか、実装を確認してみましょう。

以下がソースコードの該当箇所です。ざっと説明しておくと、

  • chosen_ids : 二つの生成テキストを比較して良いとされた方のトークン ID 系列です。
  • rejected_ids : 同、悪いとされた方のトークン ID 系列です。
  • chosen_rewards : 二つの生成テキストを比較して良いとされた方の出力特徴量系列に Linear のヘッドを適用したスカラー値(報酬)の系列です。
  • rejected_rewards : 同、悪いとされた方のスカラー(報酬)系列です。
  • c_ind : OPT モデルは例外になりますが、二つの生成テキストを比較して良いとされた方の最初の PAD トークンのインデックスです。
  • r_ind : 同、悪いとされた方の最初の PAD トークンのインデックスです。
  • divergence_ind : chosen_idsrejected_ids が異なる場合の最初の相違点のインデックスです。
# https://github.com/microsoft/DeepSpeedExamples/blob/8f8099a813f3b223d5df39e0c15c748de4eb1669/applications/DeepSpeed-Chat/training/utils/model/reward_model.py#L65-L103 から抜粋

        chosen_ids = input_ids[:bs]  # bs x seq x 1
        rejected_ids = input_ids[bs:]
        chosen_rewards = rewards[:bs]
        rejected_rewards = rewards[bs:]

        # Compute pairwise loss. Only backprop on the different tokens before padding
        loss = 0
        for i in range(bs):
            chosen_id = chosen_ids[i]
            rejected_id = rejected_ids[i]
            chosen_reward = chosen_rewards[i]
            rejected_reward = rejected_rewards[i]

            c_inds = (chosen_id == self.PAD_ID).nonzero()
            c_ind = c_inds[self.num_padding_at_beginning].item() if len(
                c_inds
            ) > self.num_padding_at_beginning else seq_len  # OPT model pads the first token, so we need to use the second padding token as the end of the sequence
            check_divergence = (chosen_id != rejected_id).nonzero()

            if len(check_divergence) == 0:
                end_ind = rejected_reward.size(-1)
                divergence_ind = end_ind - 1
                r_ind = c_ind
            else:
                # Check if there is any padding otherwise take length of sequence
                r_inds = (rejected_id == self.PAD_ID).nonzero()
                r_ind = r_inds[self.num_padding_at_beginning].item(
                ) if len(r_inds) > self.num_padding_at_beginning else seq_len
                end_ind = max(c_ind, r_ind)
                divergence_ind = check_divergence[0]
            assert divergence_ind > 0
            c_truncated_reward = chosen_reward[divergence_ind:end_ind]
            r_truncated_reward = rejected_reward[divergence_ind:end_ind]
            chosen_mean_scores.append(
                chosen_reward[c_ind - 1])  #use the end score for reference
            rejected_mean_scores.append(rejected_reward[r_ind - 1])

            loss += -torch.nn.functional.logsigmoid(c_truncated_reward -
                                                    r_truncated_reward).mean()

ロスはこの行ですね。良い方と悪い方の相違部分の報酬の差分が大きくなるような学習になってます。

            loss += -torch.nn.functional.logsigmoid(c_truncated_reward -
                                                    r_truncated_reward).mean()

推論時における実際の報酬の値としては、以下のように PAD でない有効トークン部分の末端の値を使っています。 入力されたプロンプトから生成したテキストまで全てのトークンを踏まえて導出した値なので感覚的にも良さそうな気がします。

# https://github.com/microsoft/DeepSpeedExamples/blob/8f8099a813f3b223d5df39e0c15c748de4eb1669/applications/DeepSpeed-Chat/training/utils/model/reward_model.py#L148-L152
                c_inds = (input_id[prompt_length:] == self.PAD_ID).nonzero()
                # here we only use the answer part of the sequence so we do not need to care about the padding at the beginning
                c_ind = c_inds[0].item() + prompt_length if len(
                    c_inds) > 0 else seq_len
                chosen_end_scores.append(value[c_ind - 1])

最後の工程が RLHF です。

Step 3. RLHF

Step 3. では Step 1. で学習した SFT モデルを Actor、 Step 2. で学習した報酬モデルを Critic として強化学習により出力テキストの品質を高めていきます。

いきなり Actor, Critic という単語が出てきましたが、ここで東大松尾研から公開されているスライド4で強化学習の体系を確認してみましょう。

reinforce_learning_overview

ここ数年の AI に TV ゲームをプレイさせる取り組みは図の右上、水色部分の DQN から広がる体系になります。

図中段の左端に Actor Critic が出てきますね。 そこから一段下におりて右に進んだ橙色ところにある PPO5 が DeepSpeed-Chat で採用されている手法になります。

強化学習を基礎から説明すると終わらないので、以降は必要な補足をしつつ PPO を中心に説明します (思い切って端折るので、理解出来ている人からすると色々と言いたくなるかもしれません)。

3. PPO

強化学習というと AI に TV ゲームを操作させたり、将棋や囲碁を差したりするイメージが強いかもしれませんね。 まずは強化学習の枠組みにテキスト生成がどう繋がるのか押さえておきましょう。

テキスト生成と強化学習

まず強化学習について少しだけ復習です。AI が TV ゲームをプレイするのを強化学習の枠組みで表現すると以下のようになります。

reinforce_learning

Agent がプレイヤーである AI です。Environment はプレイするゲームで、プレイ中のゲームの画面が状態( state ) です。 Agent は画面を見て行動( action )し、その行動に応じた報酬( reward )を得ます。

Agent に如何に上手くゲームをプレイさせて高得点を採るか、というのが強化学習の目標です。

上記は TV ゲームの話ですが、自己回帰によるテキスト生成に置き換えると、

  • Agent : テキスト生成モデル
  • state : モデルに入力されるプロンプト
  • action : 語彙集合から選択した出力トークン 1 個

になります。ゲームの操作に応じて画面が更新されるように、出力されたトークンをプロンプトの末尾に連結して次の状態を作り、 上の図をグルグル回りながら文章を生成する感じです。

また、強化学習では Agent の状況に応じて行動を選択する意思決定の戦略を方策と呼びます。 テキスト生成では、decoder の出力特徴量を softmax して出力トークンの確率分布を出すところが方策に相当すると思えば良いでしょう。

これでテキスト生成と強化学習が繋がりました。方策が出てきたので、もう少しだけ基本の話をします。 強化学習には大きく分けて二つの手法があります。方策ベースと値ベースです。

  • 方策ベース:
    Agent が与えられた state で、どの action を選択するかを学習します。
  • 値ベース:
    Agent に、どの state がより価値があるかを学習させた上で、より価値がある action を選択させます。また state の価値を推定する関数を価値関数と呼びます。

言葉だと分かりにくいですね。以下はHugging Face の強化学習コース6に出てくるイメージ図です。 右側の数値が各 state (コマ)の価値関数の値になります。

policy_base_value_base

さて基本の話はこれくらいにして PPO に戻ります。

PPO はそれ以前に存在した手法を組み合わせて改善を加えた手法なので、次は PPO に至るまでの流れを見ていきます。

reinforce_to_ppo

先程の図を PPO から元を辿っていくと方策勾配法にたどり着きます。さらによく見ると TRPO のところで枝割れして Actor Critic にたどり着きました。この辺りを押さえておくのが良さそうです。

先に方策ベースの手法である方策勾配法について少し見てみましょう。

方策勾配法

方策勾配法は方策をパラメータθで表された確率的な関数 π で表現し、目的関数をその勾配を使って最大化することで、 最良の方策を求めます。目的関数は以下のようになります。

policy_gradient

トークンのシーケンス中の位置を t とすれば、st がプロンプト、at が出力トークンで、πθ(at |st) はその確率ですね。 A^t は利得で位置 t 時点における出力の具合の良さだと思って下さい(後述します)。

出力の確率とその時の具合の良さを掛け合わせた期待値を最大化する訳ですから、感覚的にもそれで良さそうな気がしてきました。

この方法の泣き所は、上記の目的関数を計算する元ネタ(言語モデルが生成したトークン列と確率)を収集する必要がある為に遅く、 その報酬の分散も大きくなることです7

報酬の分散を抑える為のシンプルな方法は大量に元ネタを収集して平均することですが、学習速度への影響を考えると限度があります。 別なアプローチで報酬の分散を抑える対策の一つが Actor Critic になるので、次はそちらを押さえましょう。

Actor Critic

Actor Critic は方策ベースと値ベースを組み合わせた手法で、以下の二つの役割が登場します。

  • Actor : どのような行動をするかを制御する。つまり方策のことですね。πθ(a|s) です。
  • Critic : Actor の振る舞いを評価します。価値関数 V(s) とします。

Actor が方策ベース、Critic が価値ベースです。 Actor の振る舞いを Critic が評価し、その評価を元に Actor を更新。また Critic もより良い評価を出来るよう更新します。

Actor の振る舞いを報酬で直接評価するのではなく Critic を間に入れることで、報酬の分散の影響を抑えることができ学習の安定性が向上します。

Actor の基本的な更新は前述した LPG(θ) ですが、先程、後述するとした利得には V(s) を直接使わず以下の式を使います。

advantage

上段の式は A^t は位置 t からシーケンス長 T の末端までの δ に割引率 γ とハイパーパラメータ λ を掛けた加重和になってます。

では下段の δt の意味を考えてみましょう。rt + γV(st+1) と V(st) で分けて考えます。

  • rt + γV(st+1) :
    状態 stで action at を選択して得られた報酬 rt と、その結果生じた状態 st+1 の価値である V(st+1) (と割引率γの積)の和です。つまり状態 st で at を選んだ場合の最終的な全報酬の見込みです。
  • V(st) :
    状態 st における最終的な全報酬の見込みです。

この二つの差分を採るので δt状態 st において at を選ぶ価値が、状態 st での全 action の平均的な価値に対して、どれほど優れているか(or 劣っているか)を示すことになります。

A^t が正であれば、より at を選択する確率が高くなるように更新し、負の場合はその逆になります。

さらに学習を安定させる工夫として PPO ではクリップ付き代理目的関数が導入されています。

クリップ付き代理目的関数

まず PPO の前身にあたる TRPO8 では目的関数が変更されていて、以下の計算式を最大化しています。 πθ が更新後、πθold が更新前の方策です。

so

上段が TRPO が最大化する代理目的関数で、意味合い的には θ の変化による利得の更新の期待値です。 下段は θ の更新が過大にならないよう、KLダイバージェンスを使った制約を掛けています。

ただ、計算速度や複雑性の観点で TRPO は課題を抱えており、より簡便な方法を求めた結果が PPO になります。 PPO で導入されたクリップ付き代理目的関数 では KLダイバージェンスを使った制約の部分がクリッピングによるシンプルな方法に置き換えられています。

cso

ここで rt(θ) は以下なので rt(θ)A^t は TRPO の代理目的関数と同じですね。

ratio

結局は rt(θ)A^t を素の場合と 1-ε ~ 1+ε でクリップした場合で比べた時の最小値になります。

ここで私は「えぇっと…、正直よくわかりませんが?」となったのですが、有り難いことに表にしてまとめてくれた人がいました。 (※ pt(θ) は rt(θ) で読み替えて下さい。)

ppo_table

要するに以下の条件に合致した時のみ rt(θ)A^t の勾配で方策を更新するということですね。

  • 方策の更新量が適正範囲(1-ε ~ 1+ε)に収まる場合(1,2行目) :
    更新量が適正範囲なので更新して問題なし。
  • 更新量が小さいが利得が正値(3行目) :
    利得が正なのは質の良い文章を生成した場合なので、量が小さいとしても更新しておこう。
  • 更新量が大きいが利得が負値(6行目) :
    利得が負なのは質の悪い文章を生成した場合なので、大き目でも更新して素早く修正してしまおう。

ちなみに 4, 5行目でクリップが適用された場合に方策が更新されないのは (1±ε)A^t が θ を含まないので勾配が 0 になるからですね。

さて、かなりの駆け足で斜め読みでしたが、根底にある考え方は押さえられたかと思います。それでは実際に DeepSpeed-Chat を動かしてみましょう。

4. DeepSpeed-Chat

DeepSpeed9 は大規模な深層学習モデルを効果的かつ効率的に学習・推論する為のライブラリ群です。 そして DeepSpeed-Chat は DeepSpeed で実装された ChatGPT ライクなモデルを学習する為のフレームワークという位置づけになります。

2章の学習の全体像で示したように、RLHF の学習は複数のモデルを同時に扱います。 今回は小さなモデルで実験するので使わなかったものが多いのですが、DeepSpeed-Chat では大規模なモデルを扱うための機能として、 以下が利用できます。

  • ZeRO Optimization :
    DeepSpeed で実装されている ZeRO10 による最適化です。DeepSpeed-Chat では最適化のステージを 0~3 で選択するのみで、細かいコンフィグレーションはコード内に記述されたテンプレートから生成されます。
  • FP16 Training :
    FP16 との混合精度学習です。上述したコード内のテンプレートに、設定が埋め込まれています。
  • Gradient Checkpointing :
    勾配計算時に必要となるフォーワードパスの計算結果を破棄し、再計算で置き換えることで速度と引き換えにメモリ消費を抑えます。
  • Gradient Accumulation :
    小さいバッチサイズでの勾配を累積し、論理的に大きなバッチサイズでの学習を可能とします。
  • LoRA11 :
    言語モデルのパラメータを固定し、モデル中の全結合層の学習をその階数分解行列を最適化に置き換えることで更新対象のパラメータ数を劇的に低減する手法です。
  • Reference model offloading :
    Reference model モデルを CPU にオフロードすることで速度低下を抑えつつ、より大きなバッチサイズを使用可能とします。
  • DeepSpeed Hybrid Engine :
    DeepSpeed が提供する学習用エンジンと推論用エンジンを自動的に切り替え、Critic に評価させるテキストを効率的に生成することができます。

実用になりそうな数十億パラメータ以上のモデルで RLHF を動かすには、この辺りの機能を全部入りで使うことになりますが、 これらが動作実績のあるスクリプトとして、まとまっているのが嬉しいところですね。

次に実際に学習を動かすにはデータセットが必要なので、その準備をしていきます。

5. データセットの準備

今回も Google Colab で作業を開始したのですが、学習をする段階でいろいろあって途中から GCP 上に立てた VM で作業をすることになってしまいました。

とりあえずデータセットの準備作業は Colab で行ったので、その想定でコードを記載していきます。

さて、データセットにはチャットボット風の会話形式のものを使いたかったのですが、日本語で良さそうなものを見つけられなかったので、 第23回で使った3行要約データセットを使いました。要約データセットはモデルへの入力が長くなるのでイマイチ不安ですが、 大丈夫でしょうか。。。

まずは第23回で加工済みのデータを GCS から取得します12

from google.colab import auth
auth.authenticate_user()
!gsutil cp -r gs://somewhere/brio/raw .

こんな感じのファイルが用意できている前提です。 それぞれのファイルの内容は *.source は入力文、*.target は要約文、*.out は T5 で生成した要約文(1入力に対して16個)になります。

!ls raw
# test.out     test.target  train.source    val.out     val.target
# test.source  train.out      train.target  val.source

とりあえず、ファイルをロードして dict にします。

NUM_GENERATED = 16

def read_lines(fname):
  with open(fname, "r") as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]
  return lines

def load_dataset(split):
  source = read_lines(f"raw/{split}.source")
  target = read_lines(f"raw/{split}.target")
  flatten_out = read_lines(f"raw/{split}.out")
  out = []
  buf = []
  for i, o in enumerate(flatten_out):
    buf.append(o)
    if (i+1)%NUM_GENERATED == 0:
      out.append(buf)
      buf = []
  assert len(source) == len(target) == len(out)
  return {"source": source, "target": target, "out": out}

dataset = {
  "train": load_dataset("train"),
  "val"  : load_dataset("val"),
  "test" : load_dataset("test")
}

本当は全データで試したいところですが、 入力文が長いと強化学習のところでリソース的に大変になりそうな気がしたので、 入力長でデータを分割し、短いものだけ使うことにしました。

入力長を計算するのにトークナイザが必要なので transformers と sentencepiece を入れます。

!pip install transformers==4.30.2 sentencepiece==0.1.99      

トークナイザをロードして分割用の関数を定義して、

from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = 'rinna/japanese-gpt2-medium'
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
tokenizer.do_lower_case = True 

SEQ_LEN_SHORT = 278
SEQ_LEN_LONG = 489

def split_by_input_length(split, thresh_short, thresh_long):
  short_examples = []
  long_examples = []
  too_long_examples = []
  for i, example in enumerate(zip(split["source"], split["target"], split["out"],)):
    if i % 5000 == 0:
      print(f"processing {i}...")
    input = tokenizer(example[0], return_tensors='np')
    input_len = input.input_ids.shape[1]
    if input_len <= thresh_short:
      short_examples.append(example)
    elif input_len <= thresh_long:
      long_examples.append(example)
    else:
      too_long_examples.append(example)
  print(f"short: {len(short_examples)}, long: {len(long_examples)}, too_long {len(too_long_examples)}")
  return {"short": short_examples, "long": long_examples, "too_long": too_long_examples}

import pickle
def dump(fname, obj):
  with open(fname, "wb") as f:
    pickle.dump(obj, f)

以下のようにして各スプリットを入力長に対する閾値で分割します。

for split in ["train", "val", "test"]:
  print(f"Processing {split}...")
  splitted_by_length = split_by_input_length(dataset[split], SEQ_LEN_SHORT, SEQ_LEN_LONG)
  for key in splitted_by_length.keys():
    dump(f"./{split}_{key}.pkl", splitted_by_length[key])

# Processing train...
# processing 0...
# processing 5000...
# ...
# processing 95000...
# short: 39603, long: 29355, too_long 29441
# Processing val...
# processing 0...
# short: 239, long: 211, too_long 195
# Processing test...
# processing 0...
# short: 259, long: 201, too_long 195

結果は GCS に退避しておきましょう。以下のようなファイルが出来上がってれば OK です。

!gsutil cp *.pkl gs://somewhere/RLHF/summarize_data/
!gsutil ls gs://somewhere/RLHF/summarize_data/
gs://somewhere/RLHF/summarize_data/test_long.pkl
gs://somewhere/RLHF/summarize_data/test_short.pkl
gs://somewhere/RLHF/summarize_data/test_too_long.pkl
gs://somewhere/RLHF/summarize_data/train_long.pkl
gs://somewhere/RLHF/summarize_data/train_short.pkl
gs://somewhere/RLHF/summarize_data/train_too_long.pkl
gs://somewhere/RLHF/summarize_data/val_long.pkl
gs://somewhere/RLHF/summarize_data/val_short.pkl
gs://somewhere/RLHF/summarize_data/val_too_long.pkl

それでは学習を始めていきます。今回は学習を GCP 上の VM で実行するので、まずは環境をセットアップしましょう。

6. 環境のセットアップ

ここからは GCP に VM を立てて動かしました。VM の構成は以下のとおりです。

  • ディスク イメージ : Debian 10 based Deep Learning VM with M108
  • vCPU x 2 (Step 3 のみ x 4)
  • Mem 64 GB
  • GPU : Tesla T4 x 1 (Step 3 のみ x 2)
  • NVIDIA Driver : 525.105.17

VM を起動したら以下のようにコンテナを起動します。

$ docker run --name deepspeedchat -it --gpus all --shm-size=1g \
    -v /work:/work pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel bash
root@8f58510dddce:/workspace#

コンテナに入ったら以下のようにセットアップします。

# cd /work
# pip install deepspeed==0.9.2
# git clone https://github.com/microsoft/DeepSpeedExamples.git
# cd DeepSpeedExamples 
# git checkout 8f8099a813f3b223d5df39e0c15c748de4eb1669
# cd ./applications/DeepSpeed-Chat
# pip install -r requirements.txt
# cd /work

今回は rinna/japanese-gpt2-medium で実験することにしたのですが、 学習過程(特に Step-2 と Step-3)でやたらと loss の OVERFLOW が発生して学習が停止してしまう事態が発生しました。 確認できた動作実績は最小でも 1.3B パラメータなので、選択したモデルがかなり小さいことも影響したかもしれません。

悩んだ末に今回は fp16 を諦めて fp32 で学習することにしました。オプションでの切替はできないので直接コードを修正してます。

sed -i -e "38s/True/False/" -e "71s/True/False/" ./DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/ds_utils.py
cat ./DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/ds_utils.py | awk 'NR>=37 && NR<=38 || NR>=70 && NR<=71 {print NR"|"$0}'
37|        "fp16": {
38|            "enabled": False,
70|        "fp16": {
71|            "enabled": False

さらにコードを修正して今回使うデータセットを追加します。GCS から前章で準備した pickle を取得して、

# gsutil -m cp gs://somewhere/RLHF/summarize_data/* .
# ls
DeepSpeedExamples  test_short.pkl     train_long.pkl   train_too_long.pkl  val_short.pkl
test_long.pkl      test_too_long.pkl  train_short.pkl  val_long.pkl        val_too_long.pkl

今回加工した要約データを使用するデータセットクラスを追記します。

cat << EOF >> ./DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py

import pickle
import numpy as np
from datasets import Dataset
class BaseThreeLinesSummaryDataset(PromptRawDataset):

    def __init__(self, output_path, seed, local_rank, dataset_name):
        self.output_path = output_path
        self.seed = seed
        self.local_rank = local_rank
        self.dataset_name = "short_three_lines_summary"
        self.dataset_name_clean = "short_three_lines_summary"
        self.raw_datasets = {
            "train": self.load_pickle(self.get_train_file(), seed),
            "val": self.load_pickle(self.get_eval_file(), seed)
        }

    def get_train_file(self):
        return None

    def get_eval_file(self):
        return None

    def load_pickle(self, file_name, seed):
        with open(file_name, "rb") as f:
            rows = pickle.load(f)
        examples = []
        RANK_THRESHOLD = 8
        rng= np.random.RandomState(seed)
        for row in rows:
            generated_ind = rng.randint(RANK_THRESHOLD, len(row[2]), 1)[0]
            examples.append({
                "article" : row[0],
                "reference" : row[1],
                "genareted" : row[2][generated_ind]
            })
        return Dataset.from_list(examples)

    def get_train_data(self):
        return self.raw_datasets["train"]

    def get_eval_data(self):
        return self.raw_datasets["val"]

    def get_prompt(self, sample):
        return " human: " + sample['article'] + " assistant:"

    def get_chosen(self, sample):
        return " " + sample['reference']

    def get_rejected(self, sample):
        return " " + sample['genareted']

    def get_prompt_and_chosen(self, sample):
        return " human: " + sample['article'] + " assistant: " + sample[
            'reference']

    def get_prompt_and_rejected(self, sample):
        return " human: " + sample['article'] + " assistant: " + sample[
            'genareted']

class ShortThreeLinesSummaryDataset(BaseThreeLinesSummaryDataset):
    def get_train_file(self):
      return "./train_short.pkl"
    def get_eval_file(self):
      return "./val_short.pkl"
EOF

補足 : データセットの形式について

DeepSpeed-Chat で使うデータセットは 1 件のサンプルが以下の属性で構成されています。

  • prompt :
    モデルへの入力。今回は要約する原文のテキストです。
  • chosen :
    prompt から出力すべきテキストの高品質な例です。今回は3行要約データセットのラベル、つまり人手で作った参照要約です。
  • rejected :
    prompt から出力すべきテキストの低品質な例です。今回は T5 で作ったモデルが生成した要約文です。前章で加工した pickle には ROUGE スコアの高い順に16件格納されているので、ROUGE スコアの下半分からランダムにサンプリングしています。

学習時は各 Step でそれぞれ以下のようにサンプルの属性を使います。

  • Step 1 : prompt から chosen を生成するよう学習。
  • Step 2 : prompt に対して rejected より chosen がより高いスコアになるように学習。
  • Step 3 : prompt を入力して生成されたテキストが報酬モデルで高く評価されるように学習。

実際のデータを見てみると以下のような感じになります。

# 以下の出力例は実際は一行です。
import sys
sys.path.append("./DeepSpeedExamples/applications/DeepSpeed-Chat/training")

from utils.data.raw_datasets import ShortThreeLinesSummaryDataset
short_dataset = ShortThreeLinesSummaryDataset("temp", 1234, 0, "dummy")
split = short_dataset.get_train_data()

print(short_dataset.get_prompt(split[0]))
# human: 女優の1995年に放送されたマクドナルドのCMで一躍脚光を浴びた遠藤。"エンクミ"の愛称で親しまれ、
# 広末涼子や内田有紀らと共にショートカットのアイドルとして一世を風靡した。CMに出演した17歳当時、学校」
# ...省略...
# フからのアドバイスで別れることに。事情を聞いた彼は素直に受け入れ、「邪魔になっちゃうのかな」と語って
# いたという。 assistant:

print(short_dataset.get_chosen(split[0]))
# 遠藤久美子が9日、日テレの番組で高校時代に受けたいじめを告白した。トイレで水をかけられたり、
# 平均台の上に正座させられたりしたという。在学中にCMで脚光を浴びた遠藤は、校内でも注目されいじめの対象になった。

print(short_dataset.get_rejected(split[0]))
# TBS「しゃべくり007」で遠藤憲一が、17歳当時について語った。いじめの対象になったが、平均台の上で正座させら
# れたと告白。スタジオからは「無理じゃない? と思って」と笑いを誘った。。

本筋に戻ってコードの修正の続きです。

上記のデータセットをロードするための条件を追加します。

# sed -i -e '63a\ \ \ \ elif dataset_name == "3line_summaries":' \
     -e '63a\ \ \ \ \ \ \ \ return raw_datasets.ShortThreeLinesSummaryDataset(output_path, seed, local_rank, dataset_name)' \
  ./DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/data/data_utils.py
# cat ./DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/data/data_utils.py | awk 'NR>=61 && NR<=66{print NR"|"$0}'
61|    elif "lmqg/qag_jaquad" in dataset_name:
62|        return raw_datasets.LmqgQagjaquadDataset(output_path, seed, local_rank,
63|                                                 dataset_name)
64|    elif dataset_name == "3line_summaries":
65|        return raw_datasets.ShortThreeLinesSummaryDataset(output_path, seed, local_rank, dataset_name)
66|    else:

コードを見ると “<|endoftext|>” を生成文字列の末尾につけていますが、トークン数がもったいないので不要だと判断して削除しました。

# sed -i -e '253s/<|endoftext|>//' ./DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/data/data_utils.py
# cat ./DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/data/data_utils.py | awk 'NR>=252 && NR<=254{print NR"|"$0}'
252|                          max_seq_len,
253|                          end_of_conversation_token="",
254|                          sft_only_data_path=[]):

前述のとおり rinna/japanese-gpt2-medium を使うことにしたので、各 Step のトークナイザのロード箇所で、 Fast Tokenizer のスイッチをオフ、do_lower_case を True にします。

Step 1 のコード修正です。

# sed -i -e '205s/True/False/' \
        -e '206a\ \ \ \ tokenizer.do_lower_case = True' \
        DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py
# cat  DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py | awk 'NR>=205 && NR<=207{print NR"|"$0}'
205|    tokenizer = load_hf_tokenizer(args.model_name_or_path, fast_tokenizer=False)
206|    tokenizer.pad_token = tokenizer.eos_token
207|    tokenizer.do_lower_case = True

続いて Step 2。

# sed -i -e '204s/True/False/' \
        -e '205a\ \ \ \ tokenizer.do_lower_case = True' \
      DeepSpeedExamples/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py
# cat  DeepSpeedExamples/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py | awk 'NR>=204 && NR<=206{print NR"|"$0}'
204|    tokenizer = load_hf_tokenizer(args.model_name_or_path, fast_tokenizer=False)
205|    tokenizer.pad_token = tokenizer.eos_token
206|    tokenizer.do_lower_case = True

最後に Step 3 です。

# sed -i -e '383s/True/False/' \
        -e '384a\ \ \ \ tokenizer.do_lower_case = True' \
      DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py
# cat  DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py  | awk 'NR>=382 && NR<=385{print NR"|"$0}'
382|    tokenizer = load_hf_tokenizer(args.actor_model_name_or_path,
383|                                  fast_tokenizer=False)
384|    tokenizer.pad_token = tokenizer.eos_token
385|    tokenizer.do_lower_case = True

また、Step 3 の強化学習では、ZeRO の offload 機能を使いました。今回使用した DeepSpeed-Chat のコードでは assert で対応するオプションが使えないようになっていたので、コードを修正して外しました13

# cat DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py | awk 'NR>=366 && NR<=370{print NR"|"$0}'
366|    args.global_rank = torch.distributed.get_rank()
367|
368|    assert not args.offload, "zero-offload is not currently supported but coming soon!"
369|
370|    unsupervised_training_enabled = args.unsupervised_dataset_name and args.unsupervised_dataset_config_name

# sed -i -e '368,369d' DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py
# cat DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py | awk 'NR>=366 && NR<=368{print NR"|"$0}'
366|    args.global_rank = torch.distributed.get_rank()
367|
368|    unsupervised_training_enabled = args.unsupervised_dataset_name and args.unsupervised_dataset_config_name

それでは実際に学習を動かしてみましょう。

7. 学習の実行

ここから、各 Step を実行していきます。DeepSpeed-Chat では 3 つの Step を一気に動かすこともできますが、 今回は 1 Step ずつ実行しました。

Step 1 : SFT (Supervised Fine Tuning) の学習

以下のようにして実行しています。

# mkdir -p ./step-1
# deepspeed --num_gpus=1 ./DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py \
   --data_path 3line_summaries \
   --data_split 10,0,0 \
   --data_output_path "./cache" \
   --max_seq_len 376 \
   --model_name_or_path "rinna/japanese-gpt2-medium" \
   --per_device_train_batch_size 4 \
   --per_device_eval_batch_size 4 \
   --gradient_accumulation_steps 4 \
   --weight_decay 0.0 \
   --learning_rate 5e-5 \
   --num_train_epochs 5 \
   --zero_stage 0 \
   --seed 1234 \
   --deepspeed \
   --output_dir ./step-1 \
   &> ./step-1/training.log 

オプションについて少し補足です。

  • –data_split :
    “10,0,0” は学習データである 3line_summaries を Step-1, 2, 3 にどのように配分するかというパラメータです。"4,3,3" のようにして各 Step で学習データが Step 間で被らないようにもできますが、今回は “10,0,0” のように各 Step に全量のデータを入れるようにしています。
  • –zero_stage :
    Step-1 は 1 GPU で動かしたので 0 を設定して ZeRO を無効化しています。

使わなかった(使えなかった)オプションもあります。

  • –lora_module_name, –lora_dim :
    LoRA はモデルが GPT-NeoX であれば、--lora_module_name "gpt_neox.layers." --lora_dim 128 のようにして利用できるのですが、DeepSpeed-Chat の LoRA は簡易実装で GPT2 には対応してませんでした。14

結果は以下のようになりました。Perplexity だとイマイチよくわかりませんね。 少々回し過ぎな気もしますが、ドキュメントには過学習気味にするのが良さげだと書いてあるので、これで良しとしましょう。15

# cat ./step-1/training.log | grep ppl:
ppl: 657.7637329101562
ppl: 5.457597732543945
ppl: 5.448544979095459
ppl: 5.490699768066406
ppl: 5.513986110687256
ppl: 5.5415472984313965

Step 2 : RW モデル の学習

以下のようにして実行しています。

# mkdir -p ./step-2
# deepspeed --num_gpus=1 ./DeepSpeedExamples/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py \
   --data_path 3line_summaries \
   --data_split 0,10,0 \
   --data_output_path "./cache" \
   --max_seq_len 376 \
   --model_name_or_path "rinna/japanese-gpt2-medium" \
   --num_padding_at_beginning 0 \
   --per_device_train_batch_size 2 \
   --per_device_eval_batch_size 2 \
   --gradient_accumulation_steps 32 \
   --num_train_epochs 5  \
   --weight_decay 0.0 \
   --learning_rate 5e-5 \
   --disable_dropout \
   --zero_stage 0 \
   --seed 1234 \
   --deepspeed \
   --output_dir ./step-2 \
   &> ./step-2/training.log

Step 2 のオプションは Step 1 とだいたい同じなので、特に説明しなくても大丈夫そうですね。

結果は以下のようになりました。

# cat step-2/training.log | grep chosen_last_scores
chosen_last_scores (higher is better) : -0.08306388556957245, acc (higher is better) : 0.5849999785423279
chosen_last_scores (higher is better) : -3.165792465209961, acc (higher is better) : 1.0
chosen_last_scores (higher is better) : 1.6424806118011475, acc (higher is better) : 0.9950000047683716
chosen_last_scores (higher is better) : 5.951107501983643, acc (higher is better) : 0.9950000047683716
chosen_last_scores (higher is better) : 7.464086055755615, acc (higher is better) : 1.0
chosen_last_scores (higher is better) : 7.978110313415527, acc (higher is better) : 0.9950000047683716

パッと見た感じは悪く無さそうです。

acc は chosen と rejected のスコアを比較して chosen が上位だった割合だと思います。 99.5 % なのですが、これは正解ラベルの参照要約と T5 で生成した要約を分類させた結果です。

ですが、Step 3 の強化学習では入力される系列は全てモデルが生成したものになります。 報酬モデルからすると rejected に相当する品質の入力ばかり採点させられるわけで、果たして優劣をちゃんと判断できるでしょうか。 不安しかありません。

次、行ってみましょう。

Step 3 : RLHF の学習

Step 3 では学習に使うモデルの数が増えるので、ここからは VM の構成を変更して vCPU x 4, Tesla T4 x 2 の構成にしました。

以下のようにして実行しています。

# mkdir -p ./step-3
# deepspeed --num_gpus 2 ./DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py \
   --data_path 3line_summaries \
   --data_split 0,0,10 \
   --data_output_path "./cache" \
   --actor_model_name_or_path ./step-1 \
   --critic_model_name_or_path ./step-2 \
   --num_train_epochs 1 \
   --per_device_train_batch_size 16 \
   --per_device_mini_train_batch_size 16 \
   --generation_batch_numbers 1 \
   --gradient_accumulation_steps 1 \
   --ppo_epochs 1 \
   --max_prompt_seq_len 296 \
   --max_answer_seq_len 88 \
   --actor_zero_stage 2 \
   --critic_zero_stage 2 \
   --offload \
   --num_padding_at_beginning 0 \
   --actor_weight_decay 0.0 \
   --critic_weight_decay 0.0 \
   --actor_gradient_checkpointing \
   --critic_gradient_checkpointing \
   --inference_tp_size 1 \
   --offload_reference_model \
   --seed 1234 \
   --deepspeed \
   --output_dir ./step-3 \
   &> ./step-3/training.log &

使用したオプションについて補足です。

  • –num_gpus :
    Tesla T4 を 2 個使うので 2 を設定してます。
  • –per_device_train_batch_size, –per_device_mini_train_batch_size :
    ドキュメント16によると同じサイズにしたほうが良いようだったので、それに従いました。
  • –num_train_epochs, –ppo_epochs :
    こちらも共に 1 にしたと記述してあったので16、それに従いました。
  • –actor_zero_stage, –critic_zero_stage :
    共に 2 を設定して ZeRO の Stage 216を使用しています。
  • –offload :
    ZeRO Offload17 を有効にし、オプティマイザの計算とメモリを CPU 側に移してます。
  • –offload_reference_model :
    リファレンスモデルを CPU にオフロード出来るようにします。無視できる範囲の速度低下と引き換えにバッチサイズの拡大が可能です。

使わなかったオプションもあります。

  • –enable_hybrid_engine :
    DeepSpeed Hybrid Engine を有効にします。ってコレが目玉なんじゃないの?という気もしますが、よくわからないエラーになってしまったので泣く泣く諦めました。18
  • –enable_ema :
    Actor のパラメータの移動平均を収集する新たなモデルを追加します。精度が向上するようなので機会があれば使ってみたいです。

学習結果を GCS に退避しておきます。

gsutil -m cp -r ./step-* gs://somewhere/RLHF/gpt2-medium/

さて、不安9割、夢1割でここまで走ってきました。結果はどうなったでしょうか。。。

8. 結果発表

Step 3 では学習のログに以下のようにロスと報酬スコアが出力されているので、値を拾ってプロットしてみました。

epoch: 0|step: 1182|ppo_ep: 1|act_loss: -1.8171650171279907|cri_loss: 17.55929183959961|unsuper_loss: 0.0
average reward score: -1.4255101680755615

まずは average reward score のプロットです。

average_reward_score

だいぶ不安定な感じですが、目を細めてみると気持ち上がってる気がします。そう思いたいです。

スコアがほぼマイナス値なのは前述のとおり、データセットで rejected に相当する品質のテキストばかり入力されてるからでしょうね。

次にロスの確認です。

actor_critic_loss

えーっと、Actor のロスは下がってる感じですか、一応。スパイクが出てますが、この時に fp16 だとオーバフローしちゃうんでしょうか。 大きいモデルで試した時にどうなるかは気になるところです。収まらなければ、A100 あたりを用意して bfloat16 ですかね。

Critic のロスが。。。上っているのはヤバい感じですが、それと連動して Actor ロスと報酬スコアは改善してるような。。。 これ Critic の学習レートの調整してあげると、もう少し何とかなるんでしょうか? 残念ながら、今回はハイパーパラメータを追い込むところまで試せませんでした。

検証データによる評価

次に検証データにおける報酬スコアと ROUGE スコアを確認していきましょう。 ここからは Colab で動かします。アクセラレータは GPU にして下さい。

セットアップ

Colab のランタイム上に学習に使った環境を再現します。 必要なものをインストールしてデータと学習済みモデルを GCS から取得します。

DeepSpeed-Chat のコード修正は繰り返しになるので省略しますが 6 章の要領で変更して下さい。

!pip install deepspeed==0.9.2
!git clone https://github.com/microsoft/DeepSpeedExamples.git
!cd DeepSpeedExamples && git checkout 8f8099a813f3b223d5df39e0c15c748de4eb1669
!cd DeepSpeedExamples/applications/DeepSpeed-Chat && pip install -r requirements.txt

from google.colab import auth
auth.authenticate_user()

!gsutil -m cp gs://somewhere/RLHF/summarize_data/* .
!gsutil -m cp -r gs://somewhere/RLHF/gpt2-medium/step-* .

ROUGE スコアの計算をするのに以下を追加でインストールしました。

!git clone https://github.com/neulab/compare-mt.git
!cd compare-mt && git checkout b6d8f79d02043243c3d8aa58373a0f4c55e17a69
!cd ./compare-mt && pip install -r requirements.txt
!cd ./compare-mt && python setup.py install
!apt-get install mecab mecab-ipadic-utf8 libmecab-dev
!pip install mecab-python3 fugashi ipadic
!pip install ginza ja_ginza==5.1.0

データとモデルのロード、評価用関数

まずは検証データセットをロードして、

import sys
sys.path.append("./DeepSpeedExamples/applications/DeepSpeed-Chat/training")

from utils.data.raw_datasets import ShortThreeLinesSummaryDataset
short_dataset = ShortThreeLinesSummaryDataset("temp", 1234, 0, "dummy")
eval_split = short_dataset.get_eval_data()
len(eval_split)
# 239

学習済みのモデルもロードします。

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
device = torch.device("cuda")

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-medium",
                                              fast_tokenizer=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.do_lower_case = True

def load_actor_model(model_path):
    model = AutoModelForCausalLM.from_pretrained(model_path)
    model.to(device)
    model.eval()
    return model

sft_model = load_actor_model("./step-1")
actor_model = load_actor_model("./step-3/actor")

要約生成関数とデコード関数です。1件ずつ処理するので遅いです。面倒だったので手を抜いてしまいました。。。

def summarize(model, tokenizer, prompt, max_answer_seq_len=88):
  features = tokenizer(prompt, return_tensors="pt")
  features = {k:v.to(device) for k, v in features.items()}
  input_ids = features["input_ids"]
  mask = features["attention_mask"]
  max_min_length = max_answer_seq_len + input_ids.shape[1]
  with torch.no_grad():
    output_ids = model.generate(input_ids, attention_mask=mask,
                           max_length=max_min_length)
  return output_ids

def decode(tokenizer, output_ids):
  decodeds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  decodeds = [decoded[decoded.index("assistant:")+len("assistant:"):] for decoded in decodeds]
  return [decoded.replace("<|endoftext|>", "") for decoded in decodeds]

Critic のロードと評価用の関数です。

from utils.model.model_utils import create_critic_model
from utils.utils import to_device
from step2_reward_model_finetuning.rw_eval import prepare_singlesample
import numpy as np

num_padding_at_beginning = 0

critic_model = create_critic_model("./step-2", tokenizer, None,
                                num_padding_at_beginning, True)
critic_model.to(device)
critic_model.eval()

def calc_reward(model, query, answer):
    device = torch.device("cuda")

    prompt = "human: " + query
    my_ans = "assistant: "+ answer

    batch = prepare_singlesample(prompt,
                                 my_ans,
                                 tokenizer,
                                 max_seq_len=512,
                                 end_of_conversation_token="")
    batch = to_device(batch, device)

    with torch.no_grad():
        outputs = model.forward_value(
            **batch, prompt_length=max(2, num_padding_at_beginning)
        )  # we just need to skip the number of padding tokens at the beginning

    return outputs["chosen_end_scores"].item()

def evaluate(model, critic, tokenizer, dataset, split):
  references = []
  generateds = []
  rewards = []
  for example in split:
    query = example["article"]
    reference = example['reference']
    prompt = dataset.get_prompt(example)
    output_ids = summarize(model, tokenizer, prompt)
    answer = decode(tokenizer, output_ids)[0]
    reward = calc_reward(critic, query, answer)
    references.append(reference)
    generateds.append(answer)
    rewards.append(reward)
  return references, generateds, rewards

報酬モデルによる評価

それでは評価してみましょう。まずは Step 1 の SFT モデルです。

sft_references, sft_generateds, sft_rewards = evaluate(sft_model, critic_model, tokenizer, short_dataset, eval_split)
sft_result = [sft_references, sft_generateds, sft_rewards]
sft_score = np.array(sft_result[2]).mean()
sft_score
# -14.613813163471022

次に Step 3 で学習した Actor です。

actor_references, actor_generateds, actor_rewards = evaluate(actor_model, critic_model, tokenizer, short_dataset, eval_split)
actor_result = [actor_references, actor_generateds, actor_rewards]
actor_score = np.array(actor_result[2]).mean()
actor_score
# -13.77621269812145

微妙なところかもしれませんが、検証セットでも報酬モデルのスコアは向上していますね。

ROUGE スコアでの評価はどうなるでしょうか。

ROUGE スコアでの評価

ROUGE による評価関数はこんな感じです。

import compare_mt.rouge.tokenize as rouge_tokenize
import MeCab
import os
os.environ["MECABRC"] ="/etc/mecabrc"
mecab = MeCab.Tagger ("-Ochasen")
mecab.parse("")

def parse_by_mecab(sentence, lemma=False):
  tokens = []
  node = mecab.parseToNode(sentence).next
  while node:
    feature = node.feature.split(',')
    token = feature[-3]
    if token == '*' or not lemma:
      token = node.surface
    tokens.append(token)
    node = node.next
  return [token for token in tokens if len(token) > 0]

def tokenize(text, stemmer):
  return parse_by_mecab(text, lemma=False)

rouge_tokenize.tokenize = tokenize

from compare_mt.rouge.rouge_scorer import RougeScorer
all_scorer = RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=True)
import spacy
nlp = spacy.load('ja_ginza')

def process(x):
  return [str(sent) for sent in nlp(x.strip()).sents]

def compute_rouge(labels, preds):
  labels = [process(label) for label in labels]
  preds = [process(pred) for pred in preds]
  sample_rouge1 = 0
  sample_rouge2 = 0
  sample_rougeLsum = 0
  cnt=0
  for pred, label in zip(preds, labels):
    score = all_scorer.score("\n".join(label), "\n".join(pred))
    sample_rouge1 += score["rouge1"].fmeasure
    sample_rouge2 += score["rouge2"].fmeasure
    sample_rougeLsum += score["rougeLsum"].fmeasure
    cnt += 1
  sample_rouge1 = sample_rouge1 / cnt
  sample_rouge2 = sample_rouge2 / cnt
  sample_rougeLsum = sample_rougeLsum / cnt
  sample_rougeAve = (sample_rouge1 + sample_rouge2 + sample_rougeLsum) / 3.0
  return {"rouge1": sample_rouge1, "rouge2": sample_rouge2, "rougeLsum": sample_rougeLsum, "rougeAve":  sample_rougeAve}

Step 1 の SFT モデルのスコアです。

compute_rouge(sft_result[0], sft_result[1])
# {'rouge1': 0.4837918688472183,
#  'rouge2': 0.21163312974667375,
#  'rougeLsum': 0.44992122004620344,
#  'rougeAve': 0.38178207288003185}

次に Step 3 で学習した Actor です。

compute_rouge(actor_result[0], actor_result[1])
# {'rouge1': 0.47358949336504713,
#  'rouge2': 0.20581416734833877,
#  'rougeLsum': 0.43967977299465405,
#  'rougeAve': 0.3730278112360133}

うーん。微妙に下がってしまいました。。。

なんとなくですが、報酬モデルでのスコアは向上しているので、強化学習の枠組みは機能はしているけれど、 Critic が Actor の生成する要約の良し悪しを正しく評価できておらず、ROUGE スコアを伸ばすことが出来ていない、 という話のような気がします。

6 章の補足に記述したとおり、報酬モデルの学習は chosen が人手の参照要約、rejected が T5 の生成した要約なので、 モデルが生成する品質の要約の良し悪しを区別する能力が弱いような。chosen / rejected の差異と Actor が生成する要約の分散の乖離が大きすぎるとなかなか難しいのかもしれません。

微妙な結果になってしまいましたが、このサイズのモデルとこのデータセットでムキになっても仕方ないところもあり、 今回はこのくらいにしておこうと思います。

9. おわりに

今回は、LLM がらみで何かやろうということで、InstructGPT でも使われてる RLHF を試してみました。 微妙な結果になってしまいましたが、良さげなデータセットがあれば、またチャレンジしてみたいですね。 少し前には、りんなさんから 3.6B パラメータの PPO で学習したモデル19が公開されており、 こちらは強化学習に trlx20を使っています。trlx もそのうち試してみたいです。

次回をどうするか考えてるのですが、LangChain とかは多くの方が記事書いているので今更だよねと。 りんなさんの 3.6B パラメータの PPO モデルを使って何かできないか考えてみましょうかね。


  1. https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat 

  2. https://arxiv.org/abs/2203.02155 

  3. https://github.com/allenai/RL4LMs 

  4. スライド本体はこちらです。 https://www.slideshare.net/ShotaImai3/rlssdeepreinforcementlearning 

  5. https://arxiv.org/abs/1707.06347 

  6. https://huggingface.co/learn/deep-rl-course/unit0/introduction 

  7. Environment に確率的な部分があったり、方策が確率的な振る舞いしたりするので、同じ状態を起点にしても最終的な結果報酬は試行毎に大きく異なるということだと思います。ただ、テキスト生成の文脈だと Environment の外乱がなく、トークン選択を greedy としたりすると、ちょっと事情が違うのかもと思いました。 

  8. https://arxiv.org/abs/1502.05477 

  9. https://github.com/microsoft/DeepSpeed 

  10. https://arxiv.org/abs/1910.02054 

  11. https://arxiv.org/abs/2106.09685 

  12. 第23回の6章の作業が完了している前提で記述しています。実際に動かす方は適宜過去記事を参照して下さい。 

  13. assert で封印してあったオプションを勝手に開封してますが、その後のコミットで assert が除去されており、その間に目立った更新もなさそうだったので、「まぁ大丈夫だろう」と思ってます。皆さんは最新のやつを使ってください。 

  14. 実装に Dense ではなく Conv1D を使っているのが原因のようです。あくまで実装上の問題というか面倒なので対応しなかったということでしょう。 

  15. https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning#–others 

  16. https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning#-instablity-of-rlhf-training-and-others 

  17. https://www.deepspeed.ai/tutorials/zero/#zero-overview 

  18. https://www.deepspeed.ai/tutorials/zero-offload/ 

  19. https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-ppo 

  20. https://github.com/CarperAI/trlx