オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第23回 BRIO による抽象型要約の検証
オージス総研 技術部 データエンジニアリングセンター
鵜野 和也
2022年10月27日

今回は BRIO を使って抽象型要約に挑戦してみようと思います。 BRIO は特定のモデルに依存しない手法で、論文では PEGASUS と BART を使って実験をしています。今回は T5 を使って BRIO を試し、素の T5 と比較してみましょう。

1. はじめに

今回は BRIO1 を使って抽象型要約に挑戦してみようと思います。

そういえば、この連載で要約モデルを扱うのは初めてですね。要約には大きく分けて抽出型と抽象型の二種類があります。 簡単に説明すると、抽出型は長い文章全体から重要そうな文をつまみ食いして短くする手法、 抽象型は Seq2seq モデルで文章全体の意味をとらえた短い要約文を生成する手法です。

過去に抽出型要約は LexRank2 や劣モジュラ関数を使った手法3などを試したことがある(この連載の記事にはしてません)のですが、抽象型要約は試したことがなく、やってみたいと思っていました。

抽象型要約というのは結局、「文章を入力して、それを要約した文章を出力する」だけですから、データセットさえあれば T54 等でテキスト変換すればとりあえず出来てしまいます。「それでは新鮮味がないし。。。」と思っていたところで BRIO を見つけました。

BRIO は特定のモデルに依存しない手法で、論文では PEGASUS と BART を使って実験をしています。日本語の事前学習モデルの都合もあるので、今回は T5 を使って BRIO を試し、素の T5 と比較する形でやってみます。

BRIO に限らず要約モデルの評価指標には ROUGE スコアが用いられます。BRIO では評価だけでなく学習そのものにおいても重要な要素になるので、 まず、ROUGE についてご説明します。

2. ROUGE スコア

ROUGE(Recall-Oriented Understudy for Gisting Evaluation)5 は機械翻訳の評価指標である BLEU6 を参考にして考案されました。基本的な考え方としては「機械で生成した文章(要約候補)がお手本の文章(参照要約)にどの程度合致するか?」を測定します。この辺りは BLEU と同じですね。

ROUGE スコアにもいろいろと計算方法があるのですが、一番基本的なのが BLEU と同様に 文章中の N-Gram の合致を測る ROUGE-N です。単語一つの合致を評価するのが ROUGE-1、連続する 2 単語なら ROUGE-2 という具合です。

ROUGE-N

ROUGEスコアの論文5によると ROUGE-N として以下の計算式が記載されています。

rouge_socre

私はこの数式を見るたびに Count と Countmatch の意味合いで悩むのですが、結局のところは以下のようになります。

  • 分母は、「参照要約に含まれる全ての N-Gram」の出現回数の合計
  • 分子は、「参照要約に含まれる全ての N-Gram」について、「要約候補での出現回数」と「参照要約での出現回数」の最小値の合計7

つまり、参照要約に含まれる N-Gram が要約候補でどの程度拾われたかという Recall ですね。

なのですが、実際のところは上記の要領で Precision も計算して F1 スコアを算出するというのが一般的な使い方のようです。 今回紹介する BRIO でも F1 でスコアが記載されていますし、ROUGE スコアを算出するライブラリも多くが F1 スコアを返す(あるいはデフォルトになっている) ようです8

次に BRIO の論文では ROUGE-1, 2 の他に ROUGE-L のスコアも記載されているので、 そちらも確認してみましょう。

ROUGE-L

ROUGE-L は参照要約と要約候補の間の LCS(Longest Common Subsequence) を評価します。

LCS は文字通りで「二つの系列に共通する Subsequence で最も長いもの」なのですが、Subsequence の考え方には注意が必要です。元の系列が X = {x1, x2, x3, x4, x5} だとすると、

  • x1, x2, x3
  • x1, x3, x5

のようなもので、必ずしも連続した部分系列である必要はありません。

「LCS が長ければ二つの要約文は似ているだろう」というのが ROUGE-L の考え方になります。

また、ROUGE-L には Sentence-Level LCS と Summary-Level LCS の2種類があり、それぞれROUGE スコアの論文5の 3.1 節、 3.2 節に記載されています。 ROUGE スコアの計算を実装したライブラリでは、前者が rougeL 、後者が rougeLsum と呼ばれたりします。

BRIO では後者の Summary-Level LCS が用いられており、計算式は以下のようになります。

rouge_lsum

数式中の記号の意味合いは以下のとおりです。

  • 参照要約 R : u 個の文 {r1, r2,…,ri,…ru} で構成される。
  • 要約候補 C : k 個の文 {c1, c2,…,cj,…ck} で構成される。
  • m : 参照要約 R の単語数。
  • n : 要約候補 C の単語数。
  • LCS(ri, C) : ri と cj (j=1..k) の LCS(Longest Common Sequence)の和集合

うーん、LCS(ri, C) が分かりにくいですね。。。。 ここは論文5でも補足がありました。

具体例で説明すると以下のようになります。

R = {r1}, C = {c1, c2} があったとして。

  • r1 = w1, w2, w3, w4, w5
  • c1 = w1, w2, w6, w7, w8
  • c2 = w1, w3, w8, w9, w5

だとすると、

  • r1 と c1 の LCS : w1, w2
  • r1 と c2 の LCS : w1, w3, w5

になります。

  • 上記二つのLCSの和集合 : w1, w2, w3, w5

ですから、

  • RLCS = 4 / 5
  • PLCS = 4 / 10

になるわけです。βですが BRIO が利用する実装9を確認すると 1 でした。なので FLCS は普通の F1 スコアですね。

今回利用している実装では二重カウントの除外等もう少し複雑なことをやるのですが説明は省略します。気になる人はソースコードをチェックしてみて下さい9

それでは、評価指標の話はこれぐらいにして、本題である BRIO の説明に移ります。

3. BRIO

BRIO の話の前に T5 等の Seq2Seq モデルで抽象型要約をする場面を考えてみましょう。 モデルへの入力は要約対象の文章で、ラベルはその要約文(以降、参照要約と記述)です。これを最尤推定(MLE)でモデルが参照要約を出力する確率を最大化します。

seq2seq_generation

この連載でよく利用している T5 の学習は基本的にこのパターンですね。 推論時はデコーダーをシーケンスの位置 t-1 の出力を t の入力にする自己回帰で回し、 1 トークンづつ推論していきます。

ただ、自己回帰で要約を生成する途中、間違ったトークンを出力することもある訳で、その場合は後続トークンは間違いを前提にして出力されてしまいます。 この振る舞いは exposure bias10 と呼ばれ、モデルの性能に悪影響を及ぼす原因になります。

この推論途中で間違いを犯した場合にも性能を維持する為の対策として、モデルに要約文の品質を相対的に評価する能力を与えようというのが BRIO になります。

論文1では BART で構築した Seq2seq ベースの要約モデルがどの程度「要約文の品質を相対的に評価する能力」を有しているか、予備実験で確認しています。

preliminary_experiment

“BART” が Seq2seq モデルを MLE で最適化したもの、"Ours" が BRIO です。"High" は “BART” で入力文に対して複数生成した要約文のうち、最も品質が良かったもの、"Low" はその逆。品質は ROUGE-½/L での評価です。"Acc.“ はモデルが(ROUGE-½/L的に)高い品質の生成要約文に、より高い確率を割り当てた比率です。

