オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第31回 OpenAI text-embedding-3-large と Cohere Rerank 3 の精度評価
オージス総研 技術部 データエンジニアリングセンター
鵜野 和也
2024年6月26日

今回は OpenAI text-embedding-3-large と Cohere Rerank 3 の精度評価を行います。OpenAI の既存のテキスト埋込モデルの text-embedding-ada-002 に対し、どの程度精度が向上するか(しないか)を見ていきます。

1. はじめに

早いもので前回の更新から半年たってしまいました。。。

半年と少し前くらいに langchain と Azure OpenAI を使って、 ファイルをアップロードして「要約して」とか「XXについて教えて」とか言うと回答してくれるサンプルを作って上司に見せたら、 いつの間にか話が膨らんで「アレ、本気で使うから。」と。

本気で使うとなると、認証やらデバッグを容易にするためのログ出力やら、なんだかんだとバタバタしてたのが昨年末から今年初めくらい。 その後はアレだコレだといろいろあって気が付けば今です!

で、上述のシステムでも内部では RAG を使った QA を行う訳ですが、「もっと精度を上げるためにはどうしたらよいだろう」ということで、 今回の話につながります。

もちろん個々の具体例に対しては、それぞれ効果のある方策を考える必要があるわけですが、 全体的な底上げという観点ではモデルの更新が一番手っ取り早い方法と言えるでしょう。

そんな訳で今回は OpenAI の text-embedding-3-large と Cohere の Rerank 3 を使って text-embedding-ada-002 に対して精度向上するか見ていきましょう。

3. 今回使用するモデル

さて、今回使用するモデルについて簡単にご紹介しておきましょう。

OpenAI text-embedding-3-large

OpenAI から提供されているテキスト埋込モデルの最新版1で最大3072次元のベクトルを生成できます。

RAG で使用する場合は、生成されたベクトルをベクトル検索 DB に格納して検索することになりますが、 使用するベクトル検索 DB によってはフルサイズの 3072 次元に対応してないものもあるかと思います。 text-embedding-3-large には生成するベクトルの次元数を削減するオプションもあるので、 必要に応じて次元数を削減し、精度、計算量、格納サイズのバランスを取ることが可能なモデルになっています。

今回は 1536, 3072 の 2 パターンの次元数で比較してみました。

Cohere Rerank 3

Cohere から提供されているリランキングモデルの最新版2です。

リランキングモデルというのは、名前のとおりで検索結果の並べ替えをしてくれるモデルです。 具体的に言うと

  1. テキスト埋込モデルを用いたベクトル検索で上位 30 件を取得
  2. 抽出した 30 件から、リランキングモデルを使って上位 10 件を抽出する

という具合です。

テキスト埋込モデルのベクトル検索では単純に「テキストの持つ意味合いが似ているか」という抽出になりますが、 リランキングモデルでは、クエリーと並べ替え対象の文書群のペアを投入することで、 「クエリーに対する応答としてふさわしい文書はどれか」という、より本来の目的に即した抽出が可能になります。

ついでに RankGPT3 との比較もしたいと思います。

同様のもので、オープン系のモデルだと ColBERTv24 あたりでしょうか。 個人的には日本語に対応しており、高精度、かつ LLM での並べ替えより軽量なモデルを欲していたので、Rerank 3 に対する期待が膨らみます。

3. 実験

今回は @warper(Y F) さんが公開されている記事5のコードをほぼそのまま使わせていただきました。

  • 以下に記載するコードは基本的には元記事5からの転載です。一部好みで修正した箇所、筆者が記述した内容に関しては都度注釈を入れてます。

データは尼崎市のFAQデータセット6です。元記事の内容に従って回答データを見てみましたが、単語間に “ ” が入っているようです。

answers_df.head()
#   ID  Answer
#0  0   ■ 市 に は 乳幼児 ( 乳児 ) と その 親 が 集う 場 と して 、 次 の よう...
#1  1   ■ 地域 総合 センター 今 北 に は 、 十分な 駐車 場 が ございませ ん ので 、...
#2  2   ■ 市外 から 市 内 に 引越 し した とき の 届出 です 。 ■ 届出 期限 。 ・...
#3  3   ■ 尼崎 市立 クリーン センター で は 、 尼崎 市 内 から 発生 した 事業 系 一...
#4  4   ■ 巡回 健診 の 会場 の 1 つ と して 、 ハーティ 21 も あり ます 。 「 ...

これが気になったので、ベクトル化する手前で以下のようにして “ ” を除去しました。 よって元記事とのスコアの直接比較はできなくなっているので注意してください。

docs = [Document(page_content=answer.replace(" ", ""), metadata={"id": str(i)}) for i, answer in enumerate(answers)]

ベクトル検索 DB の生成関数です。ほぼ元記事のままですが、モデル違い、次元違いで何パターンか試すので関数化しました。

!mkdir -p ./cache/

def create_chromadb(model, docs, dims=None):
  if dims:
    CACHE_DIR = f"./cache/{model}-{dims}/"
    CHROMA_PATH = f"./chroma_{model}-{dims}"
    embeddings = OpenAIEmbeddings(model=model, dimensions=dims)
    store = LocalFileStore(CACHE_DIR)
    cached_embedder = CacheBackedEmbeddings.from_bytes_store(embeddings, store, namespace=f"{embeddings.model}-{dims}")
  else:
    CACHE_DIR = f"./cache/{model}/"
    CHROMA_PATH = f"./chroma_{model}"
    embeddings = OpenAIEmbeddings(model=model)
    store = LocalFileStore(CACHE_DIR)
    cached_embedder = CacheBackedEmbeddings.from_bytes_store(embeddings, store, namespace=embeddings.model)
  if not Path(CHROMA_PATH).exists():
    print("building chromadb from the docs....")
    chromadb = Chroma.from_documents(docs, cached_embedder, persist_directory=CHROMA_PATH)
  else:
    print("loading chromadb from disk cache....")
    chromadb = Chroma(persist_directory=CHROMA_PATH, embedding_function=cached_embedder)

  return chromadb

Rerank 3 ですが、諸事情により今回は AWS Marketplace の Cohere Rerank 3 Model - Multilingual7 を使用しました。 VMのコストは必要ですが、ソフトウェアの費用は 7 日間のフリートライアル期間があります!(以下、Rerank 3 に関連したコードは私が記述した内容になります。)

Rerank 3 を SageMaker にデプロイするので cohere_aws を使用しました。

from cohere_aws import Client
import boto3
import sagemaker as sage

以下の文字列は AWS Marketplace でサブスクライブした後の設定ページで確認できます。

cohere_package = "cohere-rerank-multilingual-v3--xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"

今回は東京リージョンにデプロイしました。手順は大体以下のような感じです8

model_package_map = {
  "ap-northeast-1" : f"arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/{cohere_package}"
}

region = 'ap-northeast-1'
model_package_arn = model_package_map[region]

co = Client(region_name=region)

co.connect_to_endpoint(endpoint_name="cohere-rerank-multilingual-v3")

co.create_endpoint(
    arn=model_package_arn, 
    endpoint_name="cohere-rerank-multilingual-v3", 
    instance_type="ml.g5.2xlarge", 
    n_instances=1,
    role="SageMaker-XXXX",
)

デプロイが出来たら、動作確認です。

docs = [
 '津波の可能性がある場合はすみやかに高台に避難し、消して海岸には近づかないでください。',
 '日経平均株価は下落の傾向。今週はより一層円安が進むと予想されます。',
 'パスワードを忘れたときは画面下部のリンクをクリックしてメールアドレスを入力してください。再設定手順を記載したメールを送信します。'
]

response = co.rerank(documents=docs, query='パスワードを忘れた場合の再設定方法を教えて', top_n=2)

response
# ※見やすいように手で改行を入れてます。
#[
#  RerankResult<
#    text: パスワードを忘れたときは画面下部のリンクをクリックしてメールアドレスを入力してください。再設定手順を記載したメールを送信します。,
#    index: 2,
#    relevance_score: 0.9904406
#  >,
#  RerankResult<
#    text: 日経平均株価は下落の傾向。今週はより一層円安が進むと予想されます。,
#    index: 1,
#    relevance_score: 3.9751925e-05
#  >
#]

とりあえず、正しく呼び出せているようです。

ではここからようやく実験開始です。モデル名や次元数違いの繰り返しパターンは省略します。

text-embedding-3-large 3072 次元

text-embedding-3-large の 3072 次元だとこんな感じ。各クエリーに対し上位10件を取得します。

model="text-embedding-3-large"
chromadb = create_chromadb(model, docs, dims=3072)
retriever = chromadb.as_retriever(search_kwargs={"k": 10})
run_dict = run_test(retriever)
qrels = Qrels(qrels_dict)
run = Run(run_dict)
evaluate(qrels, run, ["hit_rate@5", "mrr@5", "ndcg@5"])
# {'hit_rate@5': 0.7691326530612245,
#  'mrr@5': 0.6290391156462585,
#  'ndcg@5': 0.5106589647539094}

text-embedding-3-large 3072 次元 + Rerank 3

次は text-embedding-3-large の 3072 に Rerank 3 を適用します。

langchain から Rerank 3 を使用するには CohereRerank クラスを使えばよいのですが、 どうも cohere と cohere-aws の Client は微妙に仕様が違うようだったので、CohereRerank をコピーして、 SageMaker 用のクラスを作りました。

確か self.client.rerank() するところで cohere-aws には model パラメータが定義されておらず、 そのあたりの修正をしたと思います。

from langchain.retrievers import ContextualCompressionRetriever

from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Union

import cohere
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.utils import get_from_dict_or_env

class SageMakerCohereRerank(BaseDocumentCompressor):

    client: Any = None
    top_n: Optional[int] = 3
    user_agent: str = "langchain:partner"

    class Config:
        extra = Extra.forbid
        arbitrary_types_allowed = True

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        return values

    def rerank(
        self,
        documents: Sequence[Union[str, Document, dict]],
        query: str,
        *,
        rank_fields: Optional[Sequence[str]] = None,
        model: Optional[str] = None,
        top_n: Optional[int] = -1,
        max_chunks_per_doc: Optional[int] = None,
    ) -> List[Dict[str, Any]]:
        if len(documents) == 0:  # to avoid empty api call
            return []
        docs = [
            doc.page_content if isinstance(doc, Document) else doc for doc in documents
        ]
        #print(f"docs={docs}")
        top_n = top_n if (top_n is None or top_n > 0) else self.top_n
        results = self.client.rerank(
            query=query,
            documents=docs,
            top_n=top_n,
            rank_fields=rank_fields,
            max_chunks_per_doc=max_chunks_per_doc,
        )
        result_dicts = []
        for res in results.results:
            result_dicts.append(
                {"index": res.index, "relevance_score": res.relevance_score}
            )
        return result_dicts

    def compress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
        compressed = []
        for res in self.rerank(documents, query):
            doc = documents[res["index"]]
            doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
            doc_copy.metadata["relevance_score"] = res["relevance_score"]
            compressed.append(doc_copy)
        return compressed

さて、元記事は上位 10 件をベクトル検索して、それをリランクして最終結果の上位 5 件で評価していますが、 それだと、リランクで 6 ~ 10 位の結果が 5 位以内に滑りこんだ場合のみスコアの押上げ効果がある形になります。

今回は、よりリランキングの効果を得るためにベクトル検索で上位 30 件を取得、それをリランキングして上位 10 件を抽出することにしました。 この条件なら、ベクトル検索では 27 位だったものが、リランキングで 5 位以内に飛び込んでくるケースがあるので、 より効果が期待できるかと思います。

こんな感じで reranker を作ります。リランクで上位 10 件を抽出するので top_n=10 としています。

reranker = SageMakerCohereRerank(client=co, top_n=10,)

検索処理と評価のコードは、{"k": 30}の指定が入っている点、ContextualCompressionRetriever を挟んでいる点を除いて同じですね。

model="text-embedding-3-large"
chromadb = create_chromadb(model, docs, dims=3072)
retriever = chromadb.as_retriever(search_kwargs={"k": 30})
retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
run_dict = run_test(retriever)
qrels = Qrels(qrels_dict)
run = Run(run_dict)
evaluate(qrels, run, ["hit_rate@5", "mrr@5", "ndcg@5"])
# {'hit_rate@5': 0.7244897959183674,
#  'mrr@5': 0.5769345238095237,
#  'ndcg@5': 0.45607721674578217}

text-embedding-3-large 3072 次元 + RankGPT

最後に RankGPT です。LLM には Azure OpenAI の gpt-4-32k を使いました。 LLM でのリランキングは処理時間がかかるので、データセットの先頭 50 件のみを使って評価しています。

RankGPT の実装には同僚にお願いして作ってもらったものを使いました9

llm_reranker
# LlmReranker(
#  chat_model=AzureChatOpenAI(
#    client=<openai.resources.chat.completions.Completions object at 0x7fe12e3aee50>,
#    async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7fe12e3b9350>, model_name='gpt-4-32k',
#    temperature=0.01, model_kwargs={'top_p': 0.11}, openai_api_key=SecretStr('**********'), openai_proxy='', max_tokens=1024,
#    azure_endpoint='https://xxxxxxxx.openai.azure.com', deployment_name='xxxxxxxx', openai_api_version='2023-07-01-preview',
#    openai_api_type='azure'
#  ),
#  max_context_length=32768, max_tokens=1024, top_k=10, window_size=None, overlap=0, discard_trailing_tokens=True
#)

元記事の run_test() を少し改造して、num_samples=50, wait=30(1 サンプル毎に 30 秒スリープする) の指定ができるようにしています。

model="text-embedding-3-large"
chromadb = create_chromadb(model, docs, dims=3072)
retriever = chromadb.as_retriever(search_kwargs={"k": 30})
retriever = ContextualCompressionRetriever(base_compressor=llm_reranker, base_retriever=retriever)
run_dict = run_test(retriever, num_samples=50, wait=30)
qrels = Qrels(qrels_dict50)
run = Run(run_dict)
evaluate(qrels, run, ["hit_rate@5", "mrr@5", "ndcg@5"])
# {'hit_rate@5': 0.78, 'mrr@5': 0.6140000000000001, 'ndcg@5': 0.518597735068129}

4. 実験結果

以下、最終結果です。まず全件での上位5件のヒット率の比較を見てみましょう。

all

  • :text-embedding-3-large が優秀ですね。データによるところもあるかもしれませんが、計算量や格納サイズまで考慮すると text-embedding-3-large の次元数は半分サイズの 1536 でもよさそうです。
  • オレンジ:Cohere Rerank 3 を適用すると精度が下がってしまいました。。。期待してたので、これは残念です。コードは問題ないと思うのですが。。。

次に先頭50件のデータで LLM によるリランキングを加えた比較です。

50samples

  • :サンプル数が少ないせいか text-embedding-ada-002 のスコアがだいぶ下振れ( 0.74 -> 0.68 )しています。また text-embedding-3-large は 1536 次元と 3072 次元が逆転し 3072 次元が最良になりました。
  • オレンジ:Rerank 3 の結果を見ると text-embedding-ada-002 はスコアが伸びています( 0.68 -> 0.74 )が text-embedding-3-large は 1536, 3072 次元共にスコアを落としています。Rerank 3 の日本語性能は「イマイチをまぁまぁにはできる」けれども「まぁまぁを Excellent にするには力不足」という感じなのかもしれません。
  • :LLM によるリランキングは強力で text-embedding-ada-002, text-embedding-3-large の 1536, 3072 次元の全てスコアが伸びています。text-embedding-ada-002 が最良値(0.800)になっていますが、これは text-embedding-ada-002 ベクトル検索時の下位サンプルが text-embedding-3-large よりもうまい具合にバラついて正解を拾い、LLM がそれを 5 位以内に押し上げた結果かと思います。

Hit-Rate@5, MRR@5 のスコア一覧は以下のとおりです。

全件

all

先頭50件

50samples

5. おわりに

今回は最近のテキスト埋込モデルとリランキングモデルについて検証をしてみました。

とりあえずは text-embedding-3-large を 1536 次元あたりで使うのを基本線として、どうしても上積みが欲しい局面で RankGPT を使う感じでしょうか。 速度的には GPT-4o である程度の高速化が期待できますし、Claude 3 の haiku, sonnet を使ってみてもよいかもしれませんね。

少し駆け足になりましたが、今回はこれで終わりです。また何かしら書きたいことができたら記事にしたいと思います。

… おっといけない。

co.delete_endpoint()
co.close()

SageMaker の Rerank 3 エンドポイントを止め忘れてエライことになるところでした。皆さんもお気を付け下さいませ。


  1. https://openai.com/index/new-embedding-models-and-api-updates/ 

  2. https://cohere.com/blog/rerank-3 

  3. https://arxiv.org/abs/2304.09542 

  4. https://arxiv.org/abs/2112.01488 

  5. https://qiita.com/warper/items/95089ed57f1b88f29381 

  6. https://github.com/ku-nlp/bert-based-faqir 

  7. https://aws.amazon.com/marketplace/pp/prodview-ydysc72qticsw 

  8. SageMaker はざっくりセットアップ済みの状態の前提です。あと使用した AWS 環境がセキュリティ的に厳しい構成になっていたので、cohere_aws そのままでは接続できずにゴニョゴニョと。。。みなさんは普通に Cohere でサインアップして cohere パッケージを使っていただくのが一番ラクではないかと思われますー。 

  9. ソースだせなくてごめんなさい。プロンプトは日本語にしましたが、ほぼ論文3のままの実装です。ですが、langchain 0.2.1 には langchain_community.document_compressors.rankllm_rerank.RankLLMRerank として実装が入ってるので、そちらを使ってもらうのがよいかと思います。