オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第28回 Prefix Tuning の検証
オージス総研 技術部 データエンジニアリングセンター
鵜野 和也
2023年8月24日

今回は Prefix Tuning の検証 のご紹介です。扱うモデルサイズがだんだんと大きくなるばかりの昨今。ファインチューニングは LoRA 的なコトでどうにかするとしても、それなりのサイズ感のモデルから派生したファインチューニング済みモデルを複数デプロイしようとすると、GPU メモリが足りません。そこを Prefix Tuning でどうにかしたいというお話です。

1. はじめに

前回 RLHF で散々苦労して、その後すぐに LIMA1 の論文を読んで魂抜けそうになりました。。。

最近、身の回りでは GPT-4 やら Function calling やら LangChain やら Llama2 やらの話題が多くてですね、 「学習済みモデルを拾ってきて手元のデータで学習!」とかしていると周囲とのズレを感じずにはいられません2 。。。 とはいえ、流行りものは沢山の人が記事を書くので、今回も日の当たらないところを拾いに行きます。

そんな訳で今回は Prefix Tuning3 です。話としては第21回の Prompt Tuning と同じような話になります。 事前学習済みモデルのパラメータを固定して、少量の追加パラメータをカチャカチャ付け替えて様々なタスクに対応する的な話です。

普通にファインチューニングするとタスク毎に事前学習済みモデルのサイズのコピーがマルっと出来てしまって、 複数タスクをデプロイしようとすると GPU メモリが足らなくなる。 なので固定の事前学習済みモデル+タスク毎の追加パラメータにすれば、嵩張る事前学習済みモデルが一つで良いので効率的だよね、ということで。

「それ、LoRA で良いじゃん。」という声が聞こえてきますが、FasterTransformer4 使いたくてですね。 LoRA で学習したモデルを FasterTransformer にデプロイしようとすると、 追加のパラメータを元モデルにフュージョンするしかなくて元の木阿弥状態になってしまって。。。

FasterTransformer は Prefix Tuning への対応が入っているので、 高速化の恩恵を受けつつ追加パラメータをカチャカチャ付け替え可能っぽいから試してみようということになりました。

そんな訳で今回は Prefix Tuning で学習をして、それを FasterTransformer 向けに変換するところまでやってみます。

まずは Prefix Tuning について、さらっと説明します。

2. Prefix Tuning

Prefix Tuning と Fine Tuning の違いは以下の図を見てもらえれば伝わるかと思います。

prefix_tuning

表解釈、要約、翻訳と下流タスクが3つあったとき、Fine Tuning だと元と同サイズのモデルが3つになるけれども、 Prefix Tuning は元のモデルは固定でタスク毎の追加パラメータを付け替えればOKという話です。

これだけ見ると第21回の Prompt Tuning と同じにも見えますが、実は少し違います。 Prompt Tuning の時は Transformer に入力するトークンの埋め込み表現の系列の先頭に追加パラメータを結合していたのですが、 入力系列の先頭に付加するのではなく、Transformer の各層のアクティベーション系列の先頭に付加する感じですね。

prompt_prefix

〇〇〇で表現しているのは入力系列の埋め込み表現と各層のアクティベーション系列で、 青系が固定パラメータによる演算結果、赤系が学習可能なパラメータです。 Prefix Tuning では Prompt Tuning に比較して変動するパラメータが増えているので、 その分精度を維持するのに有利ですね。Prefix Tuning の論文でも実験して、そのようになったようです。

performance_of_embedding_only

「アクティベーション系列の先頭に付加しようとすると、事前学習済みモデルのコードを修正しないといけないんでは?」という気もしますが、 transformers なんかのコードはその辺よく出来ていて、forward() の past_key_values を使えば OK ということみたいです。 そして FasterTransformer の実装にも対応が入っているので prefix weight を作ってしまえば、カチャカチャ付け替えは難しくなさそうです。