つまり、Seq2seq を MLE で最適化したモデル("BART”) では 「生成した要約文の ROUGE-½/L による評価」と「モデルがその要約文を生成する確率」の高低が一致する割合は 54.8 % に過ぎなかったということです(そして BRIO ではそれが 79.63% まで改善しています)。

これは MLE による学習でラベル(参照要約)にだけ高い確率を割り当て、それ以外は全部アウトという学習をしていることが原因です。 つまり「大谷がベーブルース以来の二桁本塁打かつ二桁勝利の偉業を達成」というお手本に対し、「大谷104年ぶりの偉業達成。二桁本塁打&二桁勝利」という要約も、「大谷の首振り人形!出来栄えについてSNSでは微妙な評価。」という要約も同じ扱いになるわけですね。確かにこれは改善の余地がありそうです。

そこで BRIO ではモデルに対して以下の二つの役割を与えています。

  • 生成モデルとして、要約対象の文章を参照要約に変換するように学習する。
  • 評価モデルとして、要約対象の文章に対する要約候補に対し、ROUGE-½/Lの評価値が高い場合に、より高い生成確率を割り当てるように学習する。

この生成モデルと評価モデルのマルチタスク学習により、BRIO は自己回帰での要約生成時に exposure bias に晒されても、ROUGE-½/L での評価値が高くなりそうな候補を優先して出力できるようになっています。

次にそれぞれのタスクの損失関数を見ていきましょう。

生成モデルの損失関数

まず生成モデルです。本章冒頭の図のとおりですね。入力文から要約文を生成するモデルを g 、モデルのパラメータを θ 、学習データを D 、参照要約 S の長さを l 、各トークンを {s1, s2, …, sl} とすると、ある学習サンプルに対する損失は以下のようになります。

cross_entropy_loss

最初のΣは参照要約の先頭から最後までの足し合わせ、二つ目のΣは少し分かりにくいかもしれませんが、参照要約の各トークンの位置で語彙集合をグルっと回る形ですね。 ptrue は普通なら s がラベルと合致したら 1 そうでなければ 0 なのですが、ラベルスムージングと呼ばれる手法を使うと以下のようになります。

label_smoothing

ここで N は語彙数です。要するに 正解なら 1 そうでなければ 0 だったところを 1 から β だけ切り分けて外れの語彙に均等に分け与えただけですね。 公開されているコード11を確認すると BRIO では β = 0.1 でラベルスムージングを適用していました。

さて、生成モデルはこれぐらいにして、評価モデルを見ていきましょう。

評価モデルの損失関数

評価モデルは以下のようなイメージで、生成済みの要約候補とその評価値(下図の M(・) )から損失を算出します。

seq2seq_evaluation

損失は contrastive loss になっていて、以下のとおりです。

contrastive_loss

  • Si, Sj : 二つの異なる要約候補です。
  • S : 参照要約です。
  • ROUGE(Si, S) > ROUGE(Sj, S) ∀i, j, i<j
    • 要するに 「i と j は ROUGE スコア12の良い順に並べたときの順位を示します」ということですね。順位の数字が小さければ ROUGE スコアは大きくなります。
    • ROUGE スコアが図中の M(・) に相当します。実際の学習では事前になんらかの要約モデルにより要約候補を生成、その ROUGE スコアを計算しておきます。それを ROUGE スコアの降順にソートした上で BRIO に投入するので、学習サンプルに含まれる要約候補リストのインデックスが i, j に相当する形になります。
  • λij : Si, Siの生成確率の間に確保するマージン。λi,j = (j - i) * λ です。
    • λはハイパーパラメータで i, j は ROUGE スコアの良い順に並べたときの順位です。順位の差が大きければマージンが大きくなる仕組みです。
  • f(S) は長さで正規化された S が生成される対数確率です。α はハイパーパラメータです。

論文1に “contrastive loss” と記載されているので、そう書いてますが、どちらかというと Contrastive Learning だった前回の損失関数よりも、第9回 の Triplet Loss に近い形ですね。絵にすると以下のような形です。

contrastive_loss_image

生成した要約候補を ROUGE スコアの良い順にならべ、順位 i と 順位 j の生成確率の差に順位差 (j-i) で決定されるマージン λij を確保するように学習する訳ですね。

最終的なマルチタスク学習の損失 Lmul は Lxent にハイパーパラメータ γ を掛けた Lctrl を足し合わせるだけです。

multi_task_loss

これなら、参照要約に近い文章を生成しつつ、その生成確率を評価指標による評価の高低と合致させられそうですね。

では、ここから実際にコードを修正しつつ動かしていきましょう。

4. 環境のセットアップ

ここからは実際にコードを修正しながら BRIO の学習を行ってみます。今回も Colab で動かす想定でコードスニペットを入れていくので、 新たにノートブックを開き、アクセラレータは GPU を選んでおいて下さい。

BRIO は特定の Transformer に依存しない学習方法で、公開されているソースコード13では PEGASUS と BART の実装が提供されています。 ですが、日本語の事前学習モデルの入手しやすさを考慮して今回は T5 を Transformer に使うようコードを修正して実験しました。

それでは必要なライブラリ等をインストールしていきましょう。まずは BRIO のコードを取得して、

!git clone https://github.com/yixinL7/BRIO
!cd BRIO && git checkout 135f0e5cc5671fe4faa45ff3e05969969686419a
# ...
# HEAD is now at 135f0e5 Update main.py

依存するライブラリをインストールします。

!cd BRIO && pip install -r requirements.txt

ROUGE スコアの計算で使う compare-mt を取得します。

!git clone https://github.com/neulab/compare-mt.git
!cd compare-mt && git checkout b6d8f79d02043243c3d8aa58373a0f4c55e17a69
# ...
# HEAD is now at b6d8f79 Fix bar chart color cycling (#134)

こちらも依存ライブラリと共にインストールします。

!cd ./compare-mt && pip install -r requirements.txt
!cd ./compare-mt && python setup.py install

MeCab 関係も入れておきます。

!apt-get install mecab mecab-ipadic-utf8 python-mecab libmecab-dev
!pip install mecab-python3 fugashi ipadic

文章を文単位に分割する必要があったので、そこは GiNZA を使いました。

!pip install ginza ja_ginza==5.1.0

ここでランタイムを再起動しておいて下さい。

5. データセットの取得

この章と次の章は Colab で動かした風に記述していますが、GPU は不要です。時間もかかる処理なのでお手元の端末で動かすのが良いかもしれません。

学習するには要約の日本語データセットが必要なわけですが、今回は3行要約データセット14を使いました。

!git clone https://github.com/KodairaTomonori/ThreeLineSummaryDataset
!ls ThreeLineSummaryDataset/data
# develop.csv  test.csv  train.csv

このデータセットに含まれるのは LivedoorNews の記事 ID で記事本文とその要約はクローリングして収集する必要があります。 今回はこちらの記事15のコードをお借りして少し微修正して使わせて頂きました。

収集した記事は GCS に保存するので認証を通します。

from google.colab import auth
auth.authenticate_user()

以下は記事 ID からコンテンツを取得する関数です。

  • 追記 :
    最後の最後に気付いたのですが、この関数に少々問題があるようです。13. 要約の生成に少し書いておいたので参照して下さい。
from urllib.request import urlopen
from bs4 import BeautifulSoup
from bs4.element import NavigableString
from pprint import pprint
import time

INTERVAL = 10

def get_content(id):
    time.sleep(INTERVAL)
    URL = 'https://news.livedoor.com/article/detail/'+id+'/'
    try:
        with urlopen(URL) as res:
            output1 = ''
            html = res.read().decode('euc_jp', 'ignore')
            soup = BeautifulSoup(html, 'html.parser')
            lineList = soup.select('.articleBody p')
            for line in lineList:
                if len(line.contents) > 0 and type(line.contents[0]) == NavigableString:
                    output1 += line.contents[0].strip()
            if output1 == '': 
                return
            output1 += '\n'

            output0 = ''
            summaryList = soup.select('.summaryList li')
            for summary in summaryList:
                output0 += summary.contents[0].strip()+'\t'
            if output0 == '':
                return

            return (output0+output1)
    except Exception:
        print('Exception')

上記の関数を使ってクローリングする関数です。

import math
BUCKET = "somewhere"
def crawl(file_name, output_format="output_{}_{}.tsv", chunk_size=100, resume_from = 0, stop_at=None):
  ids = []
  with open(file_name, mode='r') as f:
    lines = f.readlines()
    for line in lines:
      id = line.strip().split(',')[3].split('.')[0]
      ids.append(id)
  num_contents = len(ids)
  print("num_contents : {}".format(num_contents))
  num_chunks = math.ceil(num_contents / chunk_size)
  print("num_chunks : {}".format(num_chunks))
  if resume_from > 0:
    print("Resuming from {}...".format(resume_from))
  heads = [chunk_size * i for i in range(0, num_chunks)]
  for head in heads:
    tail = head + chunk_size
    if tail > num_contents:
      tail = num_contents

    if tail <= resume_from:
      continue
    if stop_at and head >= stop_at:
      return

    output_file = output_format.format(head, tail) 
    print("Start crawling for {}...".format(output_file))
    with open(output_file, "a") as f:
      for idx in range(head, tail):
        if idx < resume_from:
          continue
        if stop_at and idx >= stop_at:
          return
        if idx % 30 == 0:
          print("  Processing {}".format(idx))
        output = get_content(ids[idx])
        if output:
          f.writelines(str(idx) + "\t" + output)
        else:
          print("    Can't get content for idx:{}".format(idx))
    print("Coping {} to gs://{}/brio/data/ ...".format(output_file, BUCKET))     
    os.system("gsutil cp {} gs://{}/brio/data/".format(output_file, BUCKET))

以下のようにして実行します。

crawl('ThreeLineSummaryDataset/data/train.csv', output_format="train_{}_{}.tsv", chunk_size=350, resume_from = 1050, stop_at=None)

同様にして、develop.csv, test.csv のデータも収集します。

6. raw データの形式に加工

ここからはクローリングで作った TSV を BRIO のフォーマットに加工していきます。

BRIO が使用するファイルは各 split (train/val/test)毎に以下のとおりです。 BRIO では {要約対象の文章、参照要約、事前に生成した複数の要約候補} で 1 サンプルになります。() 内は要約対象の件数を M 、生成する要約候補の数を N とした場合の件数になります。

  • ${split}.source : 平文形式の元文章 (M)
  • ${split}.source.tokenized : スペース区切りでトークナイズされた元文章 (M)
  • ${split}.target : 平文形式の参照要約 (M)
  • ${split}.target.tokenized : スペース区切りでトークナイズされた参照要約 (M)
  • ${split}.out : 平文形式の学習済みの要約モデルで生成された要約候補 (M×N)
  • ${split}.out.tokenized : スペース区切りでトークナイズされた要約候補 (M×N)

ですが、BRIO の README.md を見ると、

We use the PTB tokenizer provided by Standford CoreNLP (download here). Please note that tokenized texts are only used for evaluation. To tokenize a file, you may run (using test.source as an example)

とあり、トークナイズされたバージョン(*.tokenized)は検証でのみ使用するようです。 おそらくは他の実験との正確な比較の為かと思うのですが、今回は誰と比べる訳でもないのでトークナイズしたバージョンは作らないことにしました。

それでは、クローリングしたファイルを GCS から取得します、

!mkdir data
!mkdir raw
!gsutil -m cp gs://somewhere/brio/data/*.tsv ./data

クローリングで作る TSV は先頭列に記事 ID を入れるようにしたのですが、一部壊れている行があるようです。 おそらくは取得した文章に改行が入っていたんでしょう。とりあえず、先頭列が ID でない行は “ ” 区切りで前行に連結して補正しました。

import glob
for split in ["train","val", "test"]:
  print("split : {}".format(split))
  rows = []
  if split == "val":
    prefix = "dev"
  else:
    prefix = split 
  file_names = glob.glob("./data/{}_*.tsv".format(prefix))
  for file_name in file_names:
    with open(file_name, "r") as f:
      rows.extend([line.strip().split("\t") for line in f.readlines()])
  if not rows:
    continue

  previous = []
  records = []
  for row in rows:
    if not row[0].isdigit():
      for cell in row:
        previous.append(cell)
    else:
      previous = row
      records.append(row)

  records = [[rec[0], rec[1], rec[2], rec[3], " ".join(rec[4:]) ] for rec in records]
  print("num records:{}".format(len(records)))
  records = [rec for rec in records if len(rec[1]) + len(rec[2]) + len(rec[3]) < len(rec[4])]
  print("num records after length check :{}".format(len(records)))
  src = [rec[4] for rec in records]
  tgt = ["。".join(rec[1:4]) + "。" for rec in records]
  with open("./raw/{}.source".format(split), "w") as f:
    f.write("\n".join(src))
  with open("./raw/{}.target".format(split), "w") as f:
    f.write("\n".join(tgt)) 
# split : train
# num records:101219
# num records after length check :98399
# split : val
# num records:668
# num records after length check :645
# split : test
# num records:685
# num records after length check :655

出来上がったファイルの件数は以下のとおりです。3行要約データセットの CSV は 30万件近くあったのですが、 取得しにいくとすでに消えている記事も存在したようで、だいぶ減ってしまいました。。。

!wc -l raw/*
#     654 raw/test.source
#      654 raw/test.target
#    98398 raw/train.source
#    98398 raw/train.target
#      644 raw/val.source
#      644 raw/val.target
#   199392 total

加工済みのファイルを GCS に保存しておきます。

!gsutil cp -r raw gs://somewhere/brio/

前述のとおり、 BRIO の学習データには事前に生成した要約候補が必要です。今回は要約候補を生成するためのモデルを T5 で作成しました。

7. 初期要約モデルの学習

ここから要約候補生成につかう T5 の学習を行っていきます。 transformers の Trainer を使うので datasets を追加でインストールします。

!pip install datasets

前章で加工したデータを GCS から吸い上げて、

!gsutil cp -r gs://somewhere/brio/raw/ .

JSONL の形式に加工します。

import json
for split in ["train", "val", "test"]:
  with open("raw/{}.source".format(split), "r") as f:
    lines = f.readlines()
    src_lines = [line.strip() for line in lines]
  with open("raw/{}.target".format(split), "r") as f:
    lines = f.readlines()
    tgt_lines = [line.strip() for line in lines]
  examples = [{"source":src, "target":tgt} for src, tgt in zip(src_lines, tgt_lines)]
  with open("{}.jsonl".format(split), "w") as f:
    f.write("\n".join([json.dumps(example, ensure_ascii=False) for example in examples]))

行数はこんな感じになりました。

!wc -l *.jsonl
     654 test.jsonl
    98398 train.jsonl
      644 val.jsonl
    99696 total

データの中身はこんな感じのフォーマットです。文書を全部載せると長いので、"…“ で省略してます。

!head -1 train.jsonl
{"source": "女優の1995年に放送されたマクドナルドのCMで...。", "target": "遠藤久美子が...した。トイレで...という。在学中に...なった。"}

BRIO の検証では compare-mt の RougeScorer を使うので、T5 の学習もそれにならうことにしました。 RougeScorer は内部でトークナイズ処理をする仕様になっています。

import compare_mt.rouge.tokenize as rouge_tokenize
rouge_tokenize.tokenize("I have a pen.", stemmer=False)
# ['i', 'have', 'a', 'pen']

当然のことながら日本語に対応していないので、トークナイザを MeCab に置き換えます。

import os
os.environ["MECABRC"] ="/etc/mecabrc"
import MeCab
mecab = MeCab.Tagger ("-Ochasen")
mecab.parse("")

def parse_by_mecab(sentence, lemma=False):
  tokens = []
  node = mecab.parseToNode(sentence).next
  while node:
    feature = node.feature.split(',')
    token = feature[-3] # 標準形
    if token == '*' or not lemma:
      token = node.surface
    tokens.append(token)
    node = node.next
  return [token for token in tokens if len(token) > 0]

def tokenize(text, stemmer):
  return parse_by_mecab(text, lemma=False)

rouge_tokenize.tokenize = tokenize

確認してみましょう。

import compare_mt.rouge.tokenize as rouge_tokenize
rouge_tokenize.tokenize("私はペンを持っています。", stemmer=False)
# ['私', 'は', 'ペン', 'を', '持っ', 'て', 'い', 'ます', '。']

大丈夫そうですね。それでは学習をしてみましょう。 必要なライブラリをインポートしてトークナイザをロードします。

import torch
from datasets import load_dataset
import transformers
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import numpy as np

SRC_MAX_SEQ_LEN = 1024
TGT_MAX_SEQ_LEN = 72

tokenizer = transformers.T5Tokenizer.from_pretrained('megagonlabs/t5-base-japanese-web', return_dict=False)

先程加工した JSONL 形式のファイルをロードし、トークナイズしておきます。

dataset = load_dataset('json', 
                       data_files={"train": "./train.jsonl", 
                                   "val": "./val.jsonl"})

def tokenize_function(examples):
  sources = tokenizer.batch_encode_plus(examples["source"], max_length=SRC_MAX_SEQ_LEN, pad_to_max_length=True, truncation=True, return_tensors="np")
  targets = tokenizer.batch_encode_plus(examples["target"], max_length=TGT_MAX_SEQ_LEN, pad_to_max_length=True, truncation=True, return_tensors="np")["input_ids"]
  return {"input_ids": sources["input_ids"], "attention_mask": sources["attention_mask"], "labels": targets}

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['source', 'target'])
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

学習のパラメータは以下のようにしました。

training_args = Seq2SeqTrainingArguments(
    "./t5-3line-summalization",
    num_train_epochs = 5,
    evaluation_strategy = "steps",
    weight_decay=0.1,
    adafactor=True,
    learning_rate=1e-3,
    lr_scheduler_type="constant",
    warmup_steps = 10,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size  = 2,
    gradient_accumulation_steps = 32,
    eval_accumulation_steps     = 32,
    eval_steps = 150,
    logging_steps = 150,
    save_steps = 150,
    save_total_limit = 6,
    predict_with_generate = True,
)

検証に使うメトリクス関数です。 BRIO の実装に合わせ、参照要約や要約候補を文単位に区切って改行で連結して RougeScorer に投入しています。 文章を文に区切るのは、考えるのが面倒だったので GiNZA を使用しました。

from compare_mt.rouge.rouge_scorer import RougeScorer
all_scorer = RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=True)

import spacy
nlp = spacy.load('ja_ginza')
def process(x):
    return [str(sent) for sent in nlp(x.strip()).sents]

def compute_metrics(eval_prediction):
    label_ids = eval_prediction.label_ids
    pred_ids = eval_prediction.predictions
    labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    preds = tokenizer.batch_decode(pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    labels = [process(label) for label in labels]
    preds = [process(pred) for pred in preds]
    sample_rouge1 = 0
    sample_rouge2 = 0
    sample_rougeLsum = 0
    cnt=0
    for pred, label in zip(preds, labels):
      score = all_scorer.score("\n".join(label), "\n".join(pred))
      sample_rouge1 += score["rouge1"].fmeasure
      sample_rouge2 += score["rouge2"].fmeasure
      sample_rougeLsum += score["rougeLsum"].fmeasure
      cnt += 1
    sample_rouge1 = sample_rouge1 / cnt
    sample_rouge2 = sample_rouge2 / cnt
    sample_rougeLsum = sample_rougeLsum / cnt
    sample_rougeAve = (sample_rouge1 + sample_rouge2 + sample_rougeLsum) / 3.0
    return {"rouge1": sample_rouge1, "rouge2": sample_rouge2, "rougeLsum": sample_rougeLsum, "rougeAve":  sample_rougeAve}

事前学習モデルには Megagon Labs さんのモデルを使いました。

model = transformers.T5ForConditionalGeneration.from_pretrained('megagonlabs/t5-base-japanese-web')

チェックポイントが保存されるタイミングで GCS に保存するコールバックです。 先程、 Seq2SeqTrainingArguments に save_total_limit = 6 を設定して、ローカルディスク上のチェックポイントを 6 つまでにしていますが、 このコールバックを仕掛けることで GCS には全てのチェックポイントが保持されることになります。

from transformers import TrainerCallback
class GsSyncCallback(TrainerCallback):
  def on_save(self, args, state, control, **kwargs):
    !gsutil -m rsync -r -P ./t5-3line-summalization gs://somewhere/brio/t5-3line-summalization/

Trainer には transformers の Seq2SeqTrainer をカスタマイズしたクラスを使いました。

num_beams = 4
length_penalty = 2.0 # α
gen_max_len = 80 
gen_min_len = 48

from typing import Any, Dict, List, Optional, Tuple, Union
from torch import nn
from transformers import Seq2SeqTrainer
from transformers.integrations import is_deepspeed_zero3_enabled
class SummarizeSeq2SeqTrainer(Seq2SeqTrainer):
    '''
    Modify gen_kwargs 
    '''

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:

        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
            )

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)

        # XXX: adapt synced_gpus for fairscale as well
        gen_kwargs = {
            "max_length": gen_max_len + 2,
            "min_length": gen_min_len + 1,
            "num_beams": num_beams,
            "no_repeat_ngram_size" : 3,
            "length_penalty" : length_penalty,
            "early_stopping" : True,
            "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
        }

        generated_tokens = self.model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            **gen_kwargs,
        )
        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

        with torch.no_grad():
            if self.use_amp:
                with autocast():
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
            if has_labels:
                if self.label_smoother is not None:
                    loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
                else:
                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
            else:
                loss = None

        if self.args.prediction_loss_only:
            return (loss, None, None)

        labels = inputs["labels"]
        if labels.shape[-1] < gen_kwargs["max_length"]:
            labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])

        return (loss, generated_tokens, labels)

さて、カスタマイズの理由です。検証時に generate() で要約候補を生成するのですが、このパラメータ設定次第で ROUGE スコアがかなり違ってきます。 BRIO のパラメータ設定に合わせたかったので、継承して prediction_step() をオーバライドしました。

後は Trainer を作って、

trainer = SummarizeSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["val"],
    tokenizer = tokenizer,
    compute_metrics = compute_metrics,
    callbacks=[GsSyncCallback]
)

学習を開始します。

trainer.train()

リジュームの仕方

ちなみに GCS の保存結果からリジュームするときは以下のような感じです(GCS には全てのチェックポイントが保持されているのでその点は注意して下さい)。

!gsutil -m cp -r -P gs://somewhere/brio/t5-3line-summalization .

trainer.train(resume_from_checkpoint=True)

学習結果は以下のようになりました。

検証ロスの最良値

('2850/7685',
 {'eval_loss': 1.3976328372955322,
  'eval_rouge1': 0.5064882043998807,
  'eval_rouge2': 0.24530357459021362,
  'eval_rougeLsum': 0.473191052527153,
  'eval_rougeAve': 0.4083276105057491,
  'eval_runtime': 537.8097,
  'eval_samples_per_second': 1.199,
  'epoch': 1.76})

ROUGEスコア(ROUGE-1, 2, Lsum の平均)の最良値

('3600/7685',
 {'eval_loss': 1.4091293811798096,
  'eval_rouge1': 0.5177512866032165,
  'eval_rouge2': 0.25298135537197264,
  'eval_rougeLsum': 0.48208200105741933,
  'eval_rougeAve': 0.41760488101086946,
  'eval_runtime': 552.3013,
  'eval_samples_per_second': 1.168,
  'epoch': 2.24})

それでは最良の ROUGE スコアを記録したモデルを使って、要約候補を生成していきます。

8. 要約候補の生成

ROUGE スコアの平均が最良だったステップ 3600 を使います。

!gsutil cp -r gs://somewhere/brio/t5-3line-summalization/checkpoint-3600 .

改めて raw データも取得します。

!gsutil cp -r gs://somewhere/brio/raw .

必要なクラス等をインポートして、

import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

トークナイザと前章で学習したモデルをロードします。

tokenizer = T5Tokenizer.from_pretrained('megagonlabs/t5-base-japanese-web')
model = T5ForConditionalGeneration.from_pretrained("./checkpoint-3600").to(device)
model.eval()

サマリを生成する関数です。途中からリジュームする機能と途中経過を GCS にアップロードする機能を付けました。

def generate_summaries(src_dir, tgt_dir, max_length = 80, min_length = 48, bsz = 2, resume_from = 0):
  count = 1
  with open(src_dir) as source, open(tgt_dir, 'a') as fout:

    def generate_candidates(slines):
      with torch.no_grad():
        dct = tokenizer.batch_encode_plus(slines, max_length=1024, return_tensors="pt", pad_to_max_length=True, truncation=True)
        summaries = model.generate(
                        input_ids=dct["input_ids"].to(device),
                        attention_mask=dct["attention_mask"].to(device),
                        num_return_sequences=16, num_beam_groups=16, diversity_penalty=1.0, num_beams=16,
                        max_length=max_length + 2,  # +2 from original because we start at step=1 and stop before max_length
                        min_length=min_length + 1,  # +1 from original because we start at step=1
                        no_repeat_ngram_size=3,
                        length_penalty=2.0,
                        early_stopping=True,
        )
        dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
        for hypothesis in dec:
          hypothesis = hypothesis.replace("\n", " ")
          fout.write(hypothesis + '\n')
          fout.flush()    

    sline = source.readline().strip().lower()
    slines = [sline]
    for sline in source:
      #print("length of slines = {} at count={}".format(len(slines), count), flush=True)
      if count % bsz == 0 and count >= resume_from:
        slines = slines[-bsz:]
        #print("Invoke generate_candidates at count={}, len of slines={}".format(count, len(slines)), flush=True)
        generate_candidates(slines)
        #print("Done generate_candidates at count={}".format(count), flush=True)
        slines = []
      if count % 100 == 0 and count >= resume_from:
        print("Uploading partial contents to GCS at {}...".format(count), flush=True)
        os.system("gsutil cp {} gs://somewhere/brio/raw/".format(tgt_dir))
      sline = sline.strip().lower()
      if len(sline) == 0:
        sline = " "
      slines.append(sline)
      count += 1
    if slines != []:
      generate_candidates(slines)

以下のようにして実行します。

generate_summaries(src_dir="./raw/train.source", tgt_dir="./raw/train.out")

リジュームの仕方

GCS の保存結果からリジュームするときは以下のような感じです。

!gsutil cp -r gs://somewhere/brio/raw .
resume_from = !wc -l ./raw/train.out
resume_from = int(resume_from[0].split(" ")[0]) // 16 + 1
print(resume_from)
generate_summaries(src_dir="./raw/train.source", tgt_dir="./raw/train.out", resume_from = resume_from)

generate_summaries() が終了したら、最後に追記された分も忘れずに GCS にアップロードしておいて下さい。

!gsutil cp ./raw/train.out gs://somewhere/brio/raw/

val, test も同様に処理します。

generate_summaries(src_dir="./raw/val.source", tgt_dir="./raw/val.out")
!gsutil cp ./raw/val.out gs://somewhere/brio/raw/

generate_summaries(src_dir="./raw/test.source", tgt_dir="./raw/test.out")
!gsutil cp ./raw/test.out gs://somewhere/brio/raw/

次はここまで準備したデータを前処理して BRIO が読み込める形にします。

9. 前処理

前処理の実行は GPU は必要ないので、アクセラレータなしのランタイムで良いと思います。

再び raw データを取得して、

!gsutil cp -r gs://somewhere/brio/raw .

RougeScorer のトークナイザを差し替えます。

import compare_mt.rouge.tokenize as rouge_tokenize
import os
os.environ["MECABRC"] ="/etc/mecabrc"
import MeCab
mecab = MeCab.Tagger ("-Ochasen")
mecab.parse("")

def parse_by_mecab(sentence, lemma=False):
  tokens = []
  node = mecab.parseToNode(sentence).next
  while node:
    feature = node.feature.split(',')
    token = feature[-3] # 標準形
    if token == '*' or not lemma:
      token = node.surface
    tokens.append(token)
    node = node.next
  return [token for token in tokens if len(token) > 0]

def tokenize(text, stemmer):
  return parse_by_mecab(text, lemma=False)

rouge_tokenize.tokenize = tokenize

必要なライブラリをインポートして、

import json
from compare_mt.rouge.rouge_scorer import RougeScorer
import os
import spacy
nlp = spacy.load('ja_ginza')
all_scorer = RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=True)

*.source, *.target, *.out ファイルを読み込み 1 サンプル分のデータを yield する関数です。 今回はトークナイズ済みのデータは作るらないことにしたので、その辺りの修正を加えています。

from collections import namedtuple
Args = namedtuple('Args' , ['src_dir', 'tgt_dir', 'cand_num', 'split', 'dataset'])

def collect_diverse_beam_data(args):
    split = args.split
    src_dir = args.src_dir
    tgt_dir = os.path.join(args.tgt_dir, split)
    cands_untok = []
    cnt = 0
    with open(os.path.join(src_dir, f"{split}.source")) as src_untok, open(os.path.join(src_dir, f"{split}.target")) as tgt_untok:
        with open(os.path.join(src_dir, f"{split}.out")) as f_2:
             for y in f_2:
                y = y.strip()
                cands_untok.append(y)
                if len(cands_untok) == args.cand_num:
                    src_line_untok = src_untok.readline()
                    src_line_untok = src_line_untok.strip()
                    tgt_line_untok = tgt_untok.readline()
                    tgt_line_untok = tgt_line_untok.strip()
                    yield (src_line_untok, tgt_line_untok, cands_untok, os.path.join(tgt_dir, f"{cnt}.json"), args.dataset)
                    cands_untok = []
                    cnt += 1

前述の関数が生成したデータを受け取って、 ROUGE スコアを計算し、JSON で保存する関数です。

def build_diverse_beam(input, nlp):
  src_line_untok, tgt_line_untok, cands_untok, tgt_dir, dataset = input
  texts = [src_line_untok, tgt_line_untok] + cands_untok
  docs = nlp.pipe(texts, disable=['ner', 'bunsetu_recognizer'])

  sents_of_docs = []

  for doc in docs:
    sents_of_docs.append( [str(sent) for sent in doc.sents])

  article_untok = sents_of_docs[0]
  abstract_untok = sents_of_docs[1]
  cands_untok = sents_of_docs[2:]
  _abstract = "\n".join(abstract_untok)

  def compute_rouge(hyp):
    score = all_scorer.score(_abstract, "\n".join(hyp))
    return (score["rouge1"].fmeasure + score["rouge2"].fmeasure + score["rougeLsum"].fmeasure) / 3

  candidates_untok = [(x, compute_rouge(x)) for x in cands_untok]

  output = {
    "article_untok": article_untok, 
    "abstract_untok": abstract_untok,
    "candidates_untok": candidates_untok,
  }

  with open(tgt_dir, "w") as f:
        json.dump(output, f)

前述の関数を以下のように実行します。なんとなく nlp をリロードしてますが不要かもしれません。

!mkdir -p ./livedoor-3lines/diverse/train
!mkdir -p ./livedoor-3lines/diverse/val
!mkdir -p ./livedoor-3lines/diverse/test

for split in ["train", "val", "test"]:
  count = 0
  args = Args(src_dir="./raw", tgt_dir="./livedoor-3lines/diverse", split=split, dataset="livedoor-3lines", cand_num=16)
  nlp = spacy.load('ja_ginza')
  for input in collect_diverse_beam_data(args):
    if count % 500 == 0:
      print("count = {}".format(count))
    if count % 2000 == 0:
      print("Reload nlp...")
      nlp = spacy.load('ja_ginza')
    build_diverse_beam(input, nlp)
    count += 1

生成したファイルは GCS に退避しておきます。

!cd ./livedoor-3lines && tar zcf diverse.tar.gz ./diverse
!gsutil cp ./diverse.tar.gz gs://somewhere/brio/livedoor-3lines/

あれこれと準備してきましたが、まだ学習を始めることはできません。

これから BRIO に日本語や T5 を使うためのコード修正を行っていきます。

10. ソースコードの修正

ここからはソースコードを修正していきます。

カスタム T5 モデル

まず、今回は Transformer に T5 を使うので BRIO の学習に対応したカスタムクラスを用意しました。

%%writefile ./custom_t5.py
import torch
from transformers import T5ForConditionalGeneration, T5Config
from transformers.modeling_outputs import Seq2SeqLMOutput
import dataclasses
@dataclasses.dataclass
class DummyModel:
    def __init__(self):
        self.is_scoring_mode = True

    def scoring_mode(self):
        self.is_scoring_mode = True

    def generation_mode(self):
        self.is_scoring_mode = False

class T5Scorer(T5ForConditionalGeneration):

    def __init__(self, config: T5Config):
        super().__init__(config)
        self.model = DummyModel()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # invoke orginal forward() in generation mode.
        if not self.model.is_scoring_mode:
            return super().forward(
                input_ids,
                attention_mask,
                decoder_input_ids,
                decoder_attention_mask,
                head_mask,
                decoder_head_mask,
                cross_attn_head_mask,
                encoder_outputs,
                past_key_values,
                inputs_embeds,
                decoder_inputs_embeds,
                labels,
                use_cache,
                output_attentions,
                output_hidden_states,
                return_dict)

        # invoke customized forward() in scoreing mode.
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        # Expand hidden states of encoder and flattten candidates. 
        cand_num = decoder_input_ids.size(1)
        hidden_states = encoder_outputs[0] # [batch, seq_len, hidden]
        hidden_states = torch.repeat_interleave(hidden_states, cand_num, dim=0) # [batch * cand_num, seq_len, hidden]
        attention_mask = torch.repeat_interleave(attention_mask, cand_num, dim=0)
        decoder_input_ids = decoder_input_ids.view(-1, decoder_input_ids.size(-1))
        decoder_attention_mask = decoder_attention_mask.view(-1, decoder_attention_mask.size(-1))

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim ** -0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

BRIO の実装では Transformer に scoring モードと generation モードがあり、それを切り替えながら学習します。

  • generation モードで期待されるのは素の T5 と同じ動きです。なので super の forward() を呼ぶだけです。
  • scoring モードでは、decoder_input_ids に複数の要約候補のトークンID が [batch, cand_num, seq_len, hidden] のシェイプで投入されるので、[batch * cand_num, seq_len, hidden] の形に変形し、input_ids をそれに合わせて torch.repeat_interleave() で水増しするだけですね。

BrioDataset

次にデータセットクラスの修正です。

トークナイザのインポートと、

!cat BRIO/data_utils.py | awk 'NR==5{print NR"|"$0}'
# 5|from transformers import BartTokenizer, PegasusTokenizer

!sed -i '5s/BartTokenizer, PegasusTokenizer/T5Tokenizer/' BRIO/data_utils.py
!cat BRIO/data_utils.py | awk 'NR==5{print NR"|"$0}'
# 5|from transformers import T5Tokenizer

ロードする箇所を T5 のものに差し替えます。

!cat BRIO/data_utils.py | awk 'NR>=29 && NR<=36{print NR"|"$0}'
#29|            else:
#30|                self.num = len(self.files)
#31|        if is_pegasus:
#32|            self.tok = PegasusTokenizer.from_pretrained(model_type, verbose=False)
#33|        else:
#34|            self.tok = BartTokenizer.from_pretrained(model_type, verbose=False)
#35|        self.maxlen = max_len
#36|        self.is_test = is_test

!sed -i -e '31,34d' -e '30a\ \ \ \ \ \ \ \ self.tok = T5Tokenizer.from_pretrained(model_type, verbose=False)' BRIO/data_utils.py
!cat BRIO/data_utils.py | awk 'NR>=29 && NR<=33{print NR"|"$0}'
#29|            else:
#30|                self.num = len(self.files)
#31|        self.tok = T5Tokenizer.from_pretrained(model_type, verbose=False)
#32|        self.maxlen = max_len
#33|        self.is_test = is_test

” “ 区切りのトークナイズされた形式のデータは使わないことにしたので削除します。

!cat BRIO/data_utils.py | awk 'NR>=62 && NR<=72{print NR"|"$0}'
#62|        if self.maxnum > 0:
#63|            candidates = data["candidates_untok"][:self.maxnum]
#64|            _candidates = data["candidates"][:self.maxnum]
#65|            data["candidates"] = _candidates
#66|        if self.sorted:
#67|            candidates = sorted(candidates, key=lambda x:x[1], reverse=True)
#68|            _candidates = sorted(_candidates, key=lambda x:x[1], reverse=True)
#69|            data["candidates"] = _candidates
#70|        if not self.is_untok:
#71|            candidates = _candidates
#72|        cand_txt = [" ".join(abstract)] + [" ".join(x[0]) for x in candidates]

!sed -i -e '64,65d' -e '68,71d' BRIO/data_utils.py 
!cat BRIO/data_utils.py | awk 'NR>=62 && NR<=66{print NR"|"$0}'
#62|        if self.maxnum > 0:
#63|            candidates = data["candidates_untok"][:self.maxnum]
#64|        if self.sorted:
#65|            candidates = sorted(candidates, key=lambda x:x[1], reverse=True)
#66|        cand_txt = [" ".join(abstract)] + [" ".join(x[0]) for x in candidates]

T5 は PEGASUS と同様に decoder_token_ids の先頭に 0 が必要なので修正します。

!cat BRIO/data_utils.py | awk 'NR>=68 && NR<=75{print NR"|"$0}'
#68|        candidate_ids = cand["input_ids"]
#69|        if self.is_pegasus:
#70|            # add start token
#71|            _candidate_ids = candidate_ids.new_zeros(candidate_ids.size(0), candidate_ids.size(1) + 1)
#72|            _candidate_ids[:, 1:] = candidate_ids.clone()
#73|            _candidate_ids[:, 0] = self.tok.pad_token_id
#74|            candidate_ids = _candidate_ids
#75|        result = {

!sed -i -e '69d' -e '70,74s/^    //' BRIO/data_utils.py
!cat BRIO/data_utils.py | awk 'NR>=68 && NR<=75{print NR"|"$0}'
#68|        candidate_ids = cand["input_ids"]
#69|        # add start token
#70|        _candidate_ids = candidate_ids.new_zeros(candidate_ids.size(0), candidate_ids.size(1) + 1)
#71|        _candidate_ids[:, 1:] = candidate_ids.clone()
#72|        _candidate_ids[:, 0] = self.tok.pad_token_id
#73|        candidate_ids = _candidate_ids
#74|        result = {
#75|            "src_input_ids": src_input_ids, 

今回は学習のリジューム処理として停止時点で処理していたエポックの先頭から再開し、停止時点まで読み飛ばす処理を追加することにしました。

この場合、DatarLoader で shuffle=True とすると停止時点のエポックで学習済みのサンプルを再開後に読んでしまう可能があると思ったので、 BrioDataset に shuffle 機能を足して、DataLoader を shuffle = False とすることにしました ( shuffle した読み出し順が全エポック共通になってしまいましたが、そこは妥協の範囲かなと。また複数プロセスで動かすときのことは考えてません。)。

random の import を追加して、

!cat BRIO/data_utils.py | awk 'NR>=1 && NR<=6{print NR"|"$0}'
#1|from torch.utils.data import Dataset
#2|import os
#3|import json
#4|import torch
#5|from transformers import T5Tokenizer
#6|

!sed -i '5aimport random' BRIO/data_utils.py
!cat BRIO/data_utils.py | awk 'NR>=1 && NR<=6{print NR"|"$0}'
#1|from torch.utils.data import Dataset
#2|import os
#3|import json
#4|import torch
#5|from transformers import T5Tokenizer
#6|import random

コンストラクタに shuffle を追加して、

!cat BRIO/data_utils.py | awk 'NR>=16 && NR<=16{print NR"|"$0}'
#16|    def __init__(self, fdir, model_type, max_len=-1, is_test=False, total_len=512, is_sorted=True, max_num=-1, is_untok=True, is_pegasus=False, num=-1):

!sed -i -e '16s/num=-1)/num=-1, shuffle=False)/' BRIO/data_utils.py  
!cat BRIO/data_utils.py | awk 'NR>=16 && NR<=16{print NR"|"$0}'
#16|    def __init__(self, fdir, model_type, max_len=-1, is_test=False, total_len=512, is_sorted=True, max_num=-1, is_untok=True, is_pegasus=False, num=-1, shuffle=False):

インスタンスの生成時に shuffle した読み出し順を生成して保持しておきます。

!cat BRIO/data_utils.py | awk 'NR>=38 && NR<=40{print NR"|"$0}'
#38|        self.is_untok = is_untok
#39|        self.is_pegasus = is_pegasus
#40|

!sed -i -e '39a\ \ \ \ \ \ \ \ self.shuffle = shuffle' \
     -e '39a\ \ \ \ \ \ \ \ if shuffle:' \
     -e '39a\ \ \ \ \ \ \ \ \ \ self.sequence = list(range(self.num))' \
     -e '39a\ \ \ \ \ \ \ \ \ \ random.shuffle(self.sequence)' \
     -e '39a\ \ \ \ \ \ \ \ \ \ #print("sequence: {}".format(self.sequence[:10]))' \
     BRIO/data_utils.py
!cat BRIO/data_utils.py | awk 'NR>=38 && NR<=45{print NR"|"$0}'
#38|        self.is_untok = is_untok
#39|        self.is_pegasus = is_pegasus
#40|        self.shuffle = shuffle
#41|        if shuffle:
#42|          self.sequence = list(range(self.num))
#43|          random.shuffle(self.sequence)
#44|          #print("sequence: {}".format(self.sequence[:10]))
#45|

__getitem__() するときに shuffle しておいた読み出し順を参照して対象データを特定します。

!cat BRIO/data_utils.py | awk 'NR>48 && NR<=52{print NR"|"$0}'
#49|    def __getitem__(self, idx):
#50|        if self.isdir:
#51|            with open(os.path.join(self.fdir, "%d.json"%idx), "r") as f:
#52|                data = json.load(f)

!sed -i -e '49a\ \ \ \ \ \ \ \ if self.shuffle:' \
     -e '49a\ \ \ \ \ \ \ \ \ \ idx = self.sequence[idx]' \
     BRIO/data_utils.py
!cat BRIO/data_utils.py | awk 'NR>=48 && NR<=54{print NR"|"$0}'
#48|
#49|    def __getitem__(self, idx):
#50|        if self.shuffle:
#51|          idx = self.sequence[idx]
#52|        if self.isdir:
#53|            with open(os.path.join(self.fdir, "%d.json"%idx), "r") as f:
#54|                data = json.load(f)

model.py

model.py は Bart や PEGASUS など Transformer の実体を隠蔽化する層ですね。ここも T5 に書き換えます。

インポートを書き換えて、

!cat BRIO/model.py | awk 'NR>=5 && NR<=6{print NR"|"$0}'
#5|from modeling_bart import BartScorer
#6|from modeling_pegasus import PegasusScorer

!sed -i -e '5,6d' -e '4afrom custom_t5 import T5Scorer' BRIO/model.py 
!cat BRIO/model.py | awk 'NR>=5 && NR<=6{print NR"|"$0}'
#5|from custom_t5 import T5Scorer
#6|

カスタマイズした T5 のインスタンスを生成するようにします。

!cat BRIO/model.py | awk 'NR>=41 && NR<=46{print NR"|"$0}'
#41|        super(BRIO, self).__init__()
#42|        if is_pegasus:
#43|            self.model = PegasusScorer.from_pretrained(mname, cache_dir="./local_cache")
#44|        else:
#45|            self.model = BartScorer.from_pretrained(mname, cache_dir="./local_cache")
#46|        self.pad_token_id = pad_token_id

!sed -i -e '42,45d' -e '41a\ \ \ \ \ \ \ \ self.model = T5Scorer.from_pretrained(mname, cache_dir="./local_cache")' BRIO/model.py
!cat BRIO/model.py | awk 'NR>=41 && NR<=43{print NR"|"$0}'
#41|        super(BRIO, self).__init__()
#42|        self.model = T5Scorer.from_pretrained(mname, cache_dir="./local_cache")
#43|        self.pad_token_id = pad_token_id   

次は学習ループのコードを修正していきます。

main.py

インポートするトークナイザを書き換えて、

!cat BRIO/main.py | awk 'NR==9{print NR"|"$0}'
#9|from transformers import BartTokenizer, PegasusTokenizer

!sed -i '9s/BartTokenizer, PegasusTokenizer/T5Tokenizer/' BRIO/main.py
!cat BRIO/main.py | awk 'NR==9{print NR"|"$0}'
#9|from transformers import T5Tokenizer

evaluation()

この関数は学習と別に検証を動かす為のもののようです。

デフォルト値の設定とトークナイザの変更です。

!cat BRIO/main.py | awk 'NR>=63 && NR<=75{print NR"|"$0}'
#63|def evaluation(args):
#64|    # load data
#65|    if args.config == "cnndm":
#66|        cnndm_setting(args)
#67|    elif args.config == "xsum":
#68|        xsum_setting(args)
#69|    else:
#70|        base_setting(args)
#71|    if args.is_pegasus:
#72|        tok = PegasusTokenizer.from_pretrained(args.model_type)
#73|    else:
#74|        tok = BartTokenizer.from_pretrained(args.model_type)
#75|    collate_fn = partial(collate_mp_brio, pad_token_id=tok.pad_token_id, is_test=True)

!sed -i -e '65,74d' -e '64a\ \ \ \ base_setting(args)\n\ \ \ \ tok = T5Tokenizer.from_pretrained(args.model_type)' BRIO/main.py
!cat BRIO/main.py | awk 'NR>=63 && NR<=67{print NR"|"$0}'
#63|def evaluation(args):
#64|    # load data
#65|    base_setting(args)
#66|    tok = T5Tokenizer.from_pretrained(args.model_type)
#67|    collate_fn = partial(collate_mp_brio, pad_token_id=tok.pad_token_id, is_test=True)

evaluation() で使う DataLoader のバッチサイズとワーカー数を変更しました。

!cat BRIO/main.py | awk 'NR>=70 && NR<=71{print NR"|"$0}'
#70|    batch_size = 4
#71|    dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)

!sed -i -e '70s/4/2/' -e '71s/4/1/' BRIO/main.py  
!cat BRIO/main.py | awk 'NR>=70 && NR<=71{print NR"|"$0}'
#70|    batch_size = 2
#71|    dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=1, collate_fn=collate_fn)

” “ でトークナイズされた文章を、トークナイズされてない平文に置き換えます。 RougeScorer の内部で MeCab で分かち書きされるので、 これで大丈夫なはずです。

!cat BRIO/main.py | awk 'NR>=111 && NR<=113{print NR"|"$0}'
#111|                    sents = sample["candidates"][max_ids[j]][0]
#112|                    # print(" ".join(sents), file=f_out)
#113|                    score = rouge_scorer.score("\n".join(sample["abstract"]), "\n".join(sents))

!sed -i -e '111s/candidates/candidates_untok/' -e '113s/abstract/abstract_untok/' BRIO/main.py 
!cat BRIO/main.py | awk 'NR>=111 && NR<=113{print NR"|"$0}'
#111|                    sents = sample["candidates_untok"][max_ids[j]][0]
#112|                    # print(" ".join(sents), file=f_out)
#113|                    score = rouge_scorer.score("\n".join(sample["abstract_untok"]), "\n".join(sents))

!cat BRIO/main.py | awk 'NR>=121 && NR<=122{print NR"|"$0}'
#121|                        for s in sample["abstract"]:
#122|                            print(s, file=f)

!sed -i -e '121s/abstract/abstract_untok/' BRIO/main.py 
!cat BRIO/main.py | awk 'NR>=121 && NR<=122{print NR"|"$0}'
#121|                        for s in sample["abstract_untok"]:
#122|                            print(s, file=f)

生成モードで評価する際のバッチサイズを変更して、

!cat BRIO/main.py | awk 'NR>=134 && NR<=134{print NR"|"$0}'
#134|        bsz = 8

!sed -i -e '134s/8/2/' BRIO/main.py 
!cat BRIO/main.py | awk 'NR>=134 && NR<=134{print NR"|"$0}'
#134|        bsz = 2

文章を文単位に分割する処理を GiNZA に置き換えました。

!cat BRIO/main.py | awk 'NR>=184 && NR<=185{print NR"|"$0}'
#184|        def process(x):
#185|            return sent_tokenize(" ".join(word_tokenize(x.strip())))

!sed -i -e "183a\ \ \ \ \ \ \ \ import spacy\n\ \ \ \ \ \ \ \ nlp = spacy.load('ja_ginza')" \
     -e '185s/return.*$/return \[str\(sent\) for sent in nlp\(x\.strip\(\)\)\.sents\]/' BRIO/main.py
!cat BRIO/main.py | awk 'NR>=184 && NR<=187{print NR"|"$0}'
#184|        import spacy
#185|        nlp = spacy.load('ja_ginza')
#186|        def process(x):
#187|            return [str(sent) for sent in nlp(x.strip()).sents]

test()

この関数は学習中に検証を実施する際に使用されています。

こちらもトークナイズしていない文章で評価するように書き換えて、

!cat BRIO/main.py | awk 'NR>=245 && NR<=246{print NR"|"$0}'
#245|                sents = sample["candidates"][max_ids[j]][0]
#246|                score = rouge_scorer.score("\n".join(sample["abstract"]), "\n".join(sents))

!sed -i -e '245s/candidates/candidates_untok/' -e '246s/abstract/abstract_untok/' BRIO/main.py
!cat BRIO/main.py | awk 'NR>=245 && NR<=246{print NR"|"$0}' 
#245|                sents = sample["candidates_untok"][max_ids[j]][0]
#246|                score = rouge_scorer.score("\n".join(sample["abstract_untok"]), "\n".join(sents))

evaluate() と同様に文単位に分割する処理を差し替えます。

!cat BRIO/main.py | awk 'NR>=273 && NR<=274{print NR"|"$0}'
#273|        def process(x):
#274|            return sent_tokenize(" ".join(word_tokenize(x.strip())))

!sed -i -e "272a\ \ \ \ \ \ \ \ import spacy\n\ \ \ \ \ \ \ \ nlp = spacy.load('ja_ginza')" \
     -e '274s/return.*$/return \[str\(sent\) for sent in nlp\(x\.strip\(\)\)\.sents\]/' BRIO/main.py
!cat BRIO/main.py | awk 'NR>=273 && NR<=276{print NR"|"$0}'
#273|        import spacy
#274|        nlp = spacy.load('ja_ginza')
#275|        def process(x):
#276|            return [str(sent) for sent in nlp(x.strip()).sents]

run()

学習処理の本体です。

Colab の GPU ランタイムの寿命では学習が終わらないので、動作確認できた後は GCP に VM を立てて実行します。 ですが、そちらもプリエンプティブルにして料金を節約したいので、リジューム処理を追加しました。

まずはパラメータのデフォルト値を変更します。

!cat BRIO/main.py | awk 'NR>=331 && NR<=338{print NR"|"$0}'
#331|def run(rank, args):
#332|    if args.config == "cnndm":
#333|        cnndm_setting(args)
#334|    elif args.config == "xsum":
#335|        xsum_setting(args)
#336|    else:
#337|        base_setting(args)
#338|    # task initialization

!sed -i -e '332,337d' -e '331a\ \ \ \ base_setting(args)' BRIO/main.py 
!cat BRIO/main.py | awk 'NR>=331 && NR<=333{print NR"|"$0}'
#331|def run(rank, args):
#332|    base_setting(args)
#333|    # task initialization

トークナイザを差し替えて、

!cat BRIO/main.py | awk 'NR>=346 && NR<=349{print NR"|"$0}'
#346|    if args.is_pegasus:
#347|        tok = PegasusTokenizer.from_pretrained(args.model_type)
#348|    else:
#349|        tok = BartTokenizer.from_pretrained(args.model_type)

!sed -i -e '346,349d' -e '345a\ \ \ \ tok = T5Tokenizer.from_pretrained(args.model_type)' BRIO/main.py  
!cat BRIO/main.py | awk 'NR>=346 && NR<=346{print NR"|"$0}'
#346|    tok = T5Tokenizer.from_pretrained(args.model_type)

次に BrioDataset でシャッフルするように修正しました。 これはリジューム時に既視のデータを見ないようにする為ですね。

!sed -i -e '349s/is_pegasus)/is_pegasus, shuffle=True)/' BRIO/main.py
!cat BRIO/main.py | awk 'NR>=349 && NR<=349{print NR"|"$0}'
#349|    train_set = BrioDataset(f"./{args.dataset}/{args.datatype}/train", args.model_type, max_len=args.max_len, is_sorted=False, max_num=args.max_num, total_len=args.total_len, is_pegasus=args.is_pegasus, shuffle=True)

BrioDataset 側にシャッフルの責務を映したので DataLoader の方はシャッフルをしないようにします。

!sed -i -e '360s/shuffle=True/shuffle=False/' BRIO/main.py
!cat BRIO/main.py | awk 'NR>=360 && NR<=360{print NR"|"$0}'
#360|        dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)

バッチサイズとワーカ数の書き換えです。

!cat BRIO/main.py | awk 'NR>=354 && NR<=362{print NR"|"$0}'
#354|        dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn, sampler=train_sampler)
#355|        val_sampler = torch.utils.data.distributed.DistributedSampler(
#356|        val_set, num_replicas=world_size, rank=rank)
#357|        val_dataloader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, collate_fn=collate_fn_val, sampler=val_sampler)
#358|        val_gen_dataloader = DataLoader(val_set, batch_size=8, shuffle=False, num_workers=4, collate_fn=collate_fn_val, sampler=val_sampler)
#359|    else:
#360|        dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)
#361|        val_dataloader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, collate_fn=collate_fn_val)
#362|        val_gen_dataloader = DataLoader(val_set, batch_size=8, shuffle=False, num_workers=4, collate_fn=collate_fn_val)

!sed -i -e '354s/4/1/' -e '357s/4/1/' -e '358s/8/2/' -e '358s/4/1/' -e '360s/4/1/' -e '361s/4/1/'  -e '362s/8/2/' -e '362s/4/1/' BRIO/main.py
!cat BRIO/main.py | awk 'NR>=354 && NR<=362{print NR"|"$0}'
#354|        dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=1, collate_fn=collate_fn, sampler=train_sampler)
#355|        val_sampler = torch.utils.data.distributed.DistributedSampler(
#356|        val_set, num_replicas=world_size, rank=rank)
#357|        val_dataloader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1, collate_fn=collate_fn_val, sampler=val_sampler)
#358|        val_gen_dataloader = DataLoader(val_set, batch_size=2, shuffle=False, num_workers=1, collate_fn=collate_fn_val, sampler=val_sampler)
#359|    else:
#360|        dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=1, collate_fn=collate_fn)
#361|        val_dataloader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1, collate_fn=collate_fn_val)
#362|        val_gen_dataloader = DataLoader(val_set, batch_size=2, shuffle=False, num_workers=1, collate_fn=collate_fn_val)

ここからリジューム処理を追加していきます。

まずは args.model_pt が設定されていたら、その時点の各種情報を読み込む処理を追加します。

!cat BRIO/main.py | awk 'NR>=405 && NR<=406{print NR"|"$0}'
#405|    # start training
#406|    for epoch in range(args.epoch):

!sed -i -e '405a\ \ \ \ step_state = None' \
     -e '405a\ \ \ \ resumed_epoch = 0' \
     -e '405a\ \ \ \ if len(args.model_pt) > 0:' \
     -e '405a\ \ \ \ \ \ \ \ base_dir = args.model_pt.split("/")[0]' \
     -e '405a\ \ \ \ \ \ \ \ s_optimizer.load_state_dict(torch.load(os.path.join("./cache", base_dir, "optimizer.bin"), map_location=f"cuda:{gpuid}"))' \
     -e '405a\ \ \ \ \ \ \ \ step_state = torch.load(os.path.join("./cache", base_dir, "step.bin"))' \
     -e '405a\ \ \ \ \ \ \ \ resumed_epoch = step_state["epoch"]' \
     -e '405a\ \ \ \ \ \ \ \ all_step_cnt = step_state["all_step_cnt"]' \
     -e '405a\ \ \ \ resume_completed = True' \
     -e '405a\ \ \ \ if step_state and step_state["all_step_cnt"] > 0:' \
     -e '405a\ \ \ \ \ \ \ \ print("Start resuming to {}".format(step_state))' \
     -e '405a\ \ \ \ \ \ \ \ resume_completed = False' \
     -e '405a\ \ \ \ \ \ \ \ ' \
     -e '406s/args.epoch/resumed_epoch, args.epoch/' \
     BRIO/main.py
!cat BRIO/main.py | awk 'NR>=405 && NR<=419{print NR"|"$0}'
#405|    # start training
#406|    step_state = None
#407|    resumed_epoch = 0
#408|    if len(args.model_pt) > 0:
#409|        base_dir = args.model_pt.split("/")[0]
#410|        s_optimizer.load_state_dict(torch.load(os.path.join("./cache", base_dir, "optimizer.bin"), map_location=f"cuda:{gpuid}"))
#411|        step_state = torch.load(os.path.join("./cache", base_dir, "step.bin"))
#412|        resumed_epoch = step_state["epoch"]
#413|        all_step_cnt = step_state["all_step_cnt"]
#414|    resume_completed = True
#415|    if step_state and step_state["all_step_cnt"] > 0:
#416|        print("Start resuming to {}".format(step_state))
#417|        resume_completed = False
#418|        
#419|    for epoch in range(resumed_epoch, args.epoch):

ロードしたステップ数まで読み飛ばす処理の追加です。

!cat BRIO/main.py | awk 'NR>=426 && NR<=429{print NR"|"$0}'
#426|        for (i, batch) in enumerate(dataloader):
#427|            if args.cuda:
#428|                to_cuda(batch, gpuid)
#429|            step_cnt += 1

!sed -i -e '426a\ \ \ \ \ \ \ \ \ \ \ \ if step_state:' \
     -e '426a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ if epoch <= resumed_epoch and epoch_step < step_state["epoch_step"] :' \
     -e '426a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ step_cnt += 1' \
     -e '426a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ if step_cnt == args.accumulate_step:' \
     -e '426a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ step_cnt = 0' \
     -e '426a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ epoch_step += 1' \
     -e '426a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ continue' \
     -e '426a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ else:' \
     -e '426a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ if not resume_completed :' \
     -e '426a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ print("Resume Complated to epoch: {}, epoch_step={}, all_step_count={}.".format(epoch, epoch_step, all_step_cnt))' \
     -e '426a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ resume_completed = True' \
     BRIO/main.py
!cat BRIO/main.py | awk 'NR>=426 && NR<=440{print NR"|"$0}'
#426|        for (i, batch) in enumerate(dataloader):
#427|            if step_state:
#428|                if epoch <= resumed_epoch and epoch_step < step_state["epoch_step"] :
#429|                    step_cnt += 1
#430|                    if step_cnt == args.accumulate_step:
#431|                        step_cnt = 0
#432|                        epoch_step += 1
#433|                    continue
#434|                else:
#435|                  if not resume_completed :
#436|                      print("Resume Complated to epoch: {}, epoch_step={}, all_step_count={}.".format(epoch, epoch_step, all_step_cnt))
#437|                      resume_completed = True
#438|            if args.cuda:
#439|                to_cuda(batch, gpuid)
#440|            step_cnt += 1

ステップ情報の保存処理を追加しておきます。

!cat BRIO/main.py | awk 'NR>=520 && NR<=525{print NR"|"$0}'
#520|                    if is_mp:
#521|                        recorder.save(model.module, "model_cur.bin")
#522|                    else:
#523|                        recorder.save(model, "model_cur.bin")
#524|                    recorder.save(s_optimizer, "optimizer.bin")
#525|

!sed -i -e '524a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ torch.save({"epoch": epoch, "epoch_step": epoch_step, "step_cnt": step_cnt, "all_step_cnt": all_step_cnt}, os.path.join(recorder.dir, "step.bin"))' BRIO/main.py 
!cat BRIO/main.py | awk 'NR>=520 && NR<=526{print NR"|"$0}'
#520|                    if is_mp:
#521|                        recorder.save(model.module, "model_cur.bin")
#522|                    else:
#523|                        recorder.save(model, "model_cur.bin")
#524|                    recorder.save(s_optimizer, "optimizer.bin")
#525|                    torch.save({"epoch": epoch, "epoch_step": epoch_step, "step_cnt": step_cnt, "all_step_cnt": all_step_cnt}, os.path.join(recorder.dir, "step.bin"))
#526|

ついでに GCS にコピーするようにして、

!sed -i -e '525a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ os.system("gsutil cp -r ./cache/* gs://somewhere/brio/cache/")' BRIO/main.py 
!cat BRIO/main.py | awk 'NR>=524 && NR<=527{print NR"|"$0}'
#524|                    recorder.save(s_optimizer, "optimizer.bin")
#525|                    torch.save({"epoch": epoch, "epoch_step": epoch_step, "step_cnt": step_cnt, "all_step_cnt": all_step_cnt}, os.path.join(recorder.dir, "step.bin"))
#526|                    os.system("gsutil cp -r ./cache/* gs://somewhere/brio/cache/")
#527|

チェックポイントの保存間隔を検証と独立して設定できるようにしました。

!sed -i -e '517a\ \ \ \ \ \ \ \ \ \ \ \ if all_step_cnt % args.save_interval == 0 and all_step_cnt != 0 and step_cnt == 0:' BRIO/main.py  
!cat BRIO/main.py | awk 'NR>=517 && NR<=520{print NR"|"$0}'
#517|                        %(result["sample_rouge1"], result["sample_rouge2"], result["sample_rougeLsum"]))
#518|            if all_step_cnt % args.save_interval == 0 and all_step_cnt != 0 and step_cnt == 0:
#519|                # save current model
#520|                if is_master:

!sed -i -e '527a\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ print("Current model(all_step_cnt ={}) saved.".format(all_step_cnt), flush=True)' BRIO/main.py 
!cat BRIO/main.py | awk 'NR>=526 && NR<=528{print NR"|"$0}'
#526|                    torch.save({"epoch": epoch, "epoch_step": epoch_step, "step_cnt": step_cnt, "all_step_cnt": all_step_cnt}, os.path.join(recorder.dir, "step.bin"))
#527|                    os.system("gsutil cp -r ./cache/* gs://somewhere/brio/cache/")
#528|                    print("Current model(all_step_cnt ={}) saved.".format(all_step_cnt), flush=True)

ようやくですが、学習を開始する準備が整いました。

11. BRIO の学習

それでは学習を開始していきましょう。

学習の起点としては要約候補の生成に使用した T5 を用います。

ROUGE スコアとしてはステップ 3600 が最良だったのですが、BRIO も元ネタとしては同じデータで学習するので過学習が心配です。 それを考えると、BRIO の起点とするチェックポイントは若いものが良いと思い、検証ロスが最良であったステップ 2850 を使いました。

!gsutil cp -r gs://somewhere/brio/t5-3line-summalization/checkpoint-2850 .

8 章、9 章で生成したデータを取得して所定の位置に配置します。 9 章で作ったデータ (diverse.tar.gz) と 8 章で作ったデータ (raw/*) で行数カウントが 1 行分ズレていますが最終行に改行が無いためなので問題ありません。

!gsutil cp gs://somewhere/brio/livedoor-3lines/diverse.tar.gz .
!mkdir -p ./livedoor-3lines
!cd ./livedoor-3lines && tar zxf ../diverse.tar.gz

!ls ./livedoor-3lines/diverse/train/ | wc -l
# 98399
!ls ./livedoor-3lines/diverse/val/ | wc -l
# 645
!ls ./livedoor-3lines/diverse/test/ | wc -l
# 655

!gsutil cp -r gs://somewhere/brio/raw .
!mv ./raw/* ./livedoor-3lines/diverse
!wc -l ./livedoor-3lines/diverse/*\.*
#   10480 ./livedoor-3lines/diverse/test.out
#      654 ./livedoor-3lines/diverse/test.source
#      654 ./livedoor-3lines/diverse/test.target
#  1574384 ./livedoor-3lines/diverse/train.out
#    98398 ./livedoor-3lines/diverse/train.source
#    98398 ./livedoor-3lines/diverse/train.target
#    10320 ./livedoor-3lines/diverse/val.out
#      644 ./livedoor-3lines/diverse/val.source
#      644 ./livedoor-3lines/diverse/val.target
!mv ./livedoor-3lines/diverse/test.target ./livedoor-3lines

学習のチェックポイントは cache に出力されるので、ディレクトリを作っておきます。

!mkdir -p ./cache

コードの検索パスを追加して、

import sys
sys.path.append(".")
sys.path.append("./BRIO")

RougeScorer のトークナイザを差し替えておきます。

import compare_mt.rouge.tokenize as rouge_tokenize
import os
os.environ["MECABRC"] ="/etc/mecabrc"
import MeCab
mecab = MeCab.Tagger ("-Ochasen")
mecab.parse("")

def parse_by_mecab(sentence, lemma=False):
  tokens = []
  node = mecab.parseToNode(sentence).next
  while node:
    feature = node.feature.split(',')
    token = feature[-3] # 標準形
    if token == '*' or not lemma:
      token = node.surface
    tokens.append(token)
    node = node.next
  return [token for token in tokens if len(token) > 0]

def tokenize(text, stemmer):
  return parse_by_mecab(text, lemma=False)

rouge_tokenize.tokenize = tokenize

必要なライブラリをインポートして、

import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import numpy as np
import os
import random
from compare_mt.rouge.rouge_scorer import RougeScorer
import transformers
from transformers import T5Tokenizer 
from utils import Recorder
from data_utils import to_cuda, collate_mp_brio, BrioDataset
from torch.utils.data import DataLoader
import torch.distributed as dist
import torch.multiprocessing as mp
from functools import partial
from model import RankingLoss, BRIO
import logging
from label_smoothing_loss import label_smoothing_loss
from nltk import sent_tokenize, word_tokenize
from config import cnndm_setting, xsum_setting
from tqdm import tqdm

logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
logging.getLogger("transformers.tokenization_utils_fast").setLevel(logging.ERROR)

from main import run

ハイパーパラメータの設定です。以下に記述のないパラメータは main.py の base_setting() の値が適用されます。

import dataclasses
@dataclasses.dataclass
class Args:
    pass

args = Args()
args.cuda=True
args.gpuid=[0] 
args.evaluate=True 
args.do_reranking=True 
args.do_generation=True 
args.log=True
args.model_pt="" 
args.config=""

args.batch_size = 1
args.epoch = 4
args.accumulate_step = 8
args.max_lr = 2e-3
args.warmup_steps = 10000 #  lr = args.max_lr * min(all_step_cnt ** (-0.5), all_step_cnt * (args.warmup_steps ** (-1.5)))
args.margin = 0.001 # λ
args.gold_margin = 0
args.gold_weight = 0
args.mle_weight = 0.1 # γ
args.rank_weight = 10 # γ
args.eval_interval = 2500 # 約 5% 刻みで検証を実施
args.save_interval = 200  # 独自に追加したパラメータ
args.report_freq = 100    # この値はデフォルト値
args.model_type = "./checkpoint-2850" # Appendix B. の記載を参考に3行要約データセットで学習済みの T5 を起点とする
args.pretrained = None
args.dataset = "livedoor-3lines"
args.total_len = 1024
args.max_len = 72
args.max_num = 14 # 16 にしたいが OOM 回避の為、妥協したセッティング。
args.length_penalty = 2.0 # α
args.gen_max_len = 80
args.gen_min_len = 48

上記の設定だと学習レートは以下のようになります。 T5 の学習ではいつも固定レートの 0.001 を使うので、 かなり低く感じましたが、事前に T5 としてファインチューニング済みのモデルをさらに BRIO で追い込む形になるので、 これくらいがちょうどなのかも知れませんね。

learning_rate

以下のようにして、学習を開始します。

run(0, args)

学習中、args.report_freq 毎に以下のようなログが出力されます。

id: 0
similarity: tensor([[-0.0223, -0.0244, -0.0232, -0.0203, -0.0239, -0.0181, -0.0194, -0.0246,
         -0.0249, -0.0196]], device='cuda:0', grad_fn=<SliceBackward0>)
gold similarity: tensor([-0.0355], device='cuda:0', grad_fn=<MulBackward0>)
epoch: 1, batch: 700, avg loss: 1.204277, avg ranking loss: 0.084403, avg mle loss: 3.602421
learning rate: 0.000001

出力項目について補足すると、以下のとおりです。

  • id :
    学習を開始すると cache フォルダに {yy}-{mm}-{dd}-{id} の形式でサブディレクトリが作成されます。その id ですね。
  • similarity :
    要約候補に対する f(S) の先頭10件分です。これが降順になるようにしたい訳です。
  • gold similarity :
    参照要約に対する f(S) です。
  • mle loss :
    Lxent に相当します。
  • ranking loss :
    Lctr に相当します。
  • loss :
    mle_loss と ranking loss の加重和です。Lmul に相当します。

また、 args.eval_interval 毎に以下の検証ログが出力されます。

test similarity: [-0.01016753 -0.01040084 -0.0118106  -0.01449606 -0.01021425 -0.01469764
 -0.01524294 -0.01420222 -0.01971403 -0.01808763 -0.01124788 -0.01862179
 -0.0173974  -0.02327717]
best ranking loss - epoch: 0, batch: 2499
val ranking loss: 0.797767
val ranking rouge1: 0.515866, rouge2: 0.246008, rougeLsum: 0.479792
best generation loss - epoch: 0, batch: 2499
val generation loss: 0.795983
val generation rouge1: 0.517141, rouge2: 0.252297, rougeLsum: 0.481579

こちらの出力項目は、以下のようになります。

  • test similarity :
    要約候補に対する f(S) です。バッチ 1000 回毎に出力されるので、今回の構成だと 1 件だけ出力されます。
  • best ranking loss :
    検証データにおける Lctr の最良値を記録したことを示します。
  • val ranking rouge1/2/Lsum :
    検証データの各サンプルにおいて f(S) が最大になった要約候補で計算した rouge ½/Lsum です。
  • val ranking loss :
    上記の ROUGE スコアを使って 1 - (rouge1 * route2 + rougeLsum) / 3 で計算されます。
  • best generation loss :
    検証データの要約対象から生成した要約候補で Lxent の最良値を記録したことを示します。
  • val generation loss :
    生成した要約候補で計算した Lxent です。
  • val generation rouge1/2/Lsum :
    生成した要約候補で計算した rouge ½/Lsum です。

リジュームの仕方

ちなみに GCS の保存結果からリジュームするときは以下のような感じです。

  • BRIO のコード修正、Python のコード検索パスの追加とRougeScorer のトークナイザの差し替えを再実行して下さい。
  • best model の出力を制御するスコアは記憶していないので、リジュームするたびに best モデルが出力されてしまいます。
  • cache 直下の出力先フォルダはシステム時刻と cache 直下のファイル/ディレクトリ数から命名されます。
!mkdir -p ./cache
!gsutil -m cp -r gs://somewhere/brio/cache/22-09-05-0 ./cache
args.model_pt="22-09-05-0/model_cur.bin" 
run(0, args) # 学習経過は ./cache/{yy}-{mm}-{dd}-1 に出力されます。
# ...
# Start resuming to {'epoch': 0, 'epoch_step': 2400, 'step_cnt': 0, 'all_step_cnt': 2400}
# Resume Complated to epoch: 0, epoch_step=2400, all_step_count=2400.
# id: 1
# similarity: tensor([[-0.0085, -0.0174, -0.0110, -0.0118, -0.0192, -0.0190, -0.0237, -0.0261,
#          -0.0306, -0.0322]], device='cuda:0', grad_fn=<SliceBackward0>)
# ...

学習曲線

そんなこんなで学習曲線は以下のようになりました。図中の ROUGE スコアは "val generation rouge1/2/Lsum” の値です。

learning_curve

ちょうど 30000 ステップが generation loss の最高値で以下のようになってました。

test similarity: [-0.01142967 -0.01344618 -0.02023575 -0.01525153 -0.01401975 -0.01695879
 -0.0154583  -0.01747351 -0.01811885 -0.01853216 -0.01582579 -0.02105819
 -0.02090364 -0.02213614]
best ranking loss - epoch: 2, batch: 5401
val ranking loss: 0.792863
val ranking rouge1: 0.523876, rouge2: 0.254159, rougeLsum: 0.488263
best generation loss - epoch: 2, batch: 5401
val generation loss: 0.783452
val generation rouge1: 0.538140, rouge2: 0.273463, rougeLsum: 0.502484

それではテストデータを使って精度を確認してみましょう。

12. テストデータでの精度

当該のモデルが含まれているディレクトリを GCS から吸い上げて、

!gsutil -m cp -r gs://somewhere/brio/cache/22-09-10-1 ./cache

以下のようにして実行します。

!mkdir -p ./result
from main import evaluation
args.model_pt="22-09-10-1/model_generation.bin" 
evaluation(args)
# 22-09-10-1
# 328it [02:46,  1.97it/s]
# ranking rouge1: 0.511026, rouge2: 0.239297, rougeL: 0.473959
#   0%|          | 0/655 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils_base.py:2110: FutureWarning: The # `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
#   FutureWarning,
# /usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py:1818: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
#  next_indices = next_tokens // vocab_size
# 100%|█████████▉| 654/655 [10:44<00:00,  1.02it/s]
# evaluation rouge1: 0.537812, rouge2: 0.273945, rougeL: 0.501094

入力文から生成した要約候補を参照要約と突き合わせた evaluation rouge1/2/L の値を、要約候補の生成に使った T5 の要約モデル(ステップ 3600) と精度を比較すると以下のようになりました。

test_score

Rouge1/2/Lsum それぞれで 3 pt 強の改善を確認できました。論文1では BART に比較して 5 pt 弱の改善が報告されているので、 改善幅としては小さなものになりましたが、今回の学習データ量が小さいことで、早い段階で過学習に陥ってしまったような気がします。

せっかくモデルを作ったので最後に要約を生成してみましょう。

13. 要約の生成

パラメータの設定状態はこんな感じです。

print(args.model_type)
# ./checkpoint-2850
print(args.model_pt)
# 22-09-10-1/model_generation.bin

学習済みのモデルを一旦 BRIO としてロードし、T5 の部分だけ取り出して保存しなおします。

tokenizer = T5Tokenizer.from_pretrained(args.model_type)
model = BRIO(args.model_type, tokenizer.pad_token_id, is_pegasus=args.is_pegasus)
model.load_state_dict(torch.load(os.path.join("./cache", args.model_pt), map_location='cuda:0'))
t5_scorer = model.model
t5_scorer.save_pretrained("./brio_t5-3line-summalization")

保存したものは普通の T5 のモデルとしてロードできます。

model = transformers.T5ForConditionalGeneration.from_pretrained('./brio_t5-3line-summalization')

テストデータから 1 件読み込んで、

  • 追記 :
    以下のテキストと引用元の URL の内容を見比べるとわかるのですが、どうも本文中のハイパーリンクをうまく扱えてないようで、冒頭の段落の後部が欠けてしまっています。実際に試される方はこの辺りも修正して頂くと良いかもしれません。
with open("livedoor-3lines/diverse/test.source", "r") as f:
  lines = [line.strip() for line in f.readlines()]
lines[0] # 引用元 : https://news.livedoor.com/article/detail/12303877/
# 公開中の映画『ミュージアム』より、新ビジュアルが公開され、妻夫木が特殊メイクを施し、
# スキンヘッドで霧島を演じたことが大きな話題を呼んでいる本作。新ビジュアルは、
# 霧島を追う刑事・沢村(同役について妻夫木は、「自分なりに考えたり、過去のサイコパス
# 映画を観たりもしましたが、この手の役柄は頭で考えても難しいんじゃないかと思ったんです」
# と自分なりの考えを語り、「悩んだ挙句、すべてゼロに戻して、まずマスクを被ってみること
# にしました。自我を捨てて役を楽しんで演じれば、そこから出た芝居は嘘じゃなくなる。
# 真実に変わるんじゃないか」と挑んだ心境を吐露した。一方、霧島と対峙する沢村刑事を
# 演じた小栗も「沢村という男の中には反省や悔しさ、怒りなど様々な感情が入り組んでいた
# ので、この役を生きるのはなかなかしんどい作業ではありました」と苦労を明かす。
# 霧島の罠にはまり地下室に監禁されるシーンもあるが、「自分自身も監禁状態に追い込もう
# と思っていたので、ホテル住まいをしたんです。あの一週間は、食事も劇中で食べている
# ハンバーガーだけ。むしゃぶりつくように食べていますが、本当にお腹が空いている状態
# だったので、素直にあの演技になりました」と過酷な撮影を振り返った。(編集部・小山美咲)

以下のようにして要約を生成します。

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dct = tokenizer.batch_encode_plus(lines[:1], max_length=args.total_len, return_tensors="pt", pad_to_max_length=True, truncation=True)
model.to(device)
summaries = model.generate(
                            input_ids=dct["input_ids"].to(device),
                            attention_mask=dct["attention_mask"].to(device),
                            max_length=args.gen_max_len + 2, 
                            min_length=args.gen_min_len + 1,
                            no_repeat_ngram_size=3,
                            num_beams=args.num_beams,
                            length_penalty=args.length_penalty,
                            early_stopping=True,
                        )
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
dec

# ['映画「ミュージアム」の新ビジュアルが公開された。
#  妻夫木聡が特殊メイクを施し、スキンヘッドで霧島を演じた。
# 「自我を捨てて役を楽しんで演じれば、真実に変わるんじゃないか」と挑んだ心境を吐露した。']

人の発言内容を要約するのはマズイ気がしなくもないですが、それっぽい要約文を生成できました!

ここまでの手順で作ったモデルは論文1で言うところの BRIO-Mul に相当するモデルになります。この BRIO-Mul で要約候補の生成を行い、 そのデータでもう一度 BRIO の学習を行うのが BRIO-Loop で、もう少しスコアが向上するようです。興味のある方は試してみて下さい!

今回は抽象要約タスクでしたが、生成したテキストのメトリクスを用意できれば要約以外の自己回帰生成タスクにも応用できるような気がします。

14. おわりに

今回は抽象型要約の手法である BRIO の検証を行いました。モデルには非依存の手法なので他のモデルで試したり、メトリクス関数を用意して他の生成タスクで試したりしても面白いかもしれませんね。次回は不確実性を考慮した分類ができる SNGP を紹介したいと思います。分類モデルなので既知の選択肢の中から選ばざるを得ないのですが、分からない時は「わからない」と言って欲しいですよね。そんなところに手が届く手法のようです。


  1. https://arxiv.org/abs/2203.16804 

  2. https://arxiv.org/abs/1109.2128 

  3. https://aclanthology.org/D15-1232/ 

  4. https://arxiv.org/abs/1910.10683 

  5. https://aclanthology.org/W04-1013/ 

  6. BLEU に関しては第7回で説明しているので、興味のある方は見てみてください。 

  7. 「僕は寿司が好き」という参照要約に対し「僕は寿司寿司寿司寿司寿司寿司寿司が好きィー」という要約候補でも「寿司」のカウントは 1 です。 

  8. 私は結構長いこと recall だと勘違いしていました。。。 compare-mt : https://github.com/neulab/compare-mt/blob/b6d8f79d02043243c3d8aa58373a0f4c55e17a69/compare_mt/scorers.py#L544 , sumeval : https://github.com/chakki-works/sumeval/blob/9a6eedc9634a8bcf145c45ad4516809ee1f28c7c/sumeval/metrics/rouge.py#L146 

  9. https://github.com/neulab/compare-mt/blob/b6d8f79d02043243c3d8aa58373a0f4c55e17a69/compare_mt/rouge/rouge_scorer.py#L185-L226 

  10. 学習時に経験したのは「学習データの分布」であるのに、推論時に突然「モデルが生成するトークンの分布」にさらされるということですね。https://arxiv.org/abs/1511.06732 

  11. https://github.com/yixinL7/BRIO 

  12. 論文1 では CNNDM データセットには ROUGE-½/L の平均、Xsum には ROUGE-½ の調和平均を用いています。 

  13. https://github.com/yixinL7/BRIO 

  14. https://github.com/KodairaTomonori/ThreeLineSummaryDataset 

  15. https://ailog.site/2021/10/07/2021/1007/