オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第26回 Rasa に transformers を組み込んでチャットボットを作ってみる
オージス総研 技術部 データエンジニアリングセンター
鵜野 和也
2023年4月25日

今回は OSS の Rasa でチャットボットを作ります。普通に作るだけでは、つまらないので transformers を組み込んでみました。また、Rasa を初めて触る方向けにその概要とチャットボットを作るときのコツ的なところもご紹介できればと思います。

1. はじめに

今回は OSS の Rasa1 を使ってチャットボットを作ります。

「世の中は ChatGPT が大きな話題をさらい、GPT-4, Stanford Alpaca, LoRA, RLHF, LLM.int-8, …等々。このタイミングで Rasa ! 何故?」と思われるかもしれませんが。。。

いやぁ、スゴイですよね、ChatGPT。 とはいえ「なんでもかんでも ChatGPT で作ればOK」という訳ではない(といいな)と思うので、 利用シーンを限定した局地戦に持ち込みつつ、状況に応じて ChatGPT (あるいはその亜種)も利用しつつ、という感じになるでしょうか。

ええっと Rasa の話でしたね。

今回は Rasa を普通に使ってもつまらないので内部で使用するモデルに transformers を利用する方法を解説します。 あとは、意外と身の回りでも「Rasa が思った通りに動いてくれない。。。」みたいな嘆きが聞こえてきたりしたので、その辺りのコツみたいなものも合わせてご紹介していきます。

それではまずは Rasa について軽く紹介していきましょう。

2. Rasa

Rasa はドイツの Rasa 社が開発・公開するオープンソースのチャットボット構築フレームワークです。 以下は全体アーキテクチャの図ですが、かなりの部分がオープンソースで公開されています。

archtecture

今回はオープンソース部分(Rasa Open Source と Rasa SDK) に絞って記述しますが、 はみ出すところが気になるので最初にそちらを見ていきましょうか。

Rasa の OSS じゃないところ

上図を見ていると重要そうな部品が欠けているように見えるので一つずつ見ていきます。

Tracker Store

これはユーザとチャットボット(以下、ボットと表記)の会話履歴を保持するコンポーネントです。デフォルトでは InMemoryTrackerStore が使用されますが、 ドキュメントを見ると以下のような実装が OSS として提供されているので大丈夫そうですね。

  • SQLTrackerStore
  • DynamoTrackerStore
  • RedisTrackerStore

LockStore

これは Rasa サーバを複数立てたとき、ある会話を複数サーバが同時に扱うことを防ぐロックを提供するコンポーネントです。 デフォルトでは InMemoryLockStore が使われます。他の実装として、以下が提供されており、ConcurrentRedisLockStore は有償ですが、 無印の RedisLockStore は OSS です。ですので、無印 RedisLockStore を使えばオープンソース版だけでも複数サーバ並行起動の運用はできそうです。

  • RedisLockStore
  • ConcurrentRedisLockStore (Rasa Pro Only)

FileSystem

これ、何でしょうね。。。多分、モデルの配置場所である Model Storage の話だと思います。 ローカルディスクから読み込んだり、Rasa サーバから取得したり、S3, GCS, Azure Storage から取得したりとオープンソースの範疇で色々用意されています。

そんな訳で有償版の Rasa Pro がないと使い物にならないという訳ではなさそうな気がします。

ひと安心したところで、Rasa の基本的な動作イメージと対応するコンポーネントを見ていきましょう。

Rasa の基本的な動作イメージ

あー、そうそう、今回は細かい説明を思い切って端折るので2、この辺りの記事3,4も参考にして頂くと良いと思います。

以下は Rasa の大まかな動作イメージになります。

rasa

大きな構造として、Rasa NLU, Rasa Core の二つがあります。順番に見ていきましょう。

Rasa NLU

Rasa NLU に関しては本連載の第2回で取り扱いましたね。メジャーバージョンアップを経ているので、だいぶ変わったところも あるでしょうが、その重要な役割は NLU(Natural Language Understanding)の名の通り、メッセージに含まれるユーザの意思を抽出するところにあります。

ユーザの意図は intent と entities として抽出されます。 Rasa でこれらを抽出するコンポーネントがそれぞれ IntentClassifier, EntityExtractor になります。

IntentClassifier

IntentClassifier の実態は文章分類モデルであり、ユーザから送られたメッセージテキストの分類結果が intent つまりユーザの意図として扱われます。 感覚的にはプログラミングしていて呼び出す関数を決めるような感じでしょうか(実際はちょっと違うのですが5)。

Rasa では IntentClassifier の実装として以下のような実装が提供されています(一部抜粋)。

  • LogisticRegressionClassifier : scikit-learn を使った実装
  • SklearnIntentClassifier : Sklearn の linear SVM を使った実装
  • KeywordIntentClassifier : キーワードマッチングによる簡易な実装
  • DIETClassifier : Transformer ベースの実装で EntityExtractor の責務を兼ねるもの

今回は後でボットのサンプルを作るのですが、 IntentClassifier として transformers の BertForSequenceClassification を使ったものを作ってみることにします。

EntityExtractor

EntityExtractor の実態は固有表現抽出のモデルです。抽出されたエンティティ(固有表現)は言ってみれば関数呼び出しの引数としての役割を持ちます。

Rasa では EntityExtractor の実装として以下のような実装が提供されています(一部抜粋)。

  • SpacyEntityExtractor : SpaCy を使った実装。日本語の場合は GiNZA を利用することが可能。
  • CRFEntityExtractor : CRF を使った実装。学習可能なので独自のエンティティを定義できます。
  • DucklingEntityExtractor : Duckling6 を用いて数値や日付を抽出します。
  • DIETClassifier : 上述
  • RegexEntityExtractor: 正規表現を使った実装。

こちらは BertForTokenClassification を使った実装を紹介しようと思ったのですが、どうしてもソースが長くなるのでやめました。 (「IntentClassifier のやり方がわかれば大丈夫でしょ」的な。)

その代わりに後述するサンプルでは DucklingEntityExtractor を使ってみました(んですが。。。)。

ついでに他のコンポーネントにも言及しましょう。

Tokenizer

これは説明要りませんよね。テキストをトークンに分割します。ただし後続のコンポーネントとの兼ね合いも考慮して選択することになるかと思います。

Featurizer

こちらはテキストレベル、トークンレベルで特徴量を抽出するためのコンポーネントです。Featurizer で抽出した特徴量が、 SklearnIntentClassifier や CRFEntityExtractor, 後述する TEDPolicy への入力として使われます。

標準のコンポーネントの一つとして LanguageModelFeaturizer が提供されており、これは内部で transformers を使っています。 普通に Rasa で transformers を使うなら、このコンポーネントを使うことになります。

BERT や GPT に対応しており、生成された埋め込み表現が後続のコンポーネントに流れ込みます。 ただ、モデルに対する Tokenizer がガッツリとハードコードされている7ので、そのままでは一部の日本語モデルが動作しないのと、 文章分類への入力としては BERT なら “[CLS]” トークンの埋め込み8、GPT なら全トークンの平均9が返るので、 分類モデルへの入力としてはどうなんだろうなと感じました。

そんな訳で後述するサンプルでは Sentence BERT を使ってテキストレベルの特徴量を抽出する Featurizer を作ってみます。

Rasa Core

Rasa Core は Rasa NLU で抽出された intent, entities, features 等とこれまでのユーザとボットのやり取りから、 ボットの次のアクションを決定する役割を担います。

この次のアクションを決める責務を担っているのが Policy です。

Policy

Policy は複数同時に使用することができ、以下のような実装が提供されています。

  • TED Policy : Transformer ベースの実装です。これについては少し詳しく後述しますね。
  • Rule Policy : ルールベースの実装です。設定されたルールに従ってアクションを選択します。
  • Memoization Policy : 学習データを記憶し、推論時の状況が記憶にマッチすれば記憶したとおりのアクションを選択します。
  • UnexpecTED Intent Policy : これは補助的な Policy で、会話の流れを踏まえて発生しそうな intent を学習します。Rasa NLU が推論した intent が「ありそう」と判断すれば何もせず、「なさそう」と思えば action_unlikely_intent という特殊なアクションを選択します。今回は使いませんでした。

また以下のような優先度があり、数字が大きいと優先度が高くなります。 複数の Policy が異なるアクションを選択した場合は、選択の確信度と優先度によって最終的な決定がなされます

  • 6 - RulePolicy
  • 3 - MemoizationPolicy or AugmentedMemoizationPolicy
  • 2 - UnexpecTEDIntentPolicy
  • 1 - TEDPolicy

RulePolicy と MemoizationPolicy は常に確信度 1.0 で予測結果を返すので、最高の優先度を持つRulePolicy のルールにマッチすれば、必ずそのアクションが実行されることになります。