論文によると 0.1 % の追加パラメータで事前学習済みモデルと同等の性能を確保できたそうです。

入力系列を z、Transformer の層数を n とすると、入力系列の i 番目に対するアクティベーション
hi = [ hi1, hi2, …hin] はパラメータ Φ の事前学習モデル LM では以下のようになります。

h_i_LM

これが Prefix Tuning の場合は、プレフィックスのインデックスを Pidx として、

h_i_PrefixLM

となります。この時、Φ は固定されており θ のみが学習可能な状態です。

ここまでは各層のアクティベーション系列の先頭に学習可能なパラメータ値を連結する体で説明してきましたが、実際は少し違っていて、 i がプレフィックス長以下の場合は P’θ[i,:] を MLPθ に通した値が使われていますね。 ここで P’θ[i,:] は Pθ[i,:] に対し行数が同じで列数が小さい行列です。 論文によると Pθ[i,:] をそのまま学習パラメータとすると学習が不安定になったので、こうしているとのこと。

学習が完了すれば MLPθ、P’θ[i,:] は不要なので Pθ[i,:] だけ保存しておいて、 推論時に連結してあげれば OK です。

では、ここからは実際に Prefix Tuning を動かしてみましょう。

3. Prefix Tuning の実行

ここからは実際にコードを動かしながら Prefix Tuning をしてみます。 今回も Colab で動かす想定でコードスニペットを入れていくので、 新たにノートブックを開き、アクセラレータは GPU を選んでおいて下さい。

Prefix Tuning のコードですが、りんなさんが GPT と GPT-NeoX に対応した prefix-tuning-gpt5 を公開して下さっているので、 それをありがたく使わせて頂きます。

とりあえず、git clone して、

!git clone https://github.com/rinnakk/prefix-tuning-gpt

依存関係をインストールです。

!cd prefix-tuning-gpt && pip install -r requirements.txt

サンプルのデータは jsonl 形式で以下のようなフォーマット。

!head -3 prefix-tuning-gpt/data/sample_data.jsonl
# {"text": "午後から雨が心配だったので遠出はせず、『ふれあいロード』を走って来ました!\n確実に春が近づいてることを...らしい。"}
# {"text": "銀時にとってその十年は、長いようで短かったように思う。過ぎてみれば、の話なのだが。\n道を分かつまでの...気付いたんだから仕方ない。"}
# {"text": "自分1人ではどうしようもならないのが、借金返済・多重債務の問題です。\n潟上市の人も、まずインターネット...紹介しています。"}

prefix-tuning-gpt のサンプルコードを見ると以下のようになっています。 入力されたテキストの行末に😃を入れるように学習させるみたいですね。

#https://github.com/rinnakk/prefix-tuning-gpt/blob/bd6027bf206ddc439b6b542cc7ff094ccfeb7d29/src/prefix_tuning_example.py#L45-L65
45|def load_data_from_filepath(filepath, tokenizer, max_seq_len=64):
46|    with open(filepath, encoding="utf-8") as f:
47|        data = []
48|        for line in f:
49|            sample = json.loads(line.strip())
50|            _text = sample["text"]
51|            
52|            _sents = _text.split("\n")
53|
54|            tmp_text = ""
55|            for sent in _sents:
56|                new_tmp_text = f"{tmp_text}\n{sent} 😃"
57|                new_tmp_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(new_tmp_text))
58|                if len(new_tmp_token_ids) > max_seq_len:
59|                    tmp_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(tmp_text))
60|                    if len(tmp_token_ids) > 0:
61|                        data.append(([], tmp_token_ids))
62|                    tmp_text = f"{sent} 😃"
63|                else:
64|                    tmp_text = new_tmp_text
65|        return data

以下のようにして学習を開始します。今回はサンプルをそのまま動かすだけなのでラクですね。20分ちょっとで終わりました。

