オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第22回 MixCSE による教師なし文章ベクトル生成の検証
オージス総研 技術部 データエンジニアリングセンター
鵜野 和也
2022年8月23日

今回は教師なしの文章ベクトル化手法である MixCSE の検証です。教師なし学習ですから教師ありの手法よりは精度的に不利でしょうが、局面によっては役に立つケースもあるのでは?と試してみることに。公開されているコードは transformers ベースなのですが、今回は Colab の TPU で動かしてみたので、その方法も紹介しますね。

1. はじめに

今回は教師なしの文章ベクトル化手法である MixCSE1 の検証をしてみました。

本連載では文章ベクトル化のモデルとして、 Sentence BERT を取り上げたこと(第9回, 第18回)がありますが、品質の良いベクトルを生成する為には大量かつ良質の教師データが必要でした。 法律や特許のような特定領域に特化した文章を扱う局面では、対象領域の文書で学習したモデルを使いたいところですが、特定領域限定の都合良いデータはなかなか手に入りません。それならば教師なしでどうにかならないか?と思った次第です。

2. MixCSE

今回ご紹介する MixCSE は教師なしの文章ベクトル化モデルです。

と言っても基本的には Contrastive Learning です。「似てるものを近づけ、似てないものを遠ざける」というやつですね。この連載だと第18回の “3.Multiple Negatives Ranking Loss” のところや、第19回で紹介したあたりでしょうか。「自己教師あり学習」といった方が良いかもしれません。

似てる似てないの基準になる anchor を hi、anchor に類似したサンプル(positive)を h’i、anchor に類似していないサンプル(negative)を h’j とすると、Contrastive Learning の損失関数は以下のようになります。

cs_loss

もう少し補足すると、

  • hi は コーパス D に含まれる文章 xi を BERT に通し、 “[CLS]” トークンに対応する出力を MLP を適用して得た埋め込み表現です。
  • positive サンプルである h’i も xi から hi と同様に生成します。ただし Dropout が効いているので値は異なります。
  • negative サンプルである h’j は コーパス D からxi 以外の文章をランダム抽出して同じように生成します。
  • hi、h’i、 h’j はそれぞれ l2 normalization して長さ 1 に揃えたものとします。
  • N はサンプリングする negative の数です。
  • τは temperature と呼ばれるハイパーパラメータです。

ところで上記の数式で hi に関する偏微分は以下のようになります。

derivative_hi

上式は (h’i - h’j) に比例、つまり hi を (h’i - h’j) の方向に動かそうとする訳ですが、これは書き方を変えると (h’i - hi) - (h’j - hi) ですので、

update_dicrection

学習は hi を h’i の方向に近づけつつ、h’j から遠ざけようとしていると見ることができます。 ここで hi に関する偏微分は exp(hiTh’j/τ)に比例なので hi と h’j の内積が大きくなると指数関数的に増大します。

hi が anchor で h’j が negative サンプルですから、anchor と識別しにくい negative サンプルからはより強い教師信号が得られるということになります。

次に C を見てみると、そもそも exp(hiTh’i/τ) ≫ exp(hiTh’j/τ) な上に、 学習が進むにつれて、前者はどんどん大きく( hi は h’i に近づく)、後者はどんどん小さく(hi は h’j から遠ざかる)ので、結局 C はどんどん大きくなり、学習はそのうち止まってしまうのが分かります。

この現象を回避する為、学習過程で人工的に作った識別の難しい(hard) negative を継続的に混入させて強い教師信号を維持するというのが MixCSE の中核のアイデアになります。作った hard negative を混ぜ込むから MixCSE なんでしょうね。

さて、肝心の Mix する hard negative (以後、mix negative と記述)は以下のように作ります。

midex_hard_negative

何も難しいことはなくて重み係数 λ で positive sample (h’i)と negative sample (h’j) を足し合わせて l2 normalization するだけですね。 λ はハイパーパラメータで、論文1では λ = 0.2 を使っています。あまり強気の設定にして mix negative が positive sample よりも anchor に近づいてしまうとマズイので抑えめの値にしたそうです。

anchor, positive, negative, mix negative を3次元のイメージで示すとこんな感じになります。

emb_distribution

図中の紫の網掛けはBERT による埋め込み表現が分布する範囲を示しています。 BERT による埋め込み表現には異方性があって埋め込む空間の一部に寄るんだそうです。それを l2 normalize で長さ 1 揃えて 3 次元で表現したので、球表面のお椀状になるわけです。

mix negative が anchor から見て positive より遠く、(random) negative より近くなっているのが見て取れます。

少し脱線

ここで図の北極にあたる分布の中心 z1 と分布範囲の端の角度を ω とし、分布上のある二点2の角度をθとします。 説明は省略しますが論文ではここから数式やら証明やら飛び交って以下の図が出てきます。d は埋め込み表現の次元数ですね。

mean_and_variance_of_theta

学習を進めると類似しない二点は互いを遠ざけようとするので、分布の範囲を示す ω は広がっていきます。 上図からは ω が π/2 を超えてくると、cosθの平均も分散も 0 に近づいていくのが分かります。cosθ ≒ 0 ということは、ほとんど直交ですね。 この二点を anchor と random negative だと捉えると、判別の難しいサンプルがほとんど皆無という状態です。

前述のとおり、強い教師信号を得るには判別の難しいサンプルが必要でしたから、やはり mix negative なしでは学習が進むに従い弱い教師信号しか得られなくなることがわかります。

話をもどしましょう。 mixed negative が増えたので損失関数も変わります。

loss_with_midex_hard_negative

とはいっても前述の損失関数の分母に mixed negative が追加されただけですね。SG(・) は “stop gradient” で誤差逆伝播が mixed negative を通らないことを意味します。SG(・) の必要性は数式が複雑なので省略しますが3、実際に SG(・) なしで実験してみると性能が落ちたそうです。

それでは実際に MixCSE を動かしてみましょう。

3.学習データの準備

MixCSE を動かす前に学習データを用意しないといけません。今回も Wiki40b を使いました。

それでは Colab のノートブックを開いて下さい。テキストデータを加工するだけなので、GPU 等のアクセラレータは不要です。

処理結果を GCS に保存するので認証を通しておきます。

from google.colab import auth
auth.authenticate_user()

今回も文章を文に分割するのに GiNZA を使いました。処理は重めなので、お好みで別なライブラリを使ってもよいでしょう。

!pip install ginza ja_ginza==5.1.0

Wiki40b は tensorflow-datasets のものを使いました。

import tensorflow_datasets as tfds
import pickle
import spacy
ds = tfds.load('wiki40b/ja', try_gcs=True)

Wiki40b には “_START_ARTICLE_”, “_NEWLINE_” 等のシンボルが含まれるので、それをパースしてタイトルと本文の行を取り出す関数です。

def parse_example(example):
  article = example["text"].numpy().decode("utf-8")
  previous = ""
  title = ""
  lines = []
  for element in article.split("\n"):
    if previous == "_START_ARTICLE_":
      title = element
    if previous == "_START_PARAGRAPH_":  
      lines.extend(element.split("_NEWLINE_"))
    previous=element
  return title, lines

先程パースした各行を文に分割する関数です。

def lines2sents(lines, nlp):
  docs = nlp.pipe(lines, disable=['ner', 'bunsetu_recognizer'])
  sents = []
  for doc in docs:
    for sent in doc.sents:
      sents.append(str(sent))
  return sents

データセットの指定範囲から文のリストを抽出する関数です。

def build_sents(head=0, tail=None):
  nlp = spacy.load('ja_ginza')
  sents =[]
  i = 0
  for split in ds.keys():
    if tail is not None and i >= tail:
      break
    for example in ds[split]:
      if i < head:
        i += 1
        continue
      if head > 0 and i == head:
        print("Resume to head[{}] is completed.".format(head), flush=True)
      if i % 10000 == 0 or (i < 100 and i % 10 == 0):
        print("processing example[{}]".format(i), flush=True)
      if i > 0 and i % 10000 == 0:
        print("Reaload nlp...")
        nlp = spacy.load('ja_ginza')
      title, lines = parse_example(example)
      sents.extend(lines2sents(lines, nlp))
      i += 1
      if tail is not None and i >= tail:
        print("The number of processed examples is reached to the specified tail[{}].".format(tail), flush=True)
        break
  return sents, i

64,000 記事毎に整形結果を出力して GCS にコピーしておきます。 途中で切れてしまったら、range() の開始位置を調整して残りを処理すればよいでしょう。

chunk_size = 64000
num_articles = 828236
for head in range(0, num_articles, chunk_size):
  tail = head + chunk_size
  sents, tail = build_sents(head=head, tail=tail)
  fname = "wiki40b_{:06d}_{:06d}.txt".format(head, tail)
  print("Writing {}.".format(fname))
  with open(fname, "w") as f:
    f.write("\n".join(sents)+"\n")
    !gsutil cp {fname} gs://somewhere/MixCSE/wiki40b/

# processing example[0]
# processing example[10]
# processing example[20]
# processing example[30]
# processing example[40]
# processing example[50]
# processing example[60]
# processing example[70]
# processing example[80]
# processing example[90]
# The number of processed examples is reached to the specified tail[64000].
# Writing wiki40b_000000_064000.txt.
# ...

出来上がるとこんな感じです。

!gsutil ls gs://somewhere/MixCSE/wiki40b/
# gs://somewhere/MixCSE/wiki40b/wiki40b_000000_064000.txt
# gs://somewhere/MixCSE/wiki40b/wiki40b_064000_128000.txt
# gs://somewhere/MixCSE/wiki40b/wiki40b_128000_192000.txt
# ...
# gs://somewhere/MixCSE/wiki40b/wiki40b_704000_768000.txt
# gs://somewhere/MixCSE/wiki40b/wiki40b_768000_828236.txt

次に10文字未満と半角英数記号のみの文は削除しつつ1ファイルに結合しました。 ファイルの処理順がアレですが、どうせ最後にシャッフルするので気にしません。

import glob
import re
pattern = re.compile("^[ -~]*$")
with open("wiki40b.txt", "w") as fo:
  for fname in glob.glob("wiki40b_*.txt"):
    print(fname)
    with open(fname, "r") as f:
      lines = f.readlines()
      lines = [line for line in lines if len(line) > 10]
      lines = [line for line in lines if not pattern.match(line)]
      fo.write("".join(lines))
# wiki40b_128000_192000.txt
# wiki40b_448000_512000.txt
# wiki40b_000000_064000.txt
# ...

シャッフルします。

!shuf wiki40b.txt > wiki40b.txt.shuf
!mv wiki40b.txt.shuf wiki40b.txt

ファイルの行数はこうなりました。

!wc -l wiki40b.txt
# 14284437 wiki40b.txt

中身はこんな感じです。文単位で1行のフォーマットですね。

!head -5 wiki40b.txt
# 第1話のみ2月4日の無料放送で先行放送された。
# 2014年初めには、アーセナルFC、パリ・サンジェルマンFC、バイエルン・ミュンヘンなど複数クラブから関心を持たれるようになり、アーセン・ベンゲルは"ネクストセスク・ファブレガス"と評した。
# 2011年以降、腕時計の時差修正機能にさらに二つの方式が登場した。
# 春まで充分乾燥させてから使用するが、耐久性を高めるために使用前に燻したりする場合もある。
# また競技会場の一部が1996年開催のひろしま国体にも使われた。

できあがりを GCS にコピーしておきます。

!gsutil cp wiki40b.txt gs://somewhere/MixCSE/wiki40b/

文の長さがどの程度なのか 前述の build_sents() でとりあえず 1000 記事処理したファイル(wiki40b_000000_001000.txt)を作って確認して見ましょう。

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

lengthes = []
with open("wiki40b_000000_001000.txt", "r") as f:
  lines = f.readlines()
  lines = [line.strip() for line in lines]
  for line in lines:
    tokens = tokenizer.encode(line)
    lengthes.append(len(tokens))

import numpy as np
import matplotlib.pyplot as plt
plt.hist(lengthes, bins=50) 

hist_of_sentence_length

こんな感じになりました。 TPU の場合、入力シェイプを固定する必要があるので、短いものを最大長までパディングし、長いものは末端を切り捨てることになります。出来るだけ最大長を短くしたいが、末尾が切れてるのも増やしたくないしー、と悩んで今回は最大シーケンス長 72 で学習することにしました。

では、 MixCSE を動かしていきましょう。

4. MixCSE の学習

冒頭でも述べましたが、今回は TPU で動かしてみました。

TPU を使った理由は

  • Colab は GPU が混んでいてすぐにリソース利用上限に達してしまうが、TPU は比較的空いている(ような気がする)。
  • PyTorch / Transformers も TPU で動かせると(最近になって)知ったので、試してみたい。

といったところです。Colab で動かす分には TPU が一番パワーがあるので動かし方を覚えると今後もいいことあるかなと。

セットアップ

新しく notebook を開いて、アクセラレータには TPU を選んでください。

まず、cloud-tpu-client をインストールします。これが torch 1.9 に依存するようなので、 torch もバージョンを 1.9.0 に差し替えます。

!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchtext==0.10.0 -f https://download.pytorch.org/whl/cu111/torch_stable.html

学習プロセスの起動に transformers の examples に含まれる xla_spawn.py を使うので、リポジトリを clone しておきます。

!git clone -b v4.18.0 https://github.com/huggingface/transformers.git

transformers と MeCab 関連をインストールします。

!pip install transformers==4.18.0
!apt-get install mecab mecab-ipadic-utf8 python-mecab libmecab-dev
!pip install mecab-python3 fugashi ipadic

学習データの読込には datasets を使います。

!pip install datasets==2.2.0

今回、学習した文書ベクトル化モデルの検証には sentence-transformers に含まれるものを流用しました。

!pip install sentence-transformers==2.2.0

最後に MixCSE のコードを取得します。

!git clone https://github.com/BDBC-KG-NLP/MixCSE_AAAI2022
!cd MixCSE_AAAI2022 && git checkout a3c0ee2166b4607526629d7b5592fd2bee509e7b

GCS の認証を通しておきます。

import os
os.environ["USE_AUTH_EPHEM"] = '0'
from google.colab import auth
auth.authenticate_user()

学習データの準備

加工済みの学習データを GCS から取得します。またタイムスタンプが変わると datasets のキャッシュが効かなくなるので固定してしまいます。

!gsutil cp gs://somewhere/MixCSE/wiki40b/wiki40b.txt .
!touch -d "2022-5-14 18:00" wiki40b.txt

検証には 第18回で使った JSNLI データセットを使いました。

!wget https://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi?down=https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JSNLI/jsnli_1.1.zip&name=JSNLI.zip
!mv *zip* jsnli_1.1.zip
!unzip jsnli_1.1.zip
!ls jsnli_1.1
# dev.tsv  README.md  train_w_filtering.tsv  train_wo_filtering.tsv

学習に使用するスクリプト

学習に使用するスクリプトですが、基本方針は以下のとおりです。

  • モデル(PretrainedModel) のコードは MixCSE のもの(MixCSE_AAAI2022/mixcse/models_mix.py)をそのまま利用。
  • それ以外は transformers/examples/pytorch/language-modeling/run_mlm.py をベースに改修して作成。

上記の run_mlm.py は本連載でも何度か使用した Trainer をベースに記述されています。 Trainer 自体が TPU 実行に対応しているので、「 TPU にするからこうしないといけない!」という箇所はあまりなかった気がします。

もちろん MixCSE にも同様のコード(MixCSE_AAAI2022/train.py)が含まれていて、TPU を意識した処理も記載されているのですが、 試してみると TPU では動きませんでした4

まずは、モデル関連のパラメータクラスです。このクラスは MixCSE_AAAI2022/train.py から拾ってきました。

%%writefile MixCSE_AAAI2022/model_args.py
# coding=utf-8
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import MODEL_FOR_MASKED_LM_MAPPING
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

@dataclass
class ModelArguments:

    # Huggingface's original arguments
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The model checkpoint for weights initialization."
            "Don't set if you want to train a model from scratch."
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )

    # SimCSE's arguments
    temp: float = field(
        default=0.05,
        metadata={
            "help": "Temperature for softmax."
        }
    )
    pooler_type: str = field(
        default="cls",
        metadata={
            "help": "What kind of pooler to use (cls, cls_before_pooler, avg, avg_top2, avg_first_last)."
        }
    ) 

    output_size: float = field(
        default=768,
        metadata={
            "help": "Ootput size"
        }
    )

    lambdas: float = field(
        default=0.0,
        metadata={
            "help": "layers of gcnn"
        }
    )

    gcnn_layers: int = field(
        default=4,
        metadata={
            "help": "layers of gcnn"
        }
    )

    hard_negative_weight: float = field(
        default=0,
        metadata={
            "help": "The **logit** of weight for hard negatives (only effective if hard negatives are used)."
        }
    )

    do_mlm: bool = field(
        default=False,
        metadata={
            "help": "Whether to use MLM auxiliary objective."
        }
    )
    mlm_weight: float = field(
        default=0.1,
        metadata={
            "help": "Weight for MLM auxiliary objective (only effective if --do_mlm)."
        }
    )
    mlp_only_train: bool = field(
        default=False,
        metadata={
            "help": "Use MLP only during training"
        }
    )

次にデータ関係のパラメータクラスです。このクラスも MixCSE_AAAI2022/train.py から拾ってきて、以下の項目を追加したくらいです。

  • cache_file_path :
    学習開始時点でデータセットを全部 tokenize するのですが、 Colab だと 2 時間かかります。この際、処理結果が cache されるのですが、ランタイムを再構築すると、生成された cache を GCS に退避して戻してあっても、タイムスタンプなのか何なのか cache を見ずに、また 2 時間コースだったので、明示的にファイル指定するようにしてしまいました。
  • skip_steps :
    こちらもランタイムを再構築して resume したときの話です。デフォルトだと再開時のステップ数に合わせて学習データを先頭から読み飛ばしてくれます。今回のデータでは 0.8 エポック周辺で resume すると、この読み飛ばしに 12 時間ほどかかりそうでした。読み飛ばしだけでランタイムの寿命が尽きてしまいます。 仕方ないのでデフォルトの仕組みでの読み飛ばしを諦めました。とはいえ resume 時に既視のデータを見てしまうのもイヤだったので、別途 skip_steps で指定した分を読み飛ばすことにしました。実際の読み飛ばしの仕組みは後述します。
%%writefile MixCSE_AAAI2022/data_training_args.py
# coding=utf-8
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import MODEL_FOR_MASKED_LM_MAPPING
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    # Huggingface's original arguments. 
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    eval_path: Optional[str] = field(
        default=None, metadata={"help": "Dataset for evaluation."}
    )
    cache_file_path: Optional[str] = field(
        default=None, metadata={"help": "Cache file path to tokenized dataset."}
    )
    skip_steps: Optional[int] = field(
        default=None, metadata={"help": "number of skipping steps when ignore_data_skip is True. Note: world_size should be 1."}
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )

    # SimCSE's arguments
    train_file: Optional[str] = field(
        default=None, 
        metadata={"help": "The training data file (.txt or .csv)."}
    )
    max_seq_length: Optional[int] = field(
        default=32,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated."
        },
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch."
        },
    )
    mlm_probability: float = field(
        default=0.15, 
        metadata={"help": "Ratio of tokens to mask for MLM (only effective if --do_mlm)"}
    )

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."

次は Trainer をカスタマイズしたクラスです。

%%writefile MixCSE_AAAI2022/cl_trainer.py
# coding=utf-8
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn, Tensor
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled

class CLTrainer(Trainer):

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:

        model.train()
        inputs = self._prepare_inputs(inputs)

        if is_sagemaker_mp_enabled():
            scaler = self.scaler if self.do_grad_scaling else None
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
            return loss_mb.reduce_mean().detach().to(self.args.device)

        with self.autocast_smart_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
            loss = loss / self.args.gradient_accumulation_steps

        if self.do_grad_scaling:
            self.scaler.scale(loss).backward()
        elif self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.deepspeed:
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
        else:
            loss.backward()

        # Note:
        # copied from MixCSE_AAAI2022/mixcse/trainers.py (from here)
        if self.state.global_step % self.args.eval_steps == 0:
            with torch.no_grad():
                output = model(**inputs,return_dict=True)
                logits = output.logits.detach().cpu()
                pos_index = torch.arange(logits.size(0)).to(logits.device)
                pos_index = pos_index * logits.size(1) + pos_index
                pos_scores = logits.take(pos_index).sum().item()
                mix_num = logits.size(1) / logits.size(0)
                mix_scores = 0
                mix_index = pos_index
                for i in range(0,int(mix_num-1)):
                    mix_index += logits.size(0)
                    mix_scores += logits.take(mix_index).sum().item()
                neg_scores = (logits.sum().item() - pos_scores - mix_scores) / ((logits.size(1)-mix_num)*logits.size(0))

                pos_scores = pos_scores / logits.size(0)
                mix_scores = mix_scores / (logits.size(0) * (mix_num-1)) if mix_num > 1 else 0
                self.log({'step':self.state.global_step,'pos_scores':pos_scores,'neg_scores':neg_scores,'mix_scores':mix_scores})

        # (to here)
        return loss.detach()

    def compute_loss(self, model, inputs, return_outputs=False):
        if not model.training:
            # NOTE:
            # In evaluation, we have no "real" labels, so we return dummy loss.
            # label values are mere sentence ids whitch used to truncate results of padding sentences.
            # in compute_metrics()
            outputs = model(**inputs, sent_emb=True)
            loss = torch.Tensor([0.0]) # dummy value.
            return (loss, outputs) if return_outputs else loss
        return super().compute_loss(model, inputs, return_outputs)

    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        if hasattr(self, '_sequence_sampler'):
            return self._sequence_sampler
        else:
            return super()._get_train_sampler()

改修した理由ですが、

  • training_step():
    pos_scores, neg_scores, mix_scores のログ出力を足しています。元ネタは MixCSE_AAAI2022/mixcse/trainers.py から拾ってきました。それぞれ、anchor と positive、(random) negative, mix negative との類似度の平均です。以下に示す論文の 図 4 (a) の元ネタがこれかと思います。

pos_neg_mix_scores

  • compute_loss():
    今回検証には sentence-transformers の ParaphraseMiningEvaluator を使いました。サンプル毎のラベルはありませんから、 loss の値は 0.0 のダミー値を返しています。

  • _get_train_sampler():
    これは前述の読み飛ばしと関連する話です。 今回のコードは datasets.load_dataset() でテキストファイルを読み込んでおり、 map-style の Dataset になります5。このタイプは dataset[idx] としてサンプルにアクセスでき、dataset 中のサンプルを読み出す順序は Sampler によって決定されます。Sampler はサンプルを読み出す順序を示すインデックス系列を生成するクラスです。後述しますが、今回は resume 時の読み飛ばし処理を入れた Sampler を self._sequence_sampler に仕込んだので、その対応ロジックです。

次は JSNLI のデータを読み込んで ParaphraseMiningEvaluator に投入できる形に整形する関数です。

%%writefile MixCSE_AAAI2022/load_data_for_parafrace_mining.py
# coding=utf-8

# NOTE:
# Copied from https://www.ogis-ri.co.jp/otc/hiroba/technical/similar-document-search/part18.html
def load_data_for_parafrace_mining(filename):
    sentences_map = {} # id -> sent
    sentences_reverse_map = {} # sent -> id
    duplicates_list = [] # (id1, id2)

    def register(sent):
        if sent not in sentences_reverse_map:
            id = str(len(sentences_reverse_map))
            sentences_reverse_map[sent] = id
            sentences_map[id] = sent
            return id
        else:
            return sentences_reverse_map[sent]

    with open(filename, "r") as f:
        lines = f.readlines()
        lines = [line.strip().split("\t") for line in lines]
        rows = [[line[0], line[1].replace(" ", ""), line[2].replace(" ", "")] for line in lines]
        for row in rows:
            label = row[0] 
            sent1 = row[1]
            sent2 = row[2]
            ids = [register(sent) for sent in [sent1, sent2]]
            if label == "entailment":
                duplicates_list.append(tuple(ids))
    return sentences_map, duplicates_list

続いて問題の Sampler ですね。

%%writefile MixCSE_AAAI2022/seq_sampler.py
# coding=utf-8

from torch.utils.data.sampler import Sampler
from typing import Iterator, Optional, Sized

class SequenceSampler(Sampler[int]):
    def __init__(self, sequence: Sized, skip_steps=None, batch_size=None) -> None:
        if skip_steps and batch_size:
            skip_examples = skip_steps * batch_size
            skip_examples = skip_examples % len(sequence)   
            self.sequence = sequence[skip_examples:] + sequence[:skip_examples]
        else:
            self.sequence = sequence

    def __iter__(self) -> Iterator[int]: 
            return iter(self.sequence)

    def __len__(self) -> int:
        return len(self.sequence)

処理としては、 1 epoch 分の読み出し順を Sampler の外で決定しておいて、skip_steps * batch_size 分を先頭から切り落として末尾につけているだけです。 Colab の TPU は 8 コアなので 8 並行で動かしたいところですが、その場合はもっといろいろしないといけないかと思います。 今回は後述する諸般の事情で 1 コアのみで実行、学習量も 1 epoch ということにしたので、これで問題ないでしょう。

最後に main() 関数です。長いですね。。。一応、元ネタになった run_mlm.py から乖離が大きい部分は補足説明をしますが、興味なければ読み飛ばしちゃってください。

%%writefile MixCSE_AAAI2022/run_mixcse.py
# coding=utf-8

# Note: This script is a modified version of https://github.com/huggingface/transformers/blob/v4.18.0/examples/pytorch/language-modeling/run_mlm.py

import logging
import math
import os
import sys
from dataclasses import dataclass, field

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn, Tensor
import numpy as np
from numpy import ndarray

from transformers.utils import is_sagemaker_mp_enabled, is_apex_available
if is_apex_available():
    from apex import amp

from datasets import load_dataset, Dataset

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_MASKED_LM_MAPPING,
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
    default_data_collator,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from mixcse.models_mix import BertForCL

from sentence_transformers.evaluation import ParaphraseMiningEvaluator

logger = logging.getLogger(__name__)

from model_args import ModelArguments
from data_training_args import DataTrainingArguments
from cl_trainer import CLTrainer
from load_data_for_parafrace_mining import load_data_for_parafrace_mining
from seq_sampler import SequenceSampler
import random

def main():

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)

    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
    logger.info("Training/evaluation parameters %s", training_args)

    set_seed(training_args.seed)

    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
            )
            datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        datasets = load_dataset(extension, data_files=data_files)

    config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "use_auth_token": True if model_args.use_auth_token else None,
    }
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
        "use_auth_token": True if model_args.use_auth_token else None,
    }
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
       model = BertForCL.from_pretrained(
                model_args.model_name_or_path,
                from_tf=bool(".ckpt" in model_args.model_name_or_path),
                config=config,
                cache_dir=model_args.cache_dir,
                revision=model_args.model_revision,
                use_auth_token=True if model_args.use_auth_token else None,
                model_args=model_args
               )
    else:
        logger.info("Training new model from scratch")
        model = AutoModelForMaskedLM.from_config(config)

    model.resize_token_embeddings(len(tokenizer))

    # Prepare features
    column_names = datasets["train"].column_names
    sent2_cname = None
    # Unsupervised datasets
    sent0_cname = column_names[0]
    sent1_cname = column_names[0]

    def prepare_features(examples):
        total = len(examples[sent0_cname])
        for idx in range(total):
            if examples[sent0_cname][idx] is None:
                examples[sent0_cname][idx] = " "
            if examples[sent1_cname][idx] is None:
                examples[sent1_cname][idx] = " "
        sentences = examples[sent0_cname] + examples[sent1_cname]
        sent_features = tokenizer(
            sentences,
            max_length=data_args.max_seq_length,
            truncation=True,
            padding="max_length" if data_args.pad_to_max_length else False,
        )
        features = {}
        for key in sent_features:
            features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)]
        return features

    if training_args.do_train:
        train_dataset = datasets["train"].map(
            prepare_features,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            cache_file_name=data_args.cache_file_path,
        )

    # NOTE: Add evalation code using sentence-transformers's ParaphraseMiningEvaluator.(from here.)
    def prepare_eval_features(examples):
        sentences = examples["sentences"]
        features = tokenizer(
            sentences,
            max_length=data_args.max_seq_length, 
            truncation=True,
            padding="max_length" if  data_args.pad_to_max_length else False,
        )
        # Add sentence id as labels. 
        # labels are used to truncate padding and reorder in compute_metrics()
        features["labels"] = [int(i) for i in examples["ids"]]
        return features

    # This is a dummy class to bypass prediction in ParaphraseMiningEvaluator.
    # embeddings are predicted in Trainer's evaluation.
    class ModelAdapter:
        def __init__(self, embeddings):
            self.embeddings = embeddings

        def encode(self, sentences: Union[str, List[str]],
                show_progress_bar = False,
                batch_size: int = 32,
                convert_to_numpy: bool = True,
                convert_to_tensor: bool = False,
                device: str = None,
                normalize_embeddings: bool = False) -> Union[List[Tensor], ndarray, Tensor]:
            return self.embeddings

    if training_args.do_eval:
        sentences_map, duplicates_list = load_data_for_parafrace_mining(data_args.eval_path)
        evaluator = ParaphraseMiningEvaluator(sentences_map, duplicates_list)

        # Pad empty sentences to align size of all batches due to avoid XLA recompilation.
        pad_len = training_args.per_device_eval_batch_size - len(evaluator.sentences) % training_args.per_device_eval_batch_size 
        sentences = evaluator.sentences + [""] * pad_len
        ids = evaluator.ids + ["-1"] * pad_len

        eval_dataset = Dataset.from_dict({"sentences": sentences, "ids": ids})
        eval_dataset = eval_dataset.map(prepare_eval_features, batched=True, remove_columns=["sentences", "ids"])

        def compute_metrics(eval_prediction):
            ids = eval_prediction.label_ids
            embeddings = eval_prediction.predictions
            # type of ids : <class 'numpy.ndarray'>, len=5824
            # type of embeddings : <class 'tuple'>, len=2
            #   shape of embeddings[0] : (5824, 32, 768), this is 'last_hidden_state'.
            #   shape of embeddings[1] : (5824, 768), this is 'pooler_output'
            embeddings = embeddings[1]

            # remove padding and reorder. 
            ids = [id for id in ids if id > -1]
            embeddings = [embeddings[id]  for id in np.argsort(ids)]

            model_adapter = ModelAdapter(embeddings)
            ap = evaluator(model_adapter,  output_path=training_args.output_dir)
            return {'AP': ap}
    # (to here.)

    # Initialize our Trainer
    trainer = CLTrainer(
    #trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics if training_args.do_eval else None,
    )
    #trainer.model_args = model_args

    # Set sequantail sampler to avoid too slow skipping when continuing training from checkpoint.
    # Please set --ignore_data_skip flag and --skip_steps. but num_cores should be 1!
    if training_args.ignore_data_skip and data_args.skip_steps is not None:
        if training_args.world_size == 1 :
            num_samples = len(train_dataset)
            logger.info("Building shuffled indices for SequenceSampler from num_examples={} and skipping {} steps.".format(num_samples, data_args.skip_steps))
            sequence = [i for i in range(num_samples)]
            random.shuffle(sequence)
            logger.info("Shuffled indices = {}".format(sequence[:5]))
            trainer._sequence_sampler = SequenceSampler(sequence, skip_steps=data_args.skip_steps, batch_size=training_args.per_device_train_batch_size)
            logger.info("Skipped shuffled indices = {}".format(trainer._sequence_sampler.sequence[:5]))
        else:
            logger.warn("--skip_steps={} ignored. When using '--skip-steps', world_size should be 1.").format(data_args.skip_steps)


    # Training
    if training_args.do_train:
        if last_checkpoint is not None:
            checkpoint = last_checkpoint
        elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
            checkpoint = model_args.model_name_or_path
        else:
            checkpoint = None
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()  # Saves the tokenizer too for easy upload

        output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
        if trainer.is_world_process_zero():
            with open(output_train_file, "w") as writer:
                logger.info("***** Train results *****")
                for key, value in sorted(train_result.metrics.items()):
                    logger.info(f"  {key} = {value}")
                    writer.write(f"{key} = {value}\n")

            # Need to save the state, since Trainer.save_model saves only the tokenizer with the model
            trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))

    # Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        eval_output = trainer.evaluate()

        results = eval_output

        output_eval_file = os.path.join(training_args.output_dir, "eval_results_ap.txt")
        if trainer.is_world_process_zero():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key, value in sorted(results.items()):
                    logger.info(f"  {key} = {value}")
                    writer.write(f"{key} = {value}\n")

    return results

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()    

上記のコードでトリッキーなところを部分的に補足

まず、以下の ModelAdapter は sentence-transfomers の ParaphraseMiningEvaluator と Trainer の橋渡しをする SentenceTransformer 互換のダミークラスです。

    # This is a dummy class to bypass prediction in ParaphraseMiningEvaluator.
    # embeddings are predicted in Trainer's evaluation.
    class ModelAdapter:
        def __init__(self, embeddings):
            self.embeddings = embeddings

        def encode(self, sentences: Union[str, List[str]],
                show_progress_bar = False,
                batch_size: int = 32,
                convert_to_numpy: bool = True,
                convert_to_tensor: bool = False,
                device: str = None,
                normalize_embeddings: bool = False) -> Union[List[Tensor], ndarray, Tensor]:
            return self.embeddings

今回、精度計算に流用する sentence-transformers の ParaphraseMiningEvaluator は

  • 引数で受け取った SentenceTransformer を用いて Evaluator 自身が保持するデータをベクトル化、その後に精度計算する。

という仕様です。ですが、 Trainer で学習する際は

  • 計算済みの推論結果とラベルが compute_metrics() に渡されて精度計算する。

という流れです。つまり、compute_metrics() での精度計算に ParaphraseMiningEvaluator のロジックを使いたいけど、推論(文章のベクトル化)は別のところで終わってる。 という状態になるので、上記の ModelAdapter を使って ParaphraseMiningEvaluator 内部でのベクトル化をパススルーしています。

次に以下の処理です。

        # Pad empty sentences to align size of all batches due to avoid XLA recompilation.
        pad_len = training_args.per_device_eval_batch_size - len(evaluator.sentences) % training_args.per_device_eval_batch_size 
        sentences = evaluator.sentences + [""] * pad_len
        ids = evaluator.ids + ["-1"] * pad_len

ここは検証セットのサンプル数がバッチサイズで割り切れるように詰め物を入れています。 TPU で実行するときは入力のシェイプが固定でないとXLA の再コンパイルが必要となりかなり重いので、それを回避するためのコードです。 詰め物は後で除去できるよう id に -1 を設定してます。

最後に compute_metrics() です。

        def compute_metrics(eval_prediction):
            ids = eval_prediction.label_ids
            embeddings = eval_prediction.predictions
            # type of ids : <class 'numpy.ndarray'>, len=5824
            # type of embeddings : <class 'tuple'>, len=2
            #   shape of embeddings[0] : (5824, 32, 768), this is 'last_hidden_state'.
            #   shape of embeddings[1] : (5824, 768), this is 'pooler_output'
            embeddings = embeddings[1]

            # remove padding and reorder. 
            ids = [id for id in ids if id > -1]
            embeddings = [embeddings[id]  for id in np.argsort(ids)]

            model_adapter = ModelAdapter(embeddings)
            ap = evaluator(model_adapter,  output_path=training_args.output_dir)
            return {'AP': ap}

eval_prediction.label_ids にはサンプルの id が格納されるように仕込んでいます。 eval_prediction.predictions は長さ 2 の tuple になっていて二つ目が MixCSE で埋め込まれた検証データの文章ベクトル(embeddings)です。

ここで、id が -1 のものを目印に embeddings から前述の詰め物を除去、念のため id の昇順でソートした上で、 embeddings を ModelAdapter ラップして ParaphraseMiningEvaluatorに投入、計算された Average Precision をdict にして返します。

あとは、resume 時の読み飛ばしの処理ですね。デフォルトの読み飛ばしを無効にする ignore_data_skip が有効で skip-steps が設定され、1 コアでの実行時のみ、学習データからサンプルを読み出す順序(sequence)を作って SequenceSampler にセット、それを Trainer にねじ込んでいます。 seed の値を変えなければ random.shuffle(sequnece) の結果は毎回同じはず。

    # Set sequantail sampler to avoid too slow skipping when continuing training from checkpoint.
    # Please set --ignore_data_skip flag and --skip_steps. but num_cores should be 1!
    if training_args.ignore_data_skip and data_args.skip_steps is not None:
        if training_args.world_size == 1 :
            num_samples = len(train_dataset)
            logger.info("Building shuffled indices for SequenceSampler from num_examples={} and skipping {} steps.".format(num_samples, data_args.skip_steps))
            sequence = [i for i in range(num_samples)]
            random.shuffle(sequence)
            logger.info("Shuffled indices = {}".format(sequence[:5]))
            trainer._sequence_sampler = SequenceSampler(sequence, skip_steps=data_args.skip_steps, batch_size=training_args.per_device_train_batch_size)
            logger.info("Skipped shuffled indices = {}".format(trainer._sequence_sampler.sequence[:5]))
        else:
            logger.warn("--skip_steps={} ignored. When using '--skip-steps', world_size should be 1.").format(data_args.skip_steps)

それではいよいよ学習の実行です。

学習の実行

まずは、 TPU がらみの設定をします。 環境変数 XLA_USE_BF16 を設定することでTPU での計算を bfloat16 で行うことが出来ます6

import os
assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
os.environ["TPU_IP_ADDRESS"] = os.environ['COLAB_TPU_ADDR'].replace(":8470","")
os.environ["XRT_TPU_CONFIG"] = "tpu_worker;0;" + os.environ['COLAB_TPU_ADDR']
os.environ["XLA_USE_BF16"] = "1"

今回、学習からの出力先を /content/mixcse_bert_wiki40b に設定するので、 Huggingface 関連のキャッシュも同じディレクトリにまとめました。

os.environ['TRANSFORMERS_CACHE'] = '/content/mixcse_bert_wiki40b/cache'
os.environ['HF_DATASETS_CACHE'] = '/content/mixcse_bert_wiki40b/datasets_cache'

以下のようにして学習を実行します。

!export PYTHONPATH=${PYTHONPATH}:.:./MixCSE_AAAI2022 && \
  python ./transformers/examples/pytorch/xla_spawn.py \
    --num_cores 1 \
    ./MixCSE_AAAI2022/run_mixcse.py \
    --model_name_or_path 'cl-tohoku/bert-base-japanese-whole-word-masking' \
    --train_file 'wiki40b.txt' \
    --eval_path 'jsnli_1.1/dev.tsv' \
    --output_dir './mixcse_bert_wiki40b' \
    --cache_file_path './mixcse_bert_wiki40b/tokenized_train_dataset_cache' \
    --pooler_type 'cls' \
    --max_seq_length 72 \
    --pad_to_max_length True \
    --temp 0.05 \
    --lambdas 0.2 \
    --do_train \
    --num_train_epochs 1 \
    --per_device_train_batch_size 64 \
    --per_device_eval_batch_size 64 \
    --learning_rate 3e-5 \
    --eval_steps 4000 \
    --save_steps 4000 \
    --save_total_limit 6 \
    --evaluation_strategy 'steps' \
    --use_fast_tokenizer False \
    --label_names 'labels' \
    --ignore_data_skip True \
    --skip_steps 0 \
    --seed 42 \
    --overwrite_output_dir \

# WARNING:root:Waiting for TPU to be start up with version pytorch-1.9...
# WARNING:root:Waiting for TPU to be start up with version pytorch-1.9...
# WARNING:root:TPU has started up successfully with version pytorch-1.9
# WARNING:run_mixcse:Process rank: -1, device: xla:1, n_gpu: 0distributed training: False, 16-bits training: False
# ...
# [INFO|trainer.py:1290] 2022-05-16 07:47:40,085 >> ***** Running training *****
# [INFO|trainer.py:1291] 2022-05-16 07:47:40,085 >>   Num examples = 14284437
# [INFO|trainer.py:1292] 2022-05-16 07:47:40,085 >>   Num Epochs = 1
# [INFO|trainer.py:1293] 2022-05-16 07:47:40,085 >>   Instantaneous batch size per device = 64
# [INFO|trainer.py:1294] 2022-05-16 07:47:40,085 >>   Total train batch size (w. parallel, distributed & accumulation) = 64
# [INFO|trainer.py:1295] 2022-05-16 07:47:40,085 >>   Gradient Accumulation steps = 1
# [INFO|trainer.py:1296] 2022-05-16 07:47:40,085 >>   Total optimization steps = 223195
# {'step': 0, 'pos_scores': 18.673828125, 'neg_scores': 13.552824797453704, 'mix_scores': 15.11181640625, 'epoch': 0}
# {'loss': 0.013, 'learning_rate': 2.9932794193418313e-05, 'epoch': 0.0}
# {'loss': 0.0008, 'learning_rate': 2.986558838683662e-05, 'epoch': 0.0}
# {'loss': 0.0005, 'learning_rate': 2.9798382580254937e-05, 'epoch': 0.01}
# {'loss': 0.0004, 'learning_rate': 2.9731176773673245e-05, 'epoch': 0.01}
# ...

起動スクリプトには transformers に含まれる xla_spawn.py をそのまま利用しました。一部の引数について補足します。

  • num_cores:
    学習に使用するコア数です。 1 もしくは 8 が設定可能で、 Colab の TPU は 8 コアなので 8 を指定したいところですが、今回は以下の理由で 1 にしました。
    • 論文1にはバッチサイズを 64 としたと記載があり、その設定に合わせました。計算に bfloat16 を使い、バッチサイズ 64 であれば 1 コアに収まります。
    • MixCSE のコードには分散実行された場合に他のプロセスでの計算結果を集め実質的なバッチサイズを大きくする処理7が記載されているのですが、 xla_span.py で並行起動した場合にはこのブロックが実行されませんでした。
    • xla_span.py で並行起動する場合は前述の SequenceSampler のところももう少し工夫が必要になりそう。今回はバッチサイズ 64 と決めた時点で「 1 コアでいいや」となったので、細かく調べる気力がなくなりました。
    • 8 コアの場合、学習途中のチェックポイントから resume しようとすると、おそらくは OOM で落ちる8
  • model_name_or_path:
    この連載でいつも使っている cl-tohoku/bert-base-japanese-whole-word-masking を使用しました。
  • output_dir:
    学習中に生成されるチェックポイントの出力先です。ランタイムの寿命が来ると全部消えてしまうので、予め Google Drive をマウントしておくなり、適当な頃合いで止めて GCS に退避するなりして下さい。
  • cache_file_path:
    トークナイズしたデータセットのキャッシュファイルの名称を明示的に与えています。
  • pad_to_max_length:
    TPU で動かす場合はインプットのシェイプがそろっていないと XLA コンパイルを誘発するので True を設定しました。
  • temp:
    Contrastive Loss の temperature ですね。論文1に合わせて 0.05 にしました。
  • lambda:
    mix negative を作るときの λ です。論文1に合わせて 0.2 にしました。
  • label_names:
    evaluation 時のラベル(実際は文のID)が格納されている列名の指定です。指定しなくても動いたかも。
  • ignore_data_skip:
    Trainer のデフォルトの読み飛ばし処理を無効化するので True にしています。
  • skip_steps:
    今回用意した SequenceSampler で読み飛ばすステップ数です。初回起動なので 0 とします。
  • overwrite_output_dir:
    出力先(output_dir)が存在する場合に上書きして良いことを示す指定です。今回の main.py ではこのパラメータが未設定の場合、出力先に含まれる最新のチェックポイントから resume するようなコードになっています。

学習途中から再開する場合は、./mixcse_bert_wiki40b に退避した内容を戻してから、以下のように実行します(200000 ステップから再開する例です)。

    ...
    --output_dir './mixcse_bert_wiki40b' \
    ...
    --skip_steps 200000 \
    --seed 42 \
    #--overwrite_output_dir \

上記の実行ログの以下の行が学習中の anchor と positive, (random) negative, mix negative の類似度になります。

# {'step': 0, 'pos_scores': 18.673828125, 'neg_scores': 13.552824797453704, 'mix_scores': 15.11181640625, 'epoch': 0}

起動後のログから拾ってプロットすると以下のようになりました。

pos_neg_mix_scores

論文の図4 (a)で示されていたのと同じような値に落ち着いているので、良い感じに動いているのだと思います。

検証結果の確認

Paraphrase Mining での検証結果も見てみましょう。 ちなみに、第18回で JSNLI データセットを使って Sentence BERT を学習した際は AP = 12.87、F1(best) = 23.61 でした。

evaluation_result

あれ。。。全然ダメですね。。。AP = 1.69、F1(best) = 7.00 になってしまいました。

でも良く考えたら、検証に使った JSNLI データセットは人手で entailment(含意), contradiction(矛盾), neutral (中立) のラベルを付けたものなので、 「単語の選択は似ているけれど意味合い的には矛盾する」 ような高度な認識が必要になりそうです。

MixCSE のような dropout の間引き具合で作った positive , ランダム抽出した negative , それらを混ぜた mix negative での学習では上記のような認識は感覚的にも厳しいような気がしますね。

別のデータセットで推論させてみて、もう少し見てみることにしましょう。

推論の実行

まずは、 MixCSE を Python の検索パスに入れて、必要なクラスをインポートします。

import sys
sys.path.append("./MixCSE_AAAI2022")
from mixcse.models_mix import BertForCL
from run_mixcse import ModelArguments

つぎに torch と torch_xla をインポートして、 xm.xla_device() で取得したのが TPU ですね。

import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()

モデルのパラメータクラスに必要な設定をして、

model_args=ModelArguments()
model_args.use_fast_tokenizer = False
model_args.lambdas = 0.2

学習で生成されたチェックポイントをロードして、TPU に移します。

model = BertForCL.from_pretrained("./mixcse_bert_wiki40b/checkpoint-220000", model_args=model_args)
model.to(device)

トークナイザもロードします。

from transformers import AutoTokenizer
max_length = 72 
model_name_or_path = 'cl-tohoku/bert-base-japanese-whole-word-masking'
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

とりえあえず、推論してみましょう。推論時は sent_emb=True を設定する必要があります。

sentences = ["吾輩は猫である。"]
batch = tokenizer.batch_encode_plus(sentences, return_tensors='pt', padding='max_length', max_length = max_length)
for k in batch:
  batch[k] = batch[k].to(device)
with torch.no_grad():
    outputs = model(**batch, output_hidden_states=True, return_dict=True, sent_emb=True)
pooler_output = outputs['pooler_output'].cpu()
pooler_output.shape
# torch.Size([1, 768])

大丈夫そうですね。評価タスクとしては sentence-transformers の InformationRetrievalEvaluator を使います。

今度は検証実行時に文章をベクトル化するので、 先程ロードしたモデルに StentenceTransformer 互換の encode() メソッドを追加して、 インタフェースを合わせてしまいましょう。

TPU で動かす前提としたので、シーケンスは最大長にパディングし、最終バッチの端数にはダミーを詰めてシェイプを合わせています。

from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional
import numpy as np
from numpy import ndarray
from torch import nn, Tensor, device

 def encode(self, sentences: Union[str, List[str]],
               show_progress_bar = False,
               batch_size: int = 32,
               convert_to_numpy: bool = True,
               convert_to_tensor: bool = False,
               device: str = None,
               normalize_embeddings: bool = False) -> Union[List[Tensor], ndarray, Tensor]:

    self.eval()
    all_embeddings = []

    for start_index in range(0, len(sentences), batch_size):
        batch = sentences[start_index:start_index+batch_size]
        batch_len = len(batch)
        pad_len = batch_size - batch_len
        padding = [""] * pad_len
        batch = batch + padding

        batch = tokenizer.batch_encode_plus(
                batch,
                return_tensors='pt',
                truncation=True,
                padding='max_length',
                max_length = max_length)

        for k in batch:
            batch[k] = batch[k].to(model.device)

        with torch.no_grad():
            outputs = self.forward(**batch, output_hidden_states=True, return_dict=True, sent_emb=True)
        pooler_output = outputs['pooler_output'].cpu()

        embeddings = pooler_output
        embeddings = embeddings[:batch_len]
        all_embeddings.append(embeddings)

    return torch.cat(all_embeddings, axis=0)

import types
model.encode = types.MethodType(encode, model)

こちらも動作確認します。

sentences = ["吾輩は猫である。", "本日は晴天なり。", "どうもありがとうございました。"]

all_embeddings = model.encode(sentences, batch_size = 2)
all_embeddings.shape
# torch.Size([3, 768])

大丈夫そうですね。

データですが、本連載の第1回で用いた海外小説の翻訳者違いデータセットを用います9。 検索対象の文章集合 docs と それに対する検索クエリー集合 queries があり、文章集合中には各クエリーに対応する文章が 1 つずつ存在します。

たぶん、見てもらったほうが分かりやすいですね。queries と docs の件数は同じで。

print(len(queries))
# 918
print(len(docs))
# 918

queries はこんな感じ。

queries[:3]
# ['クリスマスから二日目の朝、私は時候の挨拶をしようと思い、友人のシャーロックホームズ宅を訪問した。',
#  'ホームズは紫の化粧着姿で、ソファにくつろいでいた。',
#  '右手の届く場所にパイプ掛けがあり、近くにはついさっきまで読んでいたらしい朝刊がしわくちゃになって積み上げられていた。']

docs がこんな感じです。

docs[:3]
# ['友人シャーロック・ホームズのもとを、私はクリスマスの二日後に訪れた。時候の挨拶をしようと思ったのだ。',
#  '彼は紫のガウンを着てソファの上でくつろいでいた。',
#  '右手の届くところにパイプ置きがあり、今読んでいるところなのだろう、手元にはぐちゃりと朝刊の山が積まれている。']

InformationRetrievalEvaluator に投入するためにさらに加工します。

 queries = {"q-{:05d}".format(i):q for i, q in enumerate(queries)}
 corpus = {"c-{:05d}".format(i):d for i, d in enumerate(docs)}
 relevant_docs = {"q-{:05d}".format(i):"c-{:05d}".format(i) for i in range(len(queries))}

加工したデータで InformationRetrievalEvaluator を作って、

from sentence_transformers.evaluation import InformationRetrievalEvaluator
ir_evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)

先程、ロードしたモデルを投入して検証します。

!mkdir ./ir_eval
ir_evaluator(model,  output_path="./ir_eval") # 値は cos_sim-MAP@100 か dot_score-MAP@100 の大きい方
# 0.12420449861477104

結果は以下のとおりです。

pd.read_csv("./ir_eval/Information-Retrieval_evaluation_results.csv")

ir_evaluation_result

様々なメトリクスが得られますが、cos_sim-MRR@10 に着目すると 86.8 が得られました。第18回の Sentence BERT で計算すると90.6 でしたので 4 pt 弱劣りますが、健闘しているのではないかと思います。

MRR @10 とあるように検索結果の上位10位までを使ってスコア計算しているので、過去記事とは計算方法が少し異なります。計算方法を合わせて再計算し、改めて比較すると以下のようになりました。青部分は第9回からの転載です。

mrr

JSNLI の Sentence BERT には及ばないものの、WMD などの単語ベクトルを使った手法や TF-IDF よりも良いスコアなので、それなりの品質の文章ベクトルを生成できるようです。

5. おわりに

今回は、教師なしの文章ベクトル生成手法である MixCSE の検証を行いました。文体に特徴があるラベルなしコーパスを対象とする場合など、状況とタスクによっては教師あり手法にまさるケースもあるかもしれませんね。次回は抽象型要約の BRIO をご紹介しようかと思います。抽象型要約は以前から扱ってみたいと思っていたタスクなのですが、 T5 でテキスト変換するだけではベタすぎる、かと言って PEGASUS は Tensorflow のカスタム op を使うので面倒そう、ということで今まで手を出せていませんでした。BRIO は学習方法の工夫でモデルの構造には非依存10のようなので、やってみようかなと思いました。


  1. https://www.aaai.org/AAAI22Papers/AAAI-8081.ZhangY.pdf 

  2. anchor と random negative の1点と考えると話がわかりやすいでしょう。 

  3. 論文1の “The necessity of stop gradient for the mixed negatives” の節に記載があります。 

  4. import でエラーになったので TPU で動かした実績ないのでしょう。おそらくは 「 TPU に対応したコードベースに改修したのでその名残が残っている」くらいかと思います。こちらのコード起点に修正しようとすると検証処理の差し替えとか結構な改造になるので run_mlm.py をベースに models_mix.py を組み込んだほうが早いということになりました。 

  5. https://pytorch.org/docs/stable/data.html#dataset-types 

  6. https://pytorch.org/xla/release/1.11/index.html#xla-tensors-and-bfloat16 

  7. https://github.com/BDBC-KG-NLP/MixCSE_AAAI2022/blob/a3c0ee2166b4607526629d7b5592fd2bee509e7b/mixcse/models_mix.py#L172-L198 このブロックが実行されない場合、 num_cores = 8 かつ per_device_train_batch_size = 8 で 8x8 = 64 のような設定はバッチサイズ 8 で計算した 8 つの結果を平均するような処理になり、num_cores = 1 かつ per_device_train_batch_size = 64 とした場合とは別の計算になってしまうと思います。 

  8. MixCSE で試してないのですが、TPU 8 コアで動かしたくて livedoor News コーパスの分類をしました。8 コアで動いたのですが、(10 epoch 回しても 820 秒で終わるので必要ないんですけど)resume させると、8 個のプロセスが同時並行で state_dict をメモリに展開しようとして OOM になりました。Colab の TPU ランタイムはメモリが 12 GB しかないのでちょっと苦しいですね。 

  9. このデータセットはインターネット上で公開されていたコンテンツを筆者が手作業で取得、加工したもので公開はしてません。 

  10. BRIO の論文ではモデルとして PEGASUS と BART を使ってました。 T5 でもいけるかなと。