さて、個人的に TEDPolicy がキモだと思うので、少し詳しく説明します。

3. TEDPolicy

TEDPolicy は Transformer Embedding Dialogue (TED) Policy の略で Transformer を使った Policy の実装です。

機能的には次のアクションの予測とエンティティの抽出のマルチタスクに対応していますが、 本章では次のアクションを選択する側面について見ていきます。

論文10や YouTube11 のコンテンツもあるので、より詳しく知りたい方はそちらを参照して下さい。

以下は TEDPolicy が “I like a pizza.” というメッセージを受けて、次のアクションを予測するときのイメージです。

rasa

まずは、Rasa NLU で “I like a pizza.” から抽出した intent(“order”) や entity(“pizza”) を One-hot ベクトル化します。 slot はユーザと Rasa のやり取りで拾い上げた情報を記憶しておく為の変数のようなものです。 prev action は直前のやり取りで選択したアクションですね。 この二つも One-hot にして全て連結して特徴量 ft を得ます(実際のところはもっと複雑12です)。

ここで ft は BERT に例えると入力シーケンスを構成する 1 トークンの埋め込み表現に相当します。

つまり、ユーザとボットの間で 0~3 の 4 ステップのやり取りがあった場合は、その特徴量をまとめた、 [f0, f1, f2, f3] が単方向 Transformer への入力になり、 次アクションの予測シーケンス [f0, f1, f2, f3] を得ます。

Transformer が単方向なのは自然言語処理の Decoder と同じで次のアクションを予測する時に未来情報を参照することを抑制する為ですね。

rasa

Loss の計算の仕方は第2回で紹介した StarSpace をベースにしたものになっています。 次のアクションを選択する分類問題なのですが、採りうる全てのアクションの埋め込み表現を学習パラメータとして保持していて、 Transformer の予測 fn の埋め込み表現と正解アクションの埋め込み表現の類似度が高くなるように、 各アクションの埋め込み表現がいい感じにバラけるようにパラメータを更新します。

実際に動かした方が分かりやすいと思うので、ここからは Colab 上で簡単なデモを構築して行くことにしましょう。

4. 環境のセットアップ

今回も Colab で動かす想定でコードを記載していきます。ランタイムのアクセラレータは GPU にして下さい。

まずは必要なパッケージをインストールしましょう。今回は transformers を利用したカスタムコンポーネントを作ることにしたので、 以下をインストールしました。

!pip install transformers==4.23.* \
             datasets==2.* \
             evaluate==0.4.* \
             mecab-python3==1.0.* \
             fugashi==1.* \
             ipadic==1.0.* \
             sentence-transformers==2.2.2

Rasa 関連をインストールします。transformers と Rasa を Colab で動かすためにアレコレと調整したら、 以下のようなバージョンの構成になりました。

!pip install rasa==3.4.2 \
             tzlocal==2.1 \
             jedi==0.16 \
             pluggy==1.0 \
             prompt-toolkit==3.0.28 \
             ipython==8.8.0 \
             nest_asyncio==1.5.6

実際動かすと packaging 関係でなにやらエラーが出たので、個別に 20.9 をインストールします。

!pip install packaging==20.9

次にエンティティの抽出に使う Duckling です。意外と時間がかかります。

!wget -qO- https://get.haskellstack.org/ | sh
!git clone https://github.com/facebook/duckling
!cd duckling && stack build

学習済みの日本語 Sentence BERT をコピーします。お手元にない人は第18回を参考に作ってみて下さい。

!gsutil cp -r gs://somewhere/sbert .

一旦、ここで再起動します。

5. カスタムコンポーネントの作成

とりあえずディレクトリを作っておきます。

!mkdir my_rasa

以下は Sentence BERT を使用した Featurizer です。

SbertFeaturizer

Sentence BERT を使ってテキストを特徴量にするカスタムコンポーネントです。学習がないので比較的シンプルですね。

%%writefile /content/my_rasa/sbert_featurizer.py
from functools import lru_cache
import numpy as np
import typing
import logging
from typing import Any, Text, Dict, List, Type

from rasa.engine.recipes.default_recipe import DefaultV1Recipe
from rasa.engine.graph import ExecutionContext, GraphComponent
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.nlu.training_data.features import Features
from rasa.shared.nlu.training_data.message import Message
from rasa.nlu.constants import (
    DENSE_FEATURIZABLE_ATTRIBUTES,
    FEATURIZER_CLASS_ALIAS,
)
from rasa.shared.nlu.constants import TEXT, FEATURE_TYPE_SENTENCE

from sentence_transformers import SentenceTransformer
import torch

logger = logging.getLogger(__name__)

@DefaultV1Recipe.register(
  DefaultV1Recipe.ComponentType.MESSAGE_FEATURIZER, is_trainable=False
)
class SbertFeaturizer(DenseFeaturizer, GraphComponent):

  @staticmethod
  def required_packages() -> List[Text]:
    return ["sentence_transformers"]

  @staticmethod
  def get_default_config() -> Dict[Text, Any]:
    return {
      **DenseFeaturizer.get_default_config(),
      "model_path": None
    }

  def __init__(self, config: Dict[Text, Any], name: Text) -> None:
    super().__init__(name, config)
    self._load_model_instance()

  @classmethod
  def create(
    cls,
    config: Dict[Text, Any],
    model_storage: ModelStorage,
    resource: Resource,
    execution_context: ExecutionContext,
  ) -> GraphComponent:
    return cls(config, execution_context.node_name)

  def _load_model_instance(self):
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.sbert = SentenceTransformer(self._config["model_path"])
    self.sbert.to(self.device)

  @lru_cache(maxsize=512)
  def _encode(self, text):
    return self.sbert.encode([text]).astype(float)

  def process(self, messages: List[Message]) -> List[Message]:
    for message in messages:
      for attribute in DENSE_FEATURIZABLE_ATTRIBUTES:
        self._set_sentemb_features(message, attribute)
    return messages

  def process_training_data(self, training_data: TrainingData) -> TrainingData:
    self.process(training_data.training_examples)
    return training_data

  def _set_sentemb_features(self, message: Message, attribute: Text = TEXT) -> None:
    text = message.get(attribute)

    if text is None or len(text) == 0:
      return

    sentence_features = self._encode(text)

    final_sentence_features = Features(
      sentence_features,
      FEATURE_TYPE_SENTENCE,
      attribute,
      self._config[FEATURIZER_CLASS_ALIAS],
    )
    message.add_features(final_sentence_features)

  @classmethod
  def validate_config(cls, config: Dict[Text, Any]) -> None:
    pass

ポイントは以下のとおりです。

  • 学習可能なコンポーネントではないので DefaultV1Recipe.register の is_trainable を False としてコンポーネント登録します。
  • def get_default_config() -> Dict[Text, Any] を実装して、このコンポーネントの設定パラメータとそのデフォルト値を定義します。
  • def process(self, messages: List[Message]) -> List[Message] を実装して、特徴量を抽出するコードを記述します。1 件ずつ処理するナイーブな実装ですが、同じテキストの時は Sentence BERT の呼び出しを回避できるよう @lru_cache(maxsize=512) とかして見ました。TEDPolicy にはこのコンポーネントで生成した特徴量も流れ込むので、品質の良い特徴量を使用することで精度向上に寄与するかもしれません13
  • def process_training_data(self, training_data: TrainingData) -> TrainingData は process() を通すように記述しただけですね。このコンポーネント自体は学習対象ではありませんが、後続の学習にこのコンポーネントが生成する特徴量を反映させる為の物だと思います。

次は BERT を使用した IntentClassifier です。

BertIntentClassifier

少し長いですが、transformers の BertForSequenceClassification を使って文章分類するカスタムコンポーネントです。

%%writefile /content/my_rasa/bert_intent_classifier.py
from __future__ import annotations
from datasets import Dataset, load_metric
from functools import partial
import glob
import logging
import math
import os
import re
import shutil
import sys
from typing import Any, Dict, Optional, Text, List
from rasa.engine.graph import GraphComponent, ExecutionContext
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage
from rasa.shared.constants import DOCS_URL_COMPONENTS
from rasa.nlu.classifiers.classifier import IntentClassifier
from rasa.shared.nlu.constants import (
  INTENT,
  TEXT,
  INTENT_NAME_KEY,
  PREDICTED_CONFIDENCE_KEY,
)
import rasa.shared.utils.io
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.nlu.training_data.message import Message
from sklearn.model_selection import train_test_split
import numpy as np
import torch
from torch import nn
import transformers
from transformers import Trainer, TrainingArguments, HfArgumentParser, set_seed, TrainerCallback
from transformers import AutoTokenizer, BertForSequenceClassification, TrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from transformers.modeling_utils import PreTrainedModel
from dataclasses import dataclass, field

logger = logging.getLogger(__name__)

###################################
#  Define functions for training
###################################

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default='cl-tohoku/bert-base-japanese-whole-word-masking',
                                          metadata={"help": "The name or path of the pretrained model.."})

@dataclass
class DataArguments:
    max_seq_len: Optional[int] = field(default=512, metadata={"help": "The maximum sequence length."})

def parse_args_fn_template(config, output_dir):
  training_args = TrainingArguments(output_dir=output_dir)
  training_args.do_train = True
  training_args.per_device_train_batch_size = config.get("batch_size")
  training_args.per_device_eval_batch_size = config.get("batch_size")
  training_args.num_train_epochs = config.get("epochs")
  training_args.save_strategy = "epoch"
  training_args.evaluation_strategy = "epoch"
  training_args.logging_strategy = "epoch"
  training_args.load_best_model_at_end = True
  training_args.metric_for_best_model = "loss"
  training_args.greater_is_better = False
  training_args.fp16 = config.get("fp16")

  model_args = ModelArguments(
    model_name_or_path = config.get("model_name_or_path"),
  )
  data_args = DataArguments(
    max_seq_len = config.get("max_seq_len")
  )

  return training_args, model_args, data_args

def load_tokenized_dataset_fn_template(training_args, model_args, data_args,
        train_examples, dev_examples, test_examples):
  label_set = set([example[0] for example in train_examples])
  label2id = {label:i for i, label in enumerate(label_set)}
  id2label = {v:k for k, v in label2id.items()}

  logger.info("label2id : {}".format(label2id))
  meta_info = {"label2id": label2id, "id2label": id2label}

  is_sentence_pair = len(train_examples[0]) == 3

  if is_sentence_pair:
    train_dataset = [[[example[1], example[2]], label2id[example[0]]] for example in train_examples]
    dev_dataset = [[[example[1], example[2]], label2id[example[0]]] for example in dev_examples]
    if training_args.do_eval :
      test_dataset = [[[example[1], example[2]], label2id[example[0]]] for example in test_examples]
  else:
    train_dataset = [[example[1], label2id[example[0]]] for example in train_examples]
    dev_dataset = [[example[1], label2id[example[0]]] for example in dev_examples]
    if training_args.do_eval :
      test_dataset = [[example[1], label2id[example[0]]] for example in test_examples]


  tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, return_dict=False)

  train_dataset = Dataset.from_dict({"inputs": [example[0] for example in train_dataset],
                                   "labels": [example[1] for example in train_dataset],})
  dev_dataset = Dataset.from_dict({"inputs": [example[0] for example in dev_dataset],
                                  "labels": [example[1] for example in dev_dataset],})
  if training_args.do_eval :
    test_dataset = Dataset.from_dict({"inputs": [example[0] for example in test_dataset],
                                  "labels": [example[1] for example in test_dataset],})

  def tokenize_function(examples):
    inputs = [input for input in examples["inputs"]]  
    features = tokenizer.batch_encode_plus(inputs, padding="max_length", truncation=True, max_length=data_args.max_seq_len)
    return features

  tokenized_dataset_train = train_dataset.map(tokenize_function, batched=True, remove_columns=['inputs'])
  tokenized_dataset_train.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
  del train_examples, train_dataset

  tokenized_dataset_dev = dev_dataset.map(tokenize_function, batched=True, remove_columns=['inputs'])
  tokenized_dataset_dev.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
  del dev_examples, dev_dataset

  tokenized_dataset_test = None
  if training_args.do_eval :
    tokenized_dataset_test = test_dataset.map(tokenize_function, batched=True, remove_columns=['inputs'])
    tokenized_dataset_test.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
    del test_examples, test_dataset

  meta_info["num_train_examples"] = len(tokenized_dataset_train)
  meta_info["batch_size"] = (training_args.train_batch_size *
                              training_args.gradient_accumulation_steps *
                                 training_args.world_size)

  return tokenizer, tokenized_dataset_train, tokenized_dataset_dev, tokenized_dataset_test, meta_info


def run(parse_args_fn, load_tokenized_dataset_fn):
  training_args, model_args, data_args = parse_args_fn()

  logging.basicConfig(
      format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
      datefmt="%m/%d/%Y %H:%M:%S",
      handlers=[logging.StreamHandler(sys.stdout)],
  )
  logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 
  logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  )

  if is_main_process(training_args.local_rank):
    transformers.utils.logging.set_verbosity_info()
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
  logger.info("Training/evaluation parameters %s", training_args)

  set_seed(training_args.seed)

  logger.info("Loading and tokenizing datasets...")
  tokenizer, tokenized_dataset_train, tokenized_dataset_dev, tokenized_dataset_test, meta_info = load_tokenized_dataset_fn(
          training_args, model_args, data_args)

  logger.info("Loading model...")
  label2id = meta_info["label2id"]
  id2label = meta_info["id2label"]
  model = BertForSequenceClassification.from_pretrained(model_args.model_name_or_path, num_labels=len(label2id))
  model.config.label2id = label2id
  model.config.id2label = id2label

  logger.info("Building metrics function...")
  metric = load_metric("f1")
  def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels, average="weighted")

  logger.info("Building Trainer...")
  trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset_train,
    eval_dataset=tokenized_dataset_dev,
    tokenizer = tokenizer,
    compute_metrics = compute_metrics,
    callbacks = None
  )

  if training_args.do_train:
    logger.info("Start training...")
    train_result = trainer.train()
    trainer.save_model()  # Saves the tokenizer too for easy upload

    output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
    if trainer.is_world_process_zero():
      with open(output_train_file, "w") as writer:
        print("***** Train results *****")
        for key, value in sorted(train_result.metrics.items()):
          print(f"  {key} = {value}")
          writer.write(f"{key} = {value}\n")

      trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))

def build_batch(xs, batch_size):
  xbs = []
  num_batch = math.ceil(len(xs) / batch_size)
  for i in range(num_batch):
    head = i * batch_size
    tail = head + batch_size
    xb = xs[head:tail]
    xbs.append(xb)
  return xbs

def build_features(texts, tokenizer, max_len=512, device="cpu"):
  features = tokenizer([s.lower() for s in texts], add_special_tokens=True,
                 truncation=True, max_length=max_len, padding=True)
  return {k:torch.tensor(v).to(device) for k, v in features.items()}

##############################
#  Define Intent Classifier
##############################

@DefaultV1Recipe.register(
  DefaultV1Recipe.ComponentType.INTENT_CLASSIFIER, is_trainable=True
)
class BertIntentClassifier(GraphComponent, IntentClassifier):
  @staticmethod
  def get_default_config() -> Dict[Text, Any]:
    return {
      "model_name_or_path": "cl-tohoku/bert-base-japanese-whole-word-masking",
      "max_seq_len": 512,
      "batch_size": 8,
      "epochs": 5,
      "fp16": True,
      "test_size": 0.1,
      "clear_checkpoints": True
    }


  def __init__(
    self,
    config: Dict[Text, Any],
    model_storage: ModelStorage,
    resource: Resource,
    execution_context: ExecutionContext,
    tokenizer: Optional[PreTrainedTokenizer] = None,
    model: Optional[PreTrainedModel] = None,
  ) -> None:
    self.component_config = config
    self._model_storage = model_storage
    self._resource = resource
    self._execution_context = execution_context
    self._tokenizer = tokenizer
    self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if model is not None:
      model.eval()
      model.to(self._device)
    self._model = model


  @classmethod
  def create(
    cls,
    config: Dict[Text, Any],
    model_storage: ModelStorage,
    resource: Resource,
    execution_context: ExecutionContext,
  ) -> BertIntentClassifier:
    return cls(config, model_storage, resource, execution_context)


  def train(self, training_data: TrainingData) -> Resource:
    with self._model_storage.write_to(self._resource) as model_dir:

      examples = [[ex.get(INTENT), ex.get(TEXT)] for ex in training_data.intent_examples]
      train_examples, dev_examples = train_test_split(
                                       examples, test_size=self.component_config.get("test_size"))

      output_dir = str(model_dir)
      parse_args_fn = partial(parse_args_fn_template, config=self.component_config, output_dir=output_dir)
      load_tokenized_dataset_fn = partial(load_tokenized_dataset_fn_template, train_examples=train_examples,
                                            dev_examples=dev_examples, test_examples=None)

      run(
        parse_args_fn,
        load_tokenized_dataset_fn
      )

      if self.component_config.get("clear_checkpoints"):
        ckpts = glob.glob(f"{output_dir}/checkpoint-*")
        for c in ckpts:
          logger.debug(f"Removing checkpoint {c}...")
          shutil.rmtree(c)

    return self._resource


  def predict(self, xs):
    all_preds = []
    all_confs = []
    with torch.no_grad():
      batch_size = self.component_config["batch_size"]
      xbs = build_batch(xs, batch_size)
      for xb in xbs:
        features = build_features(xb, tokenizer=self._tokenizer, device=self._device)
        logits = self._model(**features).logits
        probs = nn.functional.softmax(logits, dim=1)
        preds = torch.argmax(probs, axis=1)
        confs = torch.squeeze(torch.gather(probs, 1, torch.unsqueeze(preds, -1)), dim=0)
        all_preds.append(preds)
        all_confs.append(confs)

    all_preds = torch.concat(all_preds)
    all_confs = torch.concat(all_confs)
    all_preds = [self._model.config.id2label[pred] for pred in all_preds.cpu().numpy()]
    all_confs = all_confs.cpu().detach().numpy().astype(float)
    return all_preds, all_confs


  def process(self, messages: List[Message]) -> List[Message]:
    texts = [message.get(TEXT) for message in messages]
    intents, confidences = self.predict(texts)
    for i, message in enumerate(messages):
      intent_name = intents[i]
      confidence = confidences[i]
      intent = {
        INTENT_NAME_KEY: intent_name,
        PREDICTED_CONFIDENCE_KEY: confidence,
      }
      if message.get(INTENT) is None or intent_name is not None:
        message.set(INTENT, intent, add_to_output=True)
        logger.debug(f"text={message.get(TEXT)}, intent={intent_name}, "
                      "confidence={confidence}")

    return messages


  @classmethod
  def load(
    cls,
    config: Dict[Text, Any],
    model_storage: ModelStorage,
    resource: Resource,
    execution_context: ExecutionContext,
    **kwargs: Any,
  ) -> BertIntentClassifier:
    try:
      with model_storage.read_from(resource) as model_dir:
        logger.debug(f"loading tokenizer from {model_dir}...")
        tokenizer = AutoTokenizer.from_pretrained(model_dir)
        config_path =  model_dir / "config.json"
        with open(config_path, "r") as f:
          config_data = f.read()
        logger.debug(f"loading BertForSequenceClassification from {model_dir}...")
        model = BertForSequenceClassification.from_pretrained(model_dir)
    except ValueError:
      logger.warning(
                f"Failed to load {cls.__class__.__name__} from model storage. Resource "
                f"'{resource.name}' doesn't exist."
      )
      tokenizer = None
      model = None

    logger.debug(f"Loading tokenizer and model were completed.")
    return cls(
      config, model_storage, resource, execution_context, tokenizer, model
    )

基本的には IntentClassifier に transformers の Trainer による学習を組み込んだだけです。 全ての学習サンプルを最大長に合わせてパディングしてたり手抜き箇所はありますが Rasa と transformers を組み合わせるサンプルとしては十分かなと思います。

ポイントは以下のとおりです。

  • 学習可能なコンポーネントなので DefaultV1Recipe.register の is_trainable を True としてコンポーネント登録します。
  • def get_default_config() -> Dict[Text, Any] を実装して、このコンポーネントの設定パラメータとそのデフォルト値を定義します。
  • def train(self, training_data: TrainingData) -> Resource を実装して、学習の処理を記述します。コンストラクタで渡される ModelStorage と Resource のインスタンスを保持しておき、学習したモデルを ModelStrage を使って Resouce に書き込むのがポイントです。書き込んだ結果は Rasa が他のコンポーネントの出力結果とまとめて tarball にまとめてくれます。
  • def process(self, messages: List[Message]) -> List[Message] では学習したモデルを使って intent の分類を行い Message に書き込みます。
  • def load(cls, config: Dict[Text, Any], model_storage: ModelStorage, resource: Resource, execution_context: ExecutionContext, **kwargs: Any) -> BertIntentClassifier ではパラメータとして渡された Resouce から ModelStorage を使って学習済みのモデルを読み込む処理を記述します。

EntityExtractor も同様のスタイルで BertForTokenClassification ベースの実装を作ることが出来ます。IntentClassifier と合わせて 2 回 BERT を動かすのももったいないので、DIETClassifier のようにして見ても良いかもしれません。

学習後のモデルをコンバートして Triton にデプロイし、Rasa 側から学習なしのコンポーネントで Triton 上のモデルを叩いても良いでしょうね。

ここからは Rasa でボットを作っていきますが、まずは Colab で動かすためのおまじないです。

6. Rasa のプロジェクトの作成と設定

今回は会議予約(風)のボットを作ってみましょう。ざっくりとした流れは、

  • 挨拶 : 例「おはよう。」
  • 会議の予約を依頼 : 例「会議の予約をお願いします。」
  • 日時の指定 : 例「明日にして下さい。」
  • 人数の指定 : 例「6人にして下さい。」
  • 会議時間の指定 : 例「1時間にして下さい。」
  • お礼 : 例「ありがとう。」

といった感じです。

以下のようにして Python のコード探索パスの更新とイベントループのセットアップを行います。

import sys
sys.path.append(".")

import os
import rasa
import nest_asyncio
nest_asyncio.apply()
print("event loop is ready.")
# event loop is ready.

以下のようにして新規プロジェクトを作成します。

from rasa.cli.scaffold import create_initial_project
project = "rasa_3x"
create_initial_project(project)

プロジェクトディレクトリの内容は以下のようになっています。

os.chdir(project)
print(os.listdir("."))
# ['config.yml', 'domain.yml', 'data', 'endpoints.yml', 'credentials.yml', 'actions', 'tests']

config.yml

次はボットで使用するコンポーネントの構成を決める config.yml の記述です。

%%writefile /content/rasa_3x/config.yml
recipe: default.v1
language: ja

pipeline:
  - name: my_rasa.sbert_featurizer.SbertFeaturizer
    model_path: /content/sbert
  - name: my_rasa.bert_intent_classifier.BertIntentClassifier
    epochs: 5
    clear_checkpoints: true
  - name: "DucklingEntityExtractor"
    url: "http://localhost:8000"
    dimensions: ["time", "number", "duration"]
    locale: "ja_JP"
    timezone: "Asia/Tokyo"
    timeout : 3
  - name: EntitySynonymMapper
  - name: ResponseSelector
    epochs: 100
    constrain_similarities: true
  - name: FallbackClassifier
    threshold: 0.1
    ambiguity_threshold: 0.1

policies:
  - name: MemoizationPolicy
  - name: RulePolicy
  - name: TEDPolicy
    max_history: 5
    epochs: 100
    constrain_similarities: true

pipeline は Rasa NLU の処理パイプラインの構成ですね。 カスタムコンポーネントはクラスの FQCN を指定すれば大丈夫です。 また、今回はカスタムコンポーネントの中で独自にトークナイズをしており、Duckling も Tokenizer 不要なので、Tokenizer を登録していません。

policies が Rasa Core が使用する Policy の設定になります。

続いて作成するボットのドメインの設定である domain.yml を記述します。

domain.yml

“ドメイン"というとフワっとした表現ですが、作成するボットが扱える intent や entity の定義、変数を格納する slot や form の定義、 ボットの応答メッセージの登録等になります。

%%writefile /content/rasa_3x/domain.yml
version: "3.1"

intents:
  - greet:
      use_entities: []
  - reservation
  - inform
  - thankyou:
      use_entities: []
  - affirm:
      use_entities: []
  - deny:
      use_entities: []

entities:
  - time
  - number
  - duration

forms:
  reservation_form:
    required_slots:
    - time
    - number
    - duration

slots:
  time:
    type: text
    mappings:
    - type: from_entity
      entity: time
  number:
    type: text
    mappings:
    - type: from_entity
      entity: number
  duration:
    type: text
    mappings:
    - type: from_entity
      entity: duration

responses:
  utter_greet:
  - text: "こんにちは。ご用件はなんですか?"

  utter_ok:
  - text: "了解しました。"

  utter_thankyou:
  - text: "こちらこそ、ご利用ありがとうございました。"

  utter_ask_time:
  - text: "日時はいつが良いですか?"

  utter_ask_number:
  - text: "打合せの人数を教えて下さい。"

  utter_ask_duration:
  - text: "会議時間はどのぐらいですか?"

  utter_ask_continue_reservation:
  - text: "会議の予約はつづけたほうがよいですか?"  

  utter_default:
  - text: "よく分からなかったので、言い直してもらえますか?"  

actions:
  - action_make_reservation

session_config:
  session_expiration_time: 60
  carry_over_slots_to_new_session: true

だいたい直感的に分かりそうですが、form について補足しておきましょう。

以下、form の定義部分を再掲します。

forms:
  reservation_form:
    required_slots:
    - time
    - number
    - duration

reservation_form という form があり、必須の slot として time, number, duration を持ちます。slot は会話の中で拾った情報を保持する変数のようなものでしたね。

ユーザとボットのやり取り中に reservation_form を active にすると、ボットは config.yml に記述した utter_ask_{slot} のセリフを使ってユーザから情報を収集してくれます。そして slot が全て埋まったところで後述するカスタムアクションを通してバックエンドのサービスを呼び出すという仕組みです。

slot の値をどのように埋めるかを定義するのは slot の定義部分に記述します。

slots:
  time:
    type: text
    mappings:
    - type: from_entity
      entity: time
  number:
    type: text
    mappings:
    - type: from_entity
      entity: number
  duration:
    type: text
    mappings:
    - type: from_entity
      entity: duration

ここでは単純に抽出した enitity を同名の slot にテキストとして設定してますが、intent を使ったりカスタムのマッピングを定義することもできます。 詳しくは公式ドキュメント14を参照して下さい。

endpoints.yml には Rasa が連携する外部サービスの場所を記述します。バックエンドのサービスをキックするのに使うカスタムアクションは、 Action Server として別に起動されるので、その URL を設定しておきます。

%%writefile /content/rasa_3x/endpoints.yml
action_endpoint:
  url: "http://localhost:5055/webhook"

では、ここからは学習データを用意していきましょう。

7. 学習データの準備

Rasa の学習データは主に NLU Training Data, Stories, Rules の 3 つになります。順に見ていきましょう。

NLU Training Data

NLU Training Data は Rasa NLU の学習データです。初期化したプロジェクトでは data/nlu.yml が相当しますが、 複数のファイルに分けて記述することが可能です。

今回は以下のようにしました。できるだけ量を減らしたかったので、かなり恣意的に特徴を作りこんでますが、そこはご容赦ください。

%%writefile /content/rasa_3x/data/nlu.yml
version: "3.1"

nlu:
- intent: greet
  examples: |
    - おはようございます。
    - おはよう。
    - おはようさんです。
    - おはよう!
    - やぁ、おはよう。
    - おはよー
    - やぁ!おはよう。
    - おはようございます!
    - おはよう

- intent: thankyou
  examples: |
    - ありがとうございます。
    - ありがと。
    - ありがとう。
    - ありがとうございました。
    - どうもありがとうございました。
    - この度はありがとうございました。
    - どうもありがとうございます。
    - 助かりました。ありがとうございます。
    - ありがと!
    - この度はありがとうございます。

- intent: reservation
  examples: |
    - 会議の予約をお願いします。
    - 10日に会議を設定して下さい。
    - 会議予約を15時から1時間でお願いします。
    - 会議の予定を入れてもらえますか
    - 明後日に会議を予約を入れて。
    - 会議を予約して。
    - 会議の調整をお願いします。
    - 会議の予定を調整して。
    - 会議の予約を10時でお願いします。
    - 1時間くらい会議がしたいんですけど。

- intent: inform
  examples: |
    - 10時にして下さい。
    - 明日がいいです。
    - そうですね。6人にして下さい。
    - 明日の15時から1時間でいいです。
    - 30分でいいです。
    - 2時間ぐらいにして下さい。
    - 5日にして下さい。。
    - 4時30分から3人でいいです。
    - 1時間くらいにして下さい。
    - 明後日でいいです。

- intent: affirm
  examples: |
    - はい。続けて。
    - はい。続けてください
    - はい。
    - はい。どうぞ。
    - はい、よろしく。
    - はい。よろしくお願いします
    - はい、OKです
    - はい。大丈夫です。
    - はい、問題ありません。
    - はい、了解です。

- intent: deny
  examples: |
    - いいえ、そうではありません。
    - いいえ。
    - いいえ
    - いいえ、違います。
    - いいえ、もう良いです。
    - いいえ、もういいです。
    - いいえ、やめましょうか。
    - いいえ、やめてください
    - いいえ、承服しかねます。
    - いいえ、同意できません

今回は EntityExtractor に学習不可の DucklingEntityExtractor を使ったので nlu.yml にエンティティ抽出のアノテーションがありませんが、 CRFEntityExtractor のような学習可能なものを使用する場合は以下のように記述します。

    - では[11時](time)から[30分](duration)でお願いします。

Rules

これは RulePolicy の為の学習データになります。ルールベースなのに学習データというのも妙な気がしますが、 RulePolicy のソースを見ると確かに trainable = True になっていますね。詳しくは見ていませんが、コード15を確認すると確かに何かしらしてるみたいです。

ルールを記述する時のコツですが、短く、「この状況では必ずこうする」というものに限定するのが良いと思います。 それ以外は後述する Stories でカバーする感覚です。

今回は以下のようにしました。

%%writefile /content/rasa_3x/data/rules.yml
version: "3.1"

rules:

- rule: Greet
  steps:
  - intent: greet
  - action: utter_greet

- rule: Thank you
  steps:
  - intent: thankyou
  - action: utter_thankyou

- rule: Activate reservation form
  steps:
  - intent: reservation
  - action: reservation_form
  - active_loop: reservation_form

- rule: Submit form
  condition:
  - active_loop: reservation_form
  steps:
  - action: reservation_form
  - active_loop: null
  - action: action_make_reservation

- rule: Continue form after interrupt
  condition:
  - active_loop: reservation_form
  steps:
  - action: utter_ask_continue_reservation
  - intent: affirm
  - action: reservation_form
  - active_loop: reservation_form

- rule: Cancel form after interrupt
  condition:
  - active_loop: reservation_form
  steps:
  - action: utter_ask_continue_reservation
  - intent: deny
  - action: utter_ok
  - action: action_deactivate_loop
  - active_loop: null

特徴的なのは "Activate reservation form” と “Submit form” を分割したところでしょうか。これは

  • ユーザ : 会議の予約をお願いします。
  • ボット : 日時はどうしますか?
  • ユーザ : 夕べはカキフライを食べました。

のような form の slot を収集する途中で脇道にそれるパターンに対応する為です。

脇道にそれた後で復帰するかどうかを確認する際の動きも “Continue form after interrupt”, “Cancel form after interrupt” としてルール化しました。

Stories

Stories は TEDPolicy の学習データです。ユーザの様々な振る舞いを全てルールでマッチさせるのは困難なので、 Transformer を使ったモデルによって学習時に未経験の会話のパターンにも良い感じで対応させるようにします。

%%writefile /content/rasa_3x/data/stories.yml
version: "3.1"

stories:

- story: Greet and thankyou
  steps:
  - intent: greet
  - action: utter_greet
  - intent: thankyou
  - action: utter_thankyou

- story: Happy path
  steps:
  - intent: greet
  - action: utter_greet
  - intent: reservation
  - action: reservation_form
  - active_loop: reservation_form
  - active_loop: null
  - action: action_make_reservation
  - intent: thankyou
  - action: utter_thankyou

- story: Interrupt after reservation and continue.
  steps:
  - intent: greet
  - action: utter_greet
  - or :
    - intent: reservation
      entities:
      - time: "3月22日"
    - intent: reservation
      entities:
      - number: "8"
    - intent: reservation
      entities:
      - duration: "30分"
  - action: reservation_form
  - active_loop: reservation_form
  - or :
    - slot_was_set:
      - requested_slot: time
    - slot_was_set:
      - requested_slot: number
    - slot_was_set:
      - requested_slot: duration
  - intent: greet
  - action: utter_greet
  - action: utter_ask_continue_reservation
  - intent: affirm
  - action: reservation_form

- story: Interrupt after reservation and cancel.
  steps:
  - intent: greet
  - action: utter_greet
  - or :
    - intent: reservation
      entities:
      - time: "10日"
    - intent: reservation
      entities:
      - number: "4"
    - intent: reservation
      entities:
      - duration: "1時間"
  - action: reservation_form
  - active_loop: reservation_form
  - or :
    - slot_was_set:
      - requested_slot: time
    - slot_was_set:
      - requested_slot: number
    - slot_was_set:
      - requested_slot: duration
  - intent: greet
  - action: utter_greet
  - action: utter_ask_continue_reservation
  - intent: deny
  - action: utter_ok
  - action: action_deactivate_loop
  - active_loop: null

- story: Interrupt after inform and continue.
  steps:
  - intent: greet
  - action: utter_greet
  - or :
    - intent: reservation
      entities:
      - time: "3月22日"
    - intent: reservation
      entities:
      - number: "8"
    - intent: reservation
      entities:
      - duration: "30分"
  - action: reservation_form
  - active_loop: reservation_form
  - or :
    - slot_was_set:
      - requested_slot: time
    - slot_was_set:
      - requested_slot: number
    - slot_was_set:
      - requested_slot: duration
  - or :
    - intent: inform
      entities:
      - time: "3月22日"
    - intent: inform
      entities:
      - number: "8"
    - intent: inform
      entities:
      - duration: "30分"
  - action: reservation_form
  - active_loop: reservation_form
  - or :
    - slot_was_set:
      - requested_slot: time
    - slot_was_set:
      - requested_slot: number
    - slot_was_set:
      - requested_slot: duration
  - intent: greet
  - action: utter_greet
  - action: utter_ask_continue_reservation
  - intent: affirm
  - action: reservation_form

- story: Interrupt after inform and cancel.
  steps:
  - intent: greet
  - action: utter_greet
  - or :
    - intent: reservation
      entities:
      - time: "3月22日"
    - intent: reservation
      entities:
      - number: "8"
    - intent: reservation
      entities:
      - duration: "30分"
  - action: reservation_form
  - active_loop: reservation_form
  - or :
    - slot_was_set:
      - requested_slot: time
    - slot_was_set:
      - requested_slot: number
    - slot_was_set:
      - requested_slot: duration
  - or :
    - intent: inform
      entities:
      - time: "3月22日"
    - intent: inform
      entities:
      - number: "8"
    - intent: inform
      entities:
      - duration: "30分"
  - action: reservation_form
  - active_loop: reservation_form
  - or :
    - slot_was_set:
      - requested_slot: time
    - slot_was_set:
      - requested_slot: number
    - slot_was_set:
      - requested_slot: duration
  - intent: greet
  - action: utter_greet
  - action: utter_ask_continue_reservation
  - intent: deny
  - action: utter_ok
  - action: action_deactivate_loop
  - active_loop: null

今回は以下のパターンの story を記述しました。

  • シンプルな挨拶
  • 素直に会議予約をする
  • 会議予約を依頼した直後に横道それて継続する
  • 会議予約を依頼した直後に横道それてキャンセルする
  • 会議予約の slot の値を収集する際に横道それて継続する
  • 会議予約の slot の値を収集する際に横道それてキャンセルする

上記の記述では or 構文を使って記述してみました。詳しくは確認していませんが、学習時に or を解析して全てのパターンの組み合わせで学習データを形成してくれるのだと思います16

厳密に言うと、

    - intent: reservation
      entities:
      - time: "3月22日"

で intent に reservation 、 entities に time を検出した状況の後に

    - slot_was_set:
      - requested_slot: time

と time の slot を要求している局面は発生しないはずですが、効率良くパターンを記述できることを優先しました。

これで全てのパターンを網羅出来ているわけではないでしょうが、まずは基本的なところからはじめて、 動かしながら漏れているパターンを適宜追加するのが取り組みやすいのではないかと思います。

あと言い訳です。

あたかも “3月22日” を time 、"30分" を duration として検出できるような書きっぷりですが、 Duckling は全然そんな風に動いてくれませんでした。あくまで私の願望が書いてあるだけです。 例えば「13時からがよいです。」では “13時” を time として抽出して欲しいんですが、"13" を duration として拾ってしまったりとか。。。 まぁ、そんな感じです。 ただ、TEDPolicy の学習では検出した entities を One-hot にして結合するので TEDPolicy の次アクションの学習には大きな影響ないんじゃないかと。

そういえば、第2回で Rasa NLU の記事を書いた後、けっこう本気で Duckling の日本語化をしようと思った時期がありましたね。 作業量的に大変そうなのと BERT を触りたくなったので放り出したのでした。。。あれから月日がたったのでワンチャンいけるか?と思ったのですがダメでした17。。。

学習データではないですが、あと二つ紹介させてください。

Test Stories

Stories のテストデータを記述することができます。

フォーマットは Stories に似ていますが、ユーザのセリフとその intent, 含まれる entities を記述できるようになっていますね。

%%writefile /content/rasa_3x/tests/test_stories.yml
version: "3.1"

stories:
- story: Greet and thankyou
  steps:
  - user: |
      おはよう。
    intent: greet
  - action: utter_greet
  - user: |
      ありがとう。
    intent: thankyou
  - action: utter_thankyou


- story: Happy path
  steps:
  - user: |
      おはよう。
    intent: greet
  - action: utter_greet
  - user: |
      会議室の予約をお願いします。
    intent: reservation
  - action: reservation_form
  - active_loop: reservation_form
  - user: |
      [10時](time)にして下さい。
    intent: inform
  - action: reservation_form
  - active_loop: reservation_form
  - user: |
      [6](number)人にして下さい。
    intent: inform
  - action: reservation_form
  - active_loop: reservation_form
  - user: |
      [2時間](duration)ぐらいにして下さい。
    intent: inform
  - action: reservation_form
  - active_loop: null
  - action: action_make_reservation
  - user: |
      ありがとう。
    intent: thankyou
  - action: utter_thankyou

- story: Interrupt after reservation and continue.
  steps:
  - user: |
      おはよう。
    intent: greet
  - action: utter_greet
  - user: |
      [10時](time)に会議室の予約をお願いします。
    intent: reservation
  - action: reservation_form
  - active_loop: reservation_form
  - slot_was_set:
    - requested_slot: number
  - user: |
      おはよう。
    intent: greet
  - action: utter_greet
  - action: utter_ask_continue_reservation
  - user: |
      はい。どうぞ。
    intent: affirm
  - action: reservation_form

- story: Interrupt after reservation and cancel.
  steps:
  - user: |
      おはよう。
    intent: greet
  - action: utter_greet
  - user: |
      [10時](time)に会議室の予約をお願いします。
    intent: reservation
  - action: reservation_form
  - active_loop: reservation_form
  - slot_was_set:
    - requested_slot: number
  - user: |
      おはよう。
    intent: greet
  - action: utter_greet
  - action: utter_ask_continue_reservation
  - user: |
      いいえ、もういいです。
    intent: deny
  - action: utter_ok
  - action: action_deactivate_loop
  - active_loop: null

- story: Interrupt after inform and continue.
  steps:
  - user: |
      おはよう。
    intent: greet
  - action: utter_greet
  - user: |
      会議室の予約をお願いします。
    intent: reservation
  - action: reservation_form
  - active_loop: reservation_form
  - slot_was_set:
    - requested_slot: time
  - user: |
      [明日](time)にして下さい。
    intent: inform
  - action: reservation_form
  - active_loop: reservation_form
  - slot_was_set:
    - requested_slot: number
  - user: |
      おはよう。
    intent: greet
  - action: utter_greet
  - action: utter_ask_continue_reservation
  - user: |
      はい。どうぞ。
    intent: affirm
  - action: reservation_form

- story: Interrupt after inform and cancel.
  steps:
  - user: |
      おはよう。
    intent: greet
  - action: utter_greet
  - user: |
      [6]人(number)で会議室の予約をお願いします。
    intent: reservation
  - action: reservation_form
  - active_loop: reservation_form
  - slot_was_set:
    - requested_slot: time
  - user: |
      [明日](time)にして下さい。
    intent: inform
  - action: reservation_form
  - active_loop: reservation_form
  - slot_was_set:
    - requested_slot: duration
  - user: |
      おはよう。
    intent: greet
  - action: utter_greet
  - action: utter_ask_continue_reservation
  - user: |
      いいえ、もういいです。
    intent: deny
  - action: utter_ok
  - action: action_deactivate_loop
  - active_loop: null

Custom Actions

最後はカスタムアクションです。form の slot が埋まった後でバックエンドのサービスを叩くところですね。 今回は収集した slot の値をユーザにテキストで示すだけにしています。

%%writefile /content/rasa_3x/actions/actions.py
import logging
from typing import Any, Text, Dict, List
from rasa_sdk import Action, Tracker
from rasa_sdk.events import AllSlotsReset
from rasa_sdk.executor import CollectingDispatcher

logger = logging.getLogger(__name__)

class ActionMakeReservation(Action):
  def name(self) -> Text:
    return "action_make_reservation"

  def run(self, dispatcher: CollectingDispatcher, tracker: Tracker,
            domain: Dict[Text, Any]) -> List[Dict[Text, Any]]:
    slots = tracker.slots
    output=(f"次の条件で会議予約をしました。\n"
            f" - 日時={slots['time']}\n"
            f" - 人数={slots['number']}\n"
            f" - 会議時間={slots['duration']}")
    logger.debug(output)
    dispatcher.utter_message(text=output)
    return [AllSlotsReset()]

ちなみに名前の action_make_reservation は domain.yml に登録しておく必要があります。

準備が整ったので実際に学習を動かしてみましょう。

8. Rasa の学習とテスト

以下のようにして学習を実行します。コマンドラインで実行するときの “rasa train” と等価の処理です。 10 分程かかったような気がします。

os.chdir("/content/rasa_3x")
config = "config.yml"
training_files = "data/"
domain = "domain.yml"
output = "models/"
model_path = rasa.train(domain, config, {training_files}, output)

# (0lqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqk(B
# (0x(B Rasa Open Source reports anonymous usage telemetry to help improve the product (0x(B
# (0x(B for all its users.                                                             (0x(B
# (0x(B                                                                                (0x(B
# (0x(B If you'd like to opt-out, you can use `rasa telemetry disable`.                (0x(B
# (0x(B To learn more, check out https://rasa.com/docs/rasa/telemetry/telemetry.       (0x(B
# (0mqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqj(B
...
# Your Rasa model is trained and saved at 'models/20230327-023052-glass-height.tar.gz'.

学習が終わったら先程のテストを実行してみましょう。"rasa test core" に相当します。

os.chdir("/content/rasa_3x")
stories = "tests/test_stories.yml"
output = "results"
from rasa.model_testing import test_core
await test_core(model=model_path.model, stories=stories, output=output)

以下のような混同行列が表示されました。大丈夫そうですね。

rasa

テスト結果のフォルダは以下のようになっていました。

!ls /content/rasa_3x/results
# failed_test_stories.yml     story_report.json
# stories_with_warnings.yml   TEDPolicy_confusion_matrix.png
# story_confusion_matrix.png  TEDPolicy_report.json

失敗した story はなかったようです。

!cat /content/rasa_3x/results/failed_test_stories.yml
# None of the test stories failed - all good!

警告も特に出ていません。

!cat /content/rasa_3x/results/stories_with_warnings.yml
# None of the test stories failed - all good!

ちなみにテストが失敗した場合は以下のように出力されます。

version: "3.1"
stories:
- story: Interrupt after reservation and cancel. (/content/rasa_3x/tests/test_stories.yml)
  steps:
  - intent: greet
  - action: utter_greet
  - entities:
    - time: 10時
    user: |-
      [10時](time)に会議室の予約をお願いします。
  - slot_was_set:
    - time: 10時
  - action: action_listen  # predicted: utter_greet
  - intent: reservation
  - action: reservation_form
  - active_loop: reservation_form
  - slot_was_set:
    - requested_slot: number
  - intent: greet
  - action: utter_greet
  - action: utter_ask_continue_reservation
  - intent: deny
  - action: utter_ok
  - action: action_deactivate_loop
  - active_loop: null
- story: Interrupt after inform and cancel. (/content/rasa_3x/tests/test_stories.yml)
  steps:
  - intent: greet
  - action: utter_greet
  - user: |-
      [6]人(number)で会議室の予約をお願いします。
  - intent: reservation
  - action: reservation_form
  - active_loop: reservation_form
  - slot_was_set:
    - requested_slot: time
  - slot_was_set:
    - time: 明日
  - active_loop: reservation_form
  - slot_was_set:
    - requested_slot: duration
  - intent: greet
  - action: utter_greet
  - action: utter_ask_continue_reservation  # predicted: action_listen
  - intent: deny
  - action: utter_ok
  - action: action_deactivate_loop
  - active_loop: null

失敗のあった story で間違った箇所が “# predicted:” で示されているので、内容を確認して適宜 Stories にパターンを追加して潰していきましょう。

9. 学習済みのボットを利用

以下のようにして学習済みのボットを利用します。コマンドラインで言うと “rasa run actions”, “rasa shell –debug” をする感じです。

os.chdir("/content/rasa_3x")
from rasa.jupyter import chat
from rasa.core.utils import AvailableEndpoints
endpoints = AvailableEndpoints.read_endpoints("/content/rasa_3x/endpoints.yml")
import logging
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
  if logger.name.startswith("rasa"):
    #print(logger.name)
    logger.setLevel(logging.DEBUG)
  else:
    logger.setLevel(logging.INFO)
import subprocess
import requests
subprocess.Popen(["stack","exec","duckling-example-exe"], cwd="/content/duckling")
subprocess.Popen(["rasa","run","actions"], cwd="/content/rasa_3x")
chat(model_path.model, endpoints=endpoints)
# ...
# Your bot is ready to talk! Type your messages here or send '/stop'.
おはよう。
# DEBUG:rasa.core.processor:Received user message 'おはよう。' with intent '{'name': 'greet', 'confidence': 0.6915305852890015}' and entities '[]'
# DEBUG:rasa.core.processor:Predicted next action 'utter_greet' with confidence 1.00.
# DEBUG:rasa.core.processor:Predicted next action 'action_listen' with confidence 1.00.
# こんにちは。ご用件はなんですか?
会議の予約をお願いします。
# DEBUG:rasa.core.processor:Received user message '会議の予約をお願いします。' with intent '{'name': 'reservation', 'confidence': 0.6212062239646912}' and entities '[]'
# DEBUG:rasa.core.processor:Predicted next action 'reservation_form' with confidence 1.00.
# DEBUG:rasa.core.processor:Predicted next action 'action_listen' with confidence 1.00.
# 日時はいつが良いですか?
明日にして下さい。
# DEBUG:rasa.core.processor:Received user message '明日にして下さい。' with intent '{'name': 'inform', 'confidence': 0.7296549677848816}' and entities '[{'start': 0, 'end': 3, 'text': '明日に', 'value': '2023-03-28T00:00:00.000+09:00', 'confidence': 1.0, 'additional_info': {'values': [{'value': '2023-03-28T00:00:00.000+09:00', 'grain': 'day', 'type': 'value'}], 'value': '2023-03-28T00:00:00.000+09:00', 'grain': 'day', 'type': 'value'}, 'entity': 'time', 'extractor': 'DucklingEntityExtractor'}]'
# DEBUG:rasa.core.processor:Predicted next action 'reservation_form' with confidence 1.00.
# DEBUG:rasa.core.processor:Predicted next action 'action_listen' with confidence 1.00.
# 打合せの人数を教えて下さい。
6人でよいです。
# DEBUG:rasa.core.processor:Received user message '6人でよいです。' with intent '{'name': 'inform', 'confidence': 0.44254735112190247}' and entities '[{'start': 0, 'end': 1, 'text': '6', 'value': 6, 'confidence': 1.0, 'additional_info': {'value': 6, 'type': 'value'}, 'entity': 'number', 'extractor': 'DucklingEntityExtractor'}]'
# DEBUG:rasa.core.processor:Predicted next action 'reservation_form' with confidence 1.00.
# DEBUG:rasa.core.processor:Predicted next action 'action_listen' with confidence 1.00.
# 会議時間はどのぐらいですか?
30分間にして下さい。
# DEBUG:rasa.core.processor:Received user message '30分間にして下さい。' with intent '{'name': 'inform', 'confidence': 0.696469783782959}' and entities '[{'start': 0, 'end': 4, 'text': '30分間', 'value': 30, 'confidence': 1.0, 'additional_info': {'value': 30, 'type': 'value', 'minute': 30, 'unit': 'minute', 'normalized': {'value': 1800, 'unit': 'second'}}, 'entity': 'duration', 'extractor': 'DucklingEntityExtractor'}]'
# DEBUG:rasa.core.processor:Predicted next action 'reservation_form' with confidence 1.00.
# DEBUG:rasa.core.processor:Predicted next action 'action_make_reservation' with confidence 1.00.
# DEBUG:rasa.core.actions.action:Calling action endpoint to run action 'action_make_reservation'.
# DEBUG:rasa.core.processor:Predicted next action 'action_listen' with confidence 1.00.
# 次の条件で会議予約をしました。
# - 日時=2023-03-28T00:00:00.000+09:00
# - 人数=6
# - 会議時間=30
ありがとう。
# こちらこそ、ご利用ありがとうございました。
/stop

良い感じに動いている風ですね。

DEBUG ログは大量に出力されるので、Rasa NLU の解析結果と次アクションが決定したログだけを残しています。

ですが、出力ログをよく見ると「夜中の0時に会議を予約してどうするんだ!」という話になっています。。。 本来であれば、「13日がいいですね。15時からでお願いします。」のようなテキストから “13日” と “15時” を抜き出して “2023-04-13T15:00:00.000+09:00” のようにしたいところです。

これに関しては FormValidationAction18 を使うことで対応できるかと思います。この仕組みを使うと、

  • 必須ではない slot を作りたい。
    例:「八宝菜のレシピを教えて」の後で「料理のジャンルは何がいいですか?」と聞かないで欲しい場合です。
  • 選択不可の値を拒否したい。
    例: 「666会議室をお願いします。」と存在しない値を要求された時に「そんな会議室ありません。」と応じたい場合です19
  • list 型の slot に要素を追加したい。
    例: slot が [“佐藤”, “鈴木”] の状態で「会議のメンバーに山田さんを追加して下さい」と言われたら [“佐藤”, “鈴木”, “山田”] にして欲しい場合です。

というようなケースにも対応できると思います。

今回の会議予約(風)のボットとは無関係にはなりますが、最初の例を示しておくと以下のような感じです。slot は list 型で menu のエンティティを検出して、 cuisine が未設定状態だったら、「未設定で OK」の意思表示として [“NULL”] を設定しています。最後の例は試せてません(すいません)。

from typing import Dict, Text, Any, List, Union
from rasa_sdk import Tracker
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.forms import FormValidationAction

class ValidateRecipeForm(FormValidationAction):

  def name(self) -> Text:
    return "validate_recipe_form"

  def validate_menu(self, value: Text, dispatcher: CollectingDispatcher,
    tracker: Tracker, domain: Dict[Text, Any]) -> Dict[Text, Any]:
    slots = tracker.slots
    if value is not None and (slots['cuisine'] is None or len(slots['cuisine']) == 0):
      return {"cuisine": ["NULL"], "menu": value}
    return {"menu": value}

少し横道にそれましたが、折角なので form の途中で横やりをいれて続行するケースとキャンセルするケースもやってみましょう。

まずは続行する場合です。

# Your bot is ready to talk! Type your messages here or send '/stop'.
おはよう。
# こんにちは。ご用件はなんですか?
会議の予約をお願いします。
# 日時はいつが良いですか?
明日にして下さい。
# 打合せの人数を教えて下さい。
おはよう。
# こんにちは。ご用件はなんですか?
# 会議の予約はつづけたほうがよいですか?
はい。どうぞ。
# 打合せの人数を教えて下さい。
6人でよいです。
# 会議時間はどのぐらいですか?
30分間にして下さい。
# 次の条件で会議予約をしました。
#  - 日時=2023-03-28T00:00:00.000+09:00
#  - 人数=6
#  - 会議時間=30
ありがとう。
# こちらこそ、ご利用ありがとうございました。
/stop

utter_greet の仕込みのセリフ「こんにちは。ご用件はなんですか?」がアレな感じですが、ちゃんと継続意思を確認して想定の流れに戻れています。

今度はキャンセルしてみましょう。

# Your bot is ready to talk! Type your messages here or send '/stop'.
おはよう。
# こんにちは。ご用件はなんですか?
会議の予約をお願いします。
# 日時はいつが良いですか?
明日にして下さい。
# 打合せの人数を教えて下さい。
おはよう。
# こんにちは。ご用件はなんですか?
# 会議の予約はつづけたほうがよいですか?
いいえ、もういいです。
# 了解しました。
/stop

こちらも問題ないですね。最後に「思ったとおりに動かない!」という時にどうしたらいいか整理しておきましょう。

10. 思ったとおりに動かない時

Rasa で作ったボットが思った通りに動かないのは大別して以下のパターンに分けられると思います。順に見ていきましょう。

intent, entity を正しく認識できていない。

この場合は認識に失敗したサンプルを学習データに追加しましょう。学習データが不足しているなら追加します。 それでもダメなら、より強力なモデルを使った実装にコンポーネントを置き換えましょう。

Policy が次アクションの判定を間違えている。

前章の動作ログでは省略しましがたがデバッグログには以下のように、どの Policy がどういう判定をして最終的にどの Policy の意見を採用したかが出力されています。

DEBUG:rasa.engine.graph:Node 'run_MemoizationPolicy0' running 'MemoizationPolicy.predict_action_probabilities'.
DEBUG:rasa.core.policies.memoization:Current tracker state:
[state 1] user intent: greet | previous action name: action_listen
[state 2] user intent: greet | previous action name: utter_greet
[state 3] user intent: reservation | previous action name: action_listen
[state 4] user intent: reservation | previous action name: reservation_form | slots: {'time': (1.0,), 'number': (1.0,), 'duration': (1.0,)}
DEBUG:rasa.core.policies.memoization:There is no memorised next action
DEBUG:rasa.engine.graph:Node 'run_RulePolicy1' running 'RulePolicy.predict_action_probabilities'.
DEBUG:rasa.core.policies.rule_policy:Current tracker state:
[state 1] user intent: greet | previous action name: action_listen
[state 2] user intent: greet | previous action name: utter_greet
[state 3] user intent: reservation | previous action name: action_listen
[state 4] user intent: reservation | previous action name: reservation_form | slots: {'time': (1.0,), 'number': (1.0,), 'duration': (1.0,)}
DEBUG:rasa.core.policies.rule_policy:There is a rule for the next action 'action_make_reservation'.
DEBUG:rasa.engine.graph:Node 'run_TEDPolicy2' running 'TEDPolicy.predict_action_probabilities'.
DEBUG:rasa.core.policies.ted_policy:TED predicted 'action_make_reservation' based on user intent.
DEBUG:rasa.engine.graph:Node 'select_prediction' running 'DefaultPolicyPredictionEnsemble.combine_predictions_from_kwargs'.
DEBUG:rasa.core.policies.ensemble:Predicted next action using RulePolicy.
DEBUG:rasa.core.processor:Predicted next action 'action_make_reservation' with confidence 1.00.

RulePolicy が最強ですから、誤った選択をするのが RulePolicy であればルールを見直して下さい。 MemoizationPolicy の場合も同様で原因になっている story を修正する必要があるでしょう。 TEDPolicy が誤認識する場合は 7 章で述べたように、Stories に失敗するケースを追記して、テストで確認するのがよいと思います。

11. おわりに

今回は、チャットボットフレームワークである Rasa を紹介し、transformers を使ったカスタムコンポーネントを作ってみました。 transformers を使えば文章分類や固有表現抽出、テキスト生成等プリミティブなタスクはこなせますが、 実際にユーザとテキストでやり取りするアプリケーションを作るとなると Rasa のようなフレームワークを使うのが良いかと思っていました。 なのですが、世の中の雰囲気がガラッと変わっているので、次回は ChatGPT の流れに乗った流行りものに取り組みたいですね。 もう沢山の人に書き尽くされている気もするのでどうしたものだか。。。


  1. https://github.com/RasaHQ/rasa 

  2. 本連載は他の人が書いてないことや書かないことを書くのが良いかなと思っていて。あと、読みたい物や、やりたい事が他に沢山あるんです! 

  3. https://tech-blog.optim.co.jp/entry/2021/11/17/100000 

  4. https://qiita.com/Zect/items/e341e43fa9cf98529942 

  5. ここで抽出した intent を手掛かりの一部として後述する Policy がチャットボットの振る舞いを決定します。 

  6. https://github.com/facebook/duckling 

  7. https://github.com/RasaHQ/rasa/blob/99827178664ae4649465d208e9221f2b391901cc/rasa/nlu/utils/hugging_face/registry.py#L56-L65 

  8. https://github.com/RasaHQ/rasa/blob/99827178664ae4649465d208e9221f2b391901cc/rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py#L150 

  9. https://github.com/RasaHQ/rasa/blob/99827178664ae4649465d208e9221f2b391901cc/rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py#L171 

  10. https://arxiv.org/abs/1910.00486 

  11. https://www.youtube.com/watch?v=j90NvurJI4I&list=PL75e0qA87dlG-za8eLI6t0_Pbxafk-cxb&index=15 

  12. https://github.com/RasaHQ/rasa/blob/e7b09b80fe9502d913b3c15bc7ad2590b5004709/rasa/core/policies/ted_policy.py#L1765 のあたりを見てもらえれば良いかと思います。メッセージテキストやボットの返信メッセージテキストを Featurizer で特徴量にしたものや、現在アクティブな form 等の情報も含まれます。 

  13. 実際は TEDPolicy で TEXT や ACTION_TEXT の特徴量に使われる次元数を拡大したり、TEDPolicy そのものの層数を拡張したり、それなりに調整がいるかと思います。まぁ正直、ノリで作ったんですが、ある程度のデータ量の学習データを用意すれば、BERT を使わずとも CountVector ベースの特徴量で十分な精度が得られる気もしますね。。。 

  14. https://rasa.com/docs/rasa/domain#slots 

  15. https://github.com/RasaHQ/rasa/blob/99827178664ae4649465d208e9221f2b391901cc/rasa/core/policies/rule_policy.py#L782 

  16. 使いすぎると組み合わせが爆発して学習が遅くなるので注意して下さい。 

  17. 筆者の設定が悪いとかだったら教えて下さい。 

  18. https://rasa.com/docs/action-server/validation-action/#formvalidationaction-class 最初は FormValidateAction も作りこむ予定だったのですが、どうにもこうにも Duckling がどうにもならなかったので諦めましたー。 

  19. https://github.com/RasaHQ/rasa/blob/main/examples/formbot/actions/actions.py