!cd  prefix-tuning-gpt/src && \
deepspeed --include localhost:0 --module prefix_tuning_example \
    --model_type gpt-neox \
    --pretrained_model_dir rinna/japanese-gpt-neox-small \
    --data_filepath ../data/sample_data.jsonl \
    --train_data_size 1000 \
    --dev_data_size 10 \
    --batch_size 4 \
    --max_lr 0.0001 \
    --deepspeed \
    --deepspeed_config ./ds_config.json \
    --save_name smileface \
    --save_model

# [2023-07-25 02:24:01,413] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
# [2023-07-25 02:24:04,033] [WARNING] [runner.py:196:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
# [2023-07-25 02:24:04,048] [INFO] [runner.py:555:main] cmd = /usr/bin/python3 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --module --enable_each_rank_log=None prefix_tuning_example --model_type gpt-neox --pretrained_model_dir rinna/japanese-gpt-neox-small --data_filepath ../data/sample_data.jsonl --train_data_size 1000 --dev_data_size 10 --batch_size 4 --max_lr 0.0001 --deepspeed --deepspeed_config ./ds_config.json --save_name smileface --save_model
# [2023-07-25 02:24:06,714] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
# ...
<Dev> - Epoch 49 - 1217.455s - monitor: 3.5868, ppl: 36.297
checkpoint saved to data/model/smileface_prefix_encoder.seed42.2023-07-25-02-25-46.checkpoint
[2023-07-25 02:46:06,543] [INFO] [launch.py:347:main] Process 1561 exits successfully.

少し脱線

私は最初「ds_config.json の内容はどうしたら良いの?」と思いましが、 ソースを読んでみると DeepSpeed の設定はコマンドライン引数を元に内部で形成し、 その内容が ./ds_config.json に書き出されるという仕様になってました。

この prefix-tuning-gpt ですが、誤解を恐れず言えば「大事なところ以外はかなり大雑把」に作ってあります。 なので、「あれ?」と思うところはあるのですが、そこは「適当に直して使ってね」というノリなんだと思います (スクリプトの名前からして prefix_tuning_example ですし)。

とは言え、コード自体は DeepSpeed を使って書かれていて、ちょっと修正すれば rinna/japanese-gpt-neox-3.6b-instruction-ppo でも普通に使えたので、大事なところはちゃんとしてました。

話を元に戻します

学習が終わると prefix weight が出来てます。前章の Pθ ですね。

!ls prefix-tuning-gpt/data/model
# smileface_prefix_encoder.seed42.2023-07-25-02-25-46.best.checkpoint
# smileface_prefix_encoder.seed42.2023-07-25-02-25-46.checkpoint

以下のようにして推論を動かします。"Prompt: (Use Ctrl+D or Ctrl+C to exit)“ の後に何かしらテキストを入力してみて下さい。

FILENAME="smileface_prefix_encoder.seed42.2023-07-25-02-25-46"
!cd prefix-tuning-gpt/src && CUDA_VISIBLE_DEVICES=0 python -m prefix_inference \
    --model_type gpt-neox \
    --pretrained_model_dir rinna/japanese-gpt-neox-small \
    --prefix_checkpoint_path ../data/model/{FILENAME}.best.checkpoint

# Building model...
# You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
# Building prefix-tuning model...
# Prompt: (Use Ctrl+D or Ctrl+C to exit)YOASOBIのアニメ主題歌、作品の世界観が投影されてていいですね。
# 2023-07-25 02:59:47.640026: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
# YOASOBIのアニメ主題歌、作品の世界観が投影されてていいですね。😃 音楽が好きな人、アニメが好きでアニメについて知りたいひとに読んで欲しい。 😃
# Prompt: (Use Ctrl+D or Ctrl+C to exit)

入力したテキストの続きが生成されるわけですが、ちゃんと「😃」が付与されてるのが確認できました。

prefix weight なしで推論すると以下のようになり、事前学習済みモデルが元のままであることがわかります。

!cd prefix-tuning-gpt/src && CUDA_VISIBLE_DEVICES=0 python -m prefix_inference \
    --model_type gpt-neox \
    --pretrained_model_dir rinna/japanese-gpt-neox-small
# Building model...
# You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
# Prompt: (Use Ctrl+D or Ctrl+C to exit)ただ子供が口ずさむにはキビシイね。ペガサス幻想みたいなのにしてあげて欲しい。
# 2023-07-25 03:04:36.320015: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
# ただ子供が口ずさむにはキビシイね。ペガサス幻想みたいなのにしてあげて欲しい。>そして、この掲示板で、このスレでの書き込みには、何と書き込むかはあなたの自由で結構です。 投稿したらその書き込み主が反応しなくなるのね。 そうするとレスの重複を削除しなきゃいけないから注意喚起のつもりだけど 投稿する前に書き込み内容とあなたのレスが合っているか確認しよう。 まあ、あなたにも「このレスはここでの書き込みには必要ない」という ルールがあるから「そこでは必要ない」というのはおかしいけど、 それ以前の問題だよね。 削除人さんにはまず投稿者のプロフィールを確認するべき。 そしてそこで、どういう事を書くか、何を削除したいか、 どういう理由で
# Prompt: (Use Ctrl+D or Ctrl+C to exit)

さて、ここからは prefix weight を FasterTransformer 向けに変換する手順を確認していきましょう。

4. prefix weight の FasterTransformer 変換

ここからは学習済みの prefix weight を Triton + FasterTransformer Backend の環境にデプロイできるように変換していきます。

公式ドキュメント6に手順の記載があるのですが動かなかったので手順を修正しています。 変換用のコードはここ7にある huggingface_jp_gptneox_convert.py なのですが、config.json の想定パラメータ名が GPT-NeoX ではなく GPT のものになってるみたいです。 それ以外は問題なさそうなので、少し改修して使いました。

まず FasterTransformer と fastertransformer_backend を取ってきます。

!git clone https://github.com/triton-inference-server/fastertransformer_backend
!git clone https://github.com/NVIDIA/FasterTransformer

変換に必要なので事前学習済みモデルの重みも取得します。

!git lfs clone https://huggingface.co/rinna/japanese-gpt-neox-small

transformers も必要ですね。

!pip install transformers

以下のようにして事前学習済みモデルのみを FasterTransformer の形式に変換します (使うのは "jp” のついていない素の huggingface_gptneox_convert.py なので注意して下さい)。

!python ./FasterTransformer/examples/pytorch/gptneox/utils/huggingface_gptneox_convert.py \
  -in_file ./japanese-gpt-neox-small \
  -saved_dir ./fastertransformer/\
  -model_name gptneox \
  -infer_gpu_num 1 \
  -weight_data_type fp16

# =============== Argument ===============
# saved_dir: ./fastertransformer/
# in_file: ./japanese-gpt-neox-small
# infer_gpu_num: 1
# processes: 4
# weight_data_type: fp16
# model_name: gptneox
# ========================================  

出力先には沢山の重みファイルと設定ファイルが生成されています。

!ls ./fastertransformer/1-gpu | tail -5

# model.layers.9.mlp.dense_h_to_4h.weight.0.bin
# model.layers.9.post_attention_layernorm.bias.bin
# model.layers.9.post_attention_layernorm.weight.bin
# model.lm_head.weight.bin
# model.wte.bin

設定ファイルの中身は以下のとおりです。

!cat ./fastertransformer/1-gpu/config.ini

# [gptneox]
# model_name = gptneox
# head_num = 12
# size_per_head = 64
# inter_size = 3072
# num_layer = 12
# rotary_embedding = 64
# vocab_size = 44416
# start_id = 2
# end_id = 3
# use_gptj_residual = 0
# weight_data_type = fp16

prefix weight の追加変換関数です。紙面上の収まりが良いようにインデントを詰めてますが、ほとんどそのままだったと思います。

# Copied and modified from
# https://github.com/NVIDIA/FasterTransformer/blob/main/examples/pytorch/gptneox/utils/huggingface_jp_gptneox_convert.py
import configparser
import os
from pathlib import Path
import numpy as np
from transformers import GPTNeoXForCausalLM
import torch

def get_weight_data_type(data_type):
  if data_type == "fp32":
    return np.float32
  elif data_type == "fp16":
    return np.float16
  else:
    assert False, f"Invalid weight data type {data_type}"

def prefix_prompt_convert(args, config, weight_data_type):
  saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num
  prompt_in_file_list = args.prompt_in_file_list.split(',')

  task_list = []
  for idx, prompt_in_file in enumerate(prompt_in_file_list):
    weights=torch.load(prompt_in_file)
    task_name = prompt_in_file.split("/")[-1].split("_")[0]

    total_size = weights.nelement()
    n_layers = config['num_hidden_layers']
    n_head = config['num_attention_heads']
    size_per_head = config['hidden_size'] // n_head
    prefix_prompt_len = total_size // (2 * n_layers * n_head * size_per_head)

    task_list.append((task_name, prefix_prompt_len))
    # GPT NeoX
    weights=weights.view(prefix_prompt_len,n_layers,2,n_head,size_per_head) ## prefix_seq_len, num_layers, 2, num_heads, size_per_head
    # weights=weights.view(prefix_prompt_len,28,2,16,256) ## prefix_seq_len, num_layers, 2, num_heads, size_per_head
    weights=weights.permute(1,2,3,0,4) ## num_layers, 2, num_heads, perfix_seq_len, size_per_head
    local_head_num = n_head // args.infer_gpu_num
    weights_split = torch.split(weights, local_head_num, dim=2)
    for i in range(args.infer_gpu_num):
      output_file_path = saved_dir + "/model.prefix_prompt." + task_name + ".weight." + str(i) + ".bin"
      weights_split[i].detach().cpu().numpy().astype(weight_data_type).tofile(output_file_path)

  return task_list

def convert_prompt_weight(args):
  config = configparser.ConfigParser()
  saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num
  config_ini_path = saved_dir + "config.ini"
  config.read(config_ini_path)

  model = GPTNeoXForCausalLM.from_pretrained(args.in_file)
  hf_config = vars(model.config)
  np_weight_data_type = get_weight_data_type(args.weight_data_type)

  task_list = []
  if args.prompt_in_file_list is not None:
    task_list = prefix_prompt_convert(args, hf_config, np_weight_data_type)
  if len(task_list) > 0:
    config['gptneox']['num_tasks'] = str(len(task_list))
    config['gptneox']['prompt_learning_type'] = str(2)
    for idx, (task_name, prompt_length) in enumerate(task_list):
      config[f'task_{idx}'] = {}
      config[f'task_{idx}']['task_name'] = task_name
      config[f'task_{idx}']['prompt_length'] = str(prompt_length)
    with open(config_ini_path, 'w') as configfile:
      config.write(configfile)

さて、今回はお手軽に huggingface_jp_gptneox_convert.py の関数をコピー&修正して Colab でそのまま動かすので、ダミーの args を用意します。 prefix weight が複数ある時は taskA_prefix_encoder.*.checkpoint,taskB_prefix_encoder.*.checkpoint のようにカンマ区切りで記述します。

MODEL_PATH = "./japanese-gpt-neox-small"
PROMPT_WEIGHT = "./prefix-tuning-gpt/data/model/smileface_prefix_encoder.seed42.2023-07-25-02-25-46.best.checkpoint"
SAVED_DIR = "./fastertransformer"

from dataclasses import dataclass
@dataclass
class Args:
  in_file: str
  prompt_in_file_list: str
  weight_data_type: str
  saved_dir: str
  infer_gpu_num: int

 args = Args(MODEL_PATH, PROMPT_WEIGHT, "fp16", SAVED_DIR, 1)
 args
 # Args(in_file='./japanese-gpt-neox-small', prompt_in_file_list='./prefix-tuning-gpt/data/model/smileface_prefix_encoder.seed42.2023-07-25-02-25-46.best.checkpoint', weight_data_type='fp16', saved_dir='./fastertransformer', infer_gpu_num=1)

prefix weight の変換処理を実行します。

convert_prompt_weight(args)

変換後のモデルに prefix weight が追加されてますね。

!ls ./fastertransformer/1-gpu/ | tail -5

# model.layers.9.post_attention_layernorm.bias.bin
# model.layers.9.post_attention_layernorm.weight.bin
# model.lm_head.weight.bin
# model.prefix_prompt.smileface.weight.0.bin
# model.wte.bin

設定ファイルには追加したタスクが登録されています。

!cat ./fastertransformer/1-gpu/config.ini

# [gptneox]
# model_name = gptneox
# head_num = 12
# size_per_head = 64
# inter_size = 3072
# num_layer = 12
# rotary_embedding = 64
# vocab_size = 44416
# start_id = 2
# end_id = 3
# use_gptj_residual = 0
# weight_data_type = fp16
# num_tasks = 1
# prompt_learning_type = 2
#
# [task_0]
# task_name = smileface
# prompt_length = 10

Colab なのでテスト実行することはできませんが、pseudo コードを書くとこんな感じです。 task_id の値は config.ini の内容に合わせて下さい。

task_id = 0

import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype

triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)

def prepare_tensor(name, input, protocol):
    client_util = httpclient if protocol == "http" else grpcclient
    t = client_util.InferInput(
        name, input.shape, np_to_triton_dtype(input.dtype))
    t.set_data_from_numpy(input)
    return t

model_name = ...
model_version = ...
input_ids = ...
...

task_name_ids = np.zeros([input_ids.shape[0], 1]).astype(np.uint32) + task_id

inputs = [
  prepare_tensor("input_ids", input_ids, "grpc"),
  ...
  prepare_tensor("prompt_learning_task_name_ids", task_name_ids, "grpc")
]

result = triton_client.infer(model_name, inputs, model_version=model_version)

ただ small サイズで Prefix Tuning してもつまらないので、最後にオマケを付けておきます。

5. rinna/japanese-gpt-neox-3.6b-instruction-ppo の場合

prefix-tuning-gpt そのままだと少々都合が悪かったので以下の修正を適用しました。

  • 学習データのロード処理を改変。1行1サンプルのテキストデータを想定。
  • –zero_stage を追加。ZeRO の設定は DeepSpeed-Chat から流用しました。
  • –offload を追加。ZeRO Offload を有効にします。こちらも DeepSpeed-Chat から流用。
  • –fp16 を使用した場合、最初のロード時に直接 float16 でロードするようにしました(CUDA が OOM でこけたので)
%%writefile ./diff
diff --git a/src/prefix_tuning_example.py b/src/prefix_tuning_example.py
index d4823ec..55f433b 100644
--- a/src/prefix_tuning_example.py
+++ b/src/prefix_tuning_example.py
@@ -33,7 +33,7 @@ from torch.utils.data import RandomSampler
 from torch.utils.data.distributed import DistributedSampler
 from transformers import T5Tokenizer
 import deepspeed
-from deepspeed.ops.adam import FusedAdam as Adam
+from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam

 from data_source import DataSource, collate_fn
 from model.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
@@ -42,27 +42,15 @@ from model.prompt.prefix_tuning import PrefixEncoder, PrefixWrapper, init_prefix
 from util import StatisticsReporter, get_linear_schedule_with_warmup, print_rank_0, is_rank_0, count_parameters


-def load_data_from_filepath(filepath, tokenizer, max_seq_len=64):
-    with open(filepath, encoding="utf-8") as f:
-        data = []
-        for line in f:
-            sample = json.loads(line.strip())
-            _text = sample["text"]
-
-            _sents = _text.split("\n")
-
-            tmp_text = ""
-            for sent in _sents:
-                new_tmp_text = f"{tmp_text}\n{sent} 😃"
-                new_tmp_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(new_tmp_text))
-                if len(new_tmp_token_ids) > max_seq_len:
-                    tmp_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(tmp_text))
-                    if len(tmp_token_ids) > 0:
-                        data.append(([], tmp_token_ids))
-                    tmp_text = f"{sent} 😃"
-                else:
-                    tmp_text = new_tmp_text
-        return data
+def load_data_from_filepath(filepath, tokenizer):
+  with open(filepath, encoding="utf-8") as f:
+    data = []
+    for line in f:
+      sample = json.loads(line.strip())
+      token_ids = tokenizer(sample["text"]).input_ids
+      if len(token_ids) > 0:
+        data.append(([], token_ids))
+  return data


 def forward_step(model, prefix_encoder, tokenizer, batch_data):
@@ -199,13 +187,17 @@ def training(local_rank, config):

     # build model
     print_rank_0("Building model...")
+    if config.fp16:
+      torch_dtype = torch.float16
+    else :
+      torch_dtype = "auto"
     if config.model_type == "gpt":
-        base_model = GPT2LMHeadModel.from_pretrained(config.pretrained_model_dir)
+        base_model = GPT2LMHeadModel.from_pretrained(config.pretrained_model_dir, torch_dtype=torch_dtype)
         base_model_n_embd = base_model.config.n_embd
         base_model_n_layer = base_model.config.n_layer
         base_model_n_head = base_model.config.n_head
     elif config.model_type == "gpt-neox":
-        base_model = GPTNeoXForCausalLM.from_pretrained(config.pretrained_model_dir)
+        base_model = GPTNeoXForCausalLM.from_pretrained(config.pretrained_model_dir, torch_dtype=torch_dtype)
         base_model_n_embd = base_model.config.hidden_size
         base_model_n_layer = base_model.config.num_hidden_layers
         base_model_n_head = base_model.config.num_attention_heads
@@ -293,6 +285,7 @@ def training(local_rank, config):
         }
     ]
     if config.deepspeed:
+        Adam = DeepSpeedCPUAdam if config.offload else FusedAdam
         optimizer = Adam(
             optimizer_grouped_parameters,
             lr=config.max_lr,
@@ -544,6 +537,8 @@ if __name__ == "__main__":
     # deepspeed
     parser.add_argument("--fp16", action="store_true", help="use fp16 for deepspeed")
     parser.add_argument("--bf16", action="store_true", help="use bf16 for deepspeed")
+    parser.add_argument("--zero_stage", type=int, default=0, help="ZeRO optimization stage")
+    parser.add_argument("--offload", action="store_true", help="use ZeRO offload")
     parser = deepspeed.add_config_arguments(parser)

     config = parser.parse_args()
@@ -553,18 +548,6 @@ if __name__ == "__main__":
     deepspeed_config = {
         "train_micro_batch_size_per_gpu": config.batch_size,
         "gradient_accumulation_steps": config.n_accum_steps,
-        "optimizer": {
-            "type": "Adam",
-            "params": {
-                "lr": config.max_lr,
-                "betas": [
-                    config.adam_beta1,
-                    config.adam_beta2
-                ],
-                "eps": config.adam_eps,
-                "weight_decay": config.weight_decay
-            }
-        },
         "scheduler": {
             "type": "WarmupDecayLR",
             "params": {
@@ -591,6 +574,22 @@ if __name__ == "__main__":
         },
         "steps_per_print": config.log_every_n_steps
     }
+    device = "cpu" if config.offload else "none"
+    if config.zero_stage >= 2:
+      deepspeed_config["zero_optimization"] = {
+        "stage": config.zero_stage,
+        "offload_optimizer": {
+            "device": device,
+        },
+        "offload_param": {
+            "device": device,
+        },
+        "stage3_param_persistence_threshold": 1e4,
+        "stage3_max_live_parameters": 3e7,
+        "stage3_prefetch_bucket_size": 3e7,
+        "memory_efficient_linear": False
+      }
+
     if config.deepspeed and config.local_rank == 0:
         if os.path.exists(config.deepspeed_config):
             os.remove(config.deepspeed_config)
@@ -600,4 +599,4 @@ if __name__ == "__main__":
     import logging
     logging.basicConfig(level=logging.WARNING)

-    training(config.local_rank, config)
\ No newline at end of file
+    training(config.local_rank, config)

パッチを適用します。

!git clone https://github.com/rinnakk/prefix-tuning-gpt && \
  cd prefix-tuning-gpt && \
  git checkout bd6027bf206ddc439b6b542cc && \
  git apply ../diff

このコードを使って学習を動かします。

注意

Colab で動かした風の記述にしてますが、さすがに無料の Colab では動かないので、GCE で vCPUx4, Mem 64 GB, Tesla T4 x 2 の VM を作って試しました(vCPU/Mem はもっと削れそうです)。

%%bash
 ( cd ./prefix-tuning-gpt/src && deepspeed --num_gpus=2 \
    --module prefix_tuning_example \
    --model_type gpt-neox \
    --pretrained_model_dir rinna/japanese-gpt-neox-3.6b-instruction-ppo \
    --data_filepath ./path/to/data/file
    --train_data_size 10000 \
    --dev_data_size 100 \
    --batch_size 1 \
    --n_epochs 5 \
    --n_accum_steps 8 \
    --n_warmup_steps 100 \
    --prefix_seq_len 20 \
    --prefix_input_dim 64 \
    --prefix_hidden_dim 64 \
    --fp16 \
    --max_lr 0.0001 \
    --deepspeed \
    --save_name your_task_name \
    --save_model \
    --deepspeed_config ./ds_config \
    --zero_stage 2 \
    --offload)

ZeRO2 と ZeRO Offload の併用で動きました。上記の設定では prefix 長が 20 で 3.6B の約 0.3 % の学習パラメータ数です。 ログを確認するとこんな感じです。

Trainable backbone model parameters: 0
Trainable prefix encoder parameters: 13184320

論文は 0.1 % だったので、やや多めでしょうか。その良し悪しは評価してませんが、今回の手順で FasterTransformer 変換して、 Triton で推論できてます。

6. おわりに

今回は Prefix Tuning と FasterTransformer への適用のご紹介でした。 しかし「今時は OpenAI の API 呼んでモノ作るよねぇ」と思いながら書いてるのでイマイチ気持ちが盛り上がりませんw。 そんな訳なので次回は LangChain と OpenAI API を使ってなにかしら、やって見ようと思います。 ネットに山ほど記事が転がっているネタなので内容的に出来るだけ被らないようにしたいですが、どうなるでしょうか。。。


  1. https://arxiv.org/abs/2305.11206 1000件を丁寧に作って普通に回したら結構イケたよ。みたいな。「あぁ、そうなんですか。。。」ってなりました。 

  2. そっち系も手を出したいと思っているので、検索しても引っかけられない内容が書けそうになったら、なんか書くかも。 

  3. https://arxiv.org/abs/2101.00190 2年前の論文ですけど、えらい昔に感じられますね。。。 

  4. FasterTransformer を好んで使っていますが、CTranslate2 の方が高速みたいですね( https://github.com/OpenNMT/CTranslate2#gpu )。今回のような話の対応状況みて、CTranslate2 も考えたいです。。。けど、結局 OpenAI の API 使うことになるか。。。 

  5. https://github.com/rinnakk/prefix-tuning-gpt 

  6. https://github.com/triton-inference-server/fastertransformer_backend/blob/main/docs/gptneox_guide.md#run-gpt-neox-with-prompt-tuning 

  7. https://github.com/NVIDIA/FasterTransformer/blob/main/examples/pytorch/gptneox/utils/huggingface_jp_gptneox_convert.py