オブジェクトの広場はオージス総研グループのエンジニアによる技術発表サイトです

AI

はじめての自然言語処理

第11回 QuartzNet で音声認識 JetBot を試してみる
オージス総研 技術部 アドバンストテクノロジセンター
鵜野 和也
2020年10月22日

今回も音声認識です。前回は QuartzNet と JSUT データセットで音声→テキスト変換の検証を行いました。今回は問題を音声コマンド認識に簡単化し、自前のデータセットを作成、PCのマイクに話かけてストリーミングでの推論を試します。ついでに JetBot の操縦をしてみましょう。

1. はじめに

今回も 前回 に続いて音声認識です。使用するモデルは前回と同じ QuartzNet ですので、前回を未読の方は 第10回 の記事に目を通して戻ってきて頂けると、より理解がしやすいと思います。前回は音声→テキスト変換の話だったのですが、音声認識モデルを作るとどうしてもマイクに話かけて認識されるかどうか試したくなります。ですが前回の JSUT コーパスは単一女性の声のみですので、私の声が認識できるはずもなく。CSJコーパス1あたりを購入して試しても良いのですが、この連載は読んだ人が試せるようにしたいので問題を音声コマンド認識にすることにしました。

音声コマンド認識なら自分の声を認識するくらいのデータセットは比較的手軽に作れそうです。なのですが、音声コマンド認識ですと認識した結果で「テレビがつく」とか何か動かないとつまらないです。「何かないかな?」と思っていたところ、社内の誰かが作って遊んで飽き…ではなく技術検証作業が完了した JetBot が稼働状態のまま会社のキャビネットの中で眠っていました。

jetbot

QuartzNet は 15x5 のフルサイズモデルでも Jetson Nano でリアルタイムストリーミングでの処理が可能なので、音声コマンド認識用の小規模モデルなら USBバッテリー駆動でフルパワーがでてない JetBot でも動きそう。

そういう訳で今回は自前でデータセットを作成し、モデルを学習、JetBot にデプロイして操縦という流れで進めていきましょう2

まずは、音声コマンド分類データセットの作成です。

2. 音声コマンド分類データセットの作成

JetBot を声で操作する為の音声コマンド分類のデータセットを作成します。環境としては マイクの付いた PC、Windows 10、Chrome、Colab です。

大まかな流れとして、以下のような作業になります。

  1. ラベル生成スクリプトが出力するラベルを読み上げ
  2. その声を PC のマイクで拾い、WebSocket 経由で Colab 上で起動した録音サーバに飛ばし、 wav ファイルとして保存
  3. 保存した wav を無音部分で分割し 1. で出力したラベルと突き合わせながら、必要に応じて修正

まず、Colab 上で起動した録音サーバに接続する為、ngrock の準備をします。

ngrock

ngrock は簡単にいうと NAT やファイアーウォールの内側で動作しているサーバをセキュアなトンネルを通して手軽にインターネットに公開するサービスです。 Colab のランタイムで動作しているサーバプロセスに直接アクセスすることはできないので、ngrock でトンネルを作ってあげる訳です。

「ローカルサーバをインターネットに公開できる」というのは「悪意のあるアクセスを NAT やファイウォールの内側に呼び込む可能性を作っている」というのことなので、そこは注意して下さい。この手のサービスの利用を社内ルールで禁止している会社もあろうかと思います。3

まずは、 https://ngrok.com/ でサインアップを済ませて下さい。無料版で OK です。

サインアップしたら認証トークンの値を確認してください。

ngrok_authtoken

認証トークンの値が確認できたら、 Colab でノートブックを開きましょう。音声を拾って wav に書き出すだけなので、ランタイムのアクセラレータは None で構いません。

この章ではノートブックを複数並行で使用するので、このノートブックを “noteboook-1” とします。以後、コード例先頭に “# notebook-?” の行がある場合は当該のノートブックでそのセルを実行するものとします。

それでは、次に ngrok の実行コマンドを取得します。

# notebook-1
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip ngrok-stable-linux-amd64.zip

ngrok コマンドを使って、先ほど確認した認証トークンを設定しておきます。

# notebook-1
!./ngrok authtoken your_auth_token_here

# Authtoken saved to configuration file: /root/.ngrok2/ngrok.yml

秘密鍵と証明書の生成

Chrome から録音サーバへの WebSocket 接続には HTTPS(WSS) を使用したので、 以下のようにして秘密鍵と証明書を生成しておきます。 セキュリティ的な強度とかは考慮しておらず、とりあえず WSS で接続したいが為にやっています4

# notebook-1
!mkdir ssl
!openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout "./ssl/cert.key" -out "./ssl/cert.pem" -batch

録音サーバ

録音サーバで使用するライブラリをインストールします。

# notebook-1
!pip install configargparse
!pip install samplerate

録音サーバのコードは以下のとおりです5

%%bash
# notebook-1
cat <<EOF > wsrecoder.py
#!/usr/bin/env python3
import sys
import os
import math
import configargparse
import tornado.ioloop
import tornado.web
import tornado.websocket
import wave
import numpy as np
import datetime
import samplerate


# How to make server key and certificate.
# openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout "cert.key" -out "cert.pem" -batch

def log(message):
    print(message, flush=True)

def get_parser(parser=None, required=True):
  if parser is None:
    parser = configargparse.ArgumentParser(
      description='',
      config_file_parser_class=configargparse.YAMLConfigFileParser,
      formatter_class=configargparse.ArgumentDefaultsHelpFormatter)

  parser.add('--config', is_config_file=True, help='config file path')
  parser.add_argument('--sample_rate', default=16000, type=int, help="Server side sample rate.")
  parser.add_argument('--browser_sample_rate', default=48000, type=int, help="Client side sample rate.")
  parser.add_argument('--cert_file', default="./ssl/cert.pem", type=str, help="SSL cert file.")
  parser.add_argument('--key_file', default="./ssl/cert.key", type=str, help="SSL key file.")
  parser.add_argument('--websocket_port', default=8889, type=int, help="WebSocket port.")
  parser.add_argument('--frame_len', default=210.0, type=float, help="frame length (sec)")
  parser.add_argument('--message_len', default=2048, type=int, help="message size of websocket upload.")
  parser.add_argument('--dump_wav', default=1, type=int, help="whether to dump frame as wav.")
  parser.add_argument('--dump_npy', default=1, type=int, help="whether to dump frame as npy.")

  return parser

def calc_frame_size(args):
  log("frame_len(sec) = %f" % (args.frame_len))
  log("browser_sample_rate = %d" % (args.browser_sample_rate))
  log("message_len = %d" % (args.message_len))
  n_browser_frame_len = int(args.frame_len * args.browser_sample_rate)
  log("Round frame_len times brower_sample_rate(%d) to a multiple of message_len(%d)"
       % (n_browser_frame_len, args.message_len))
  n_browser_frame_len = args.message_len * (n_browser_frame_len // args.message_len)
  log("n_browser_frame_len = %d" % (n_browser_frame_len))

  resample_rate = args.sample_rate / args.browser_sample_rate
  pseudo_browser_frame = np.zeros(n_browser_frame_len, dtype=np.float32)
  pseudo_resampled = samplerate.resample(pseudo_browser_frame, resample_rate, 'sinc_best')
  n_frame_len = pseudo_resampled.shape[0]
  log("sample_rate = %d" % (args.sample_rate))
  log("resample rate = %f" % (resample_rate))
  log("n_frame_len = %d" % (n_frame_len))

  return n_browser_frame_len, n_frame_len

class WebSocketHandler(tornado.websocket.WebSocketHandler):

  def initialize(self, n_browser_frame_len, n_frame_len, message_len,
                 sample_rate, browser_sample_rate, dump_wav=False, dump_npy=False):
    log("Initialize WebSocket handler.")
    self.message_len = message_len
    self.n_browser_frame_len = n_browser_frame_len
    self.n_frame_len = n_frame_len
    self.sample_rate = sample_rate
    self.browser_sample_rate = browser_sample_rate
    assert n_browser_frame_len % message_len == 0
    self.n_messages_per_frame = n_browser_frame_len // message_len
    log("sample_rate : %d" % (sample_rate))
    log("browser_sample_rate : %d" % (browser_sample_rate))
    log("n_frame_len : %d" % (self.n_frame_len))
    log("n_browser_frame_len : %d" % (self.n_browser_frame_len))
    log("n_messages_per_frame : %d" % (self.n_messages_per_frame))
    self.dump_wav = dump_wav
    self.dump_npy = dump_npy
    self.pos = 0

  def open(self):
    self.buffer = []
    log("audio socket opened")

  def on_message(self, message):
    message = np.frombuffer(message, dtype='float32')
    log("on message : size=%d, buffer size=%d" % (len(message), len(self.buffer)))
    assert len(message)==self.message_len    
    self.buffer.append(message)

    if len(self.buffer) >= self.n_messages_per_frame:
      self.transcribe()

  def transcribe(self):
    browser_frame = np.array(self.buffer).flatten()
    log("browsr frame length = %s" % (browser_frame.shape))
    rate = self.sample_rate/self.browser_sample_rate
    log("resample rate = %f" % (rate))
    frame = samplerate.resample(browser_frame, rate, 'sinc_best')
    log("resampled frame length = %s" % (frame.shape))

    timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
    if self.dump_npy:
      np.save("frame_%s" % (timestamp), frame)
    if self.dump_wav:
      self.dump_as_wav(frame, timestamp)
    self.buffer.clear()

  def dump_as_wav(self, frame, timestamp):
    arr = (frame * 32767).astype(np.int16) # 32767 is max value of 16 bit int.
    filename = 'frame_%s.wav' % (timestamp)
    log("wrting wav file : %s" % filename)
    with wave.open(filename, 'wb') as wf:
      wf.setnchannels(1)
      wf.setsampwidth(2) # 2bytes(16bit precision)
      wf.setframerate(self.sample_rate)
      wf.writeframes(arr.tobytes('C'))

  def on_close(self):
    self.transcribe()
    log("audio socket closed")

  def check_origin(self, origin):
    return True


class ControlPageHandler(tornado.web.RequestHandler):

  def get(self):
    host=self.request.host
    html = """
<html>
<div id="control" style="padding: 10px">
  <font color="#ffffff">
    <button id="start_transcribe"
            style="border: 0px; padding: 10px; border-radius: 10px; background-color: #00bfff; margin: 10px">
        START
    </button>
    <button id="stop_transcribe"
            style="border: 0px; padding: 10px; border-radius: 10px; background-color: #ff69b4">
        STOP
    </button>
  </font>
</div>

<script>
  var current_stream = null;
  var context = null;
  var ws = null;

  var upload = function(stream) {
    current_stream = stream;
    context = new AudioContext();
    ws = new WebSocket('wss://%s/websocket');

    console.log(context.sampleRate);

    var input = context.createMediaStreamSource(stream)
    var processor = context.createScriptProcessor(0, 1, 1);

    input.connect(processor);
    processor.connect(context.destination);

    processor.onaudioprocess = function(e) {
      var voice = e.inputBuffer.getChannelData(0);
      ws.send(voice.buffer);
    };
  };

  function start_transcribe(){
    navigator.mediaDevices.getUserMedia({ audio: true, video: false }).then(upload)
  }

  function stop_transcribe(){
    ws.close();
   current_stream.getTracks().forEach(track => track.stop());
    context.close();
  }

  document.getElementById("start_transcribe").onclick = start_transcribe;
  document.getElementById("stop_transcribe").onclick = stop_transcribe;

</script>
</html>
""" % host
    self.write(html)

  def check_origin(self, origin):
    return True

def main(cmd_args):
  parser = get_parser()
  args, _ = parser.parse_known_args(cmd_args)

  n_browser_frame_len, n_frame_len = calc_frame_size(args)

  log("Starting SSL server on port %d" % args.websocket_port)

  app = tornado.web.Application([(r"/websocket", WebSocketHandler, {
      'n_browser_frame_len': n_browser_frame_len,
      'n_frame_len': n_frame_len,
      'message_len': args.message_len,
      'sample_rate': args.sample_rate,
      'browser_sample_rate': args.browser_sample_rate,
      'dump_wav': args.dump_wav,
      'dump_npy': args.dump_npy
    }),
    (r"/control", ControlPageHandler)
  ])

  http_server = tornado.httpserver.HTTPServer(app, ssl_options={
    "certfile": args.cert_file,
    "keyfile" : args.key_file,
  })

  http_server.listen(args.websocket_port)
  tornado.ioloop.IOLoop.instance().start()

if __name__ == '__main__':
    main(sys.argv[1:])  
EOF
chmod 755 wsrecoder.py

トンネルの生成と録音サーバの起動

一通りの準備ができたので、以下のようにしてトンネルを作ります。セルから出力された URL がトンネルの入り口です (筆者が試した時はときどきエラーになりましたが、その場合も何度か試すとうまくいきました)。

# notebook-1
get_ipython().system_raw('./ngrok http https://localhost:8889 &')
! curl -s https://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

# https://761d5a9eff82.ngrok.io

続いて録音サーバを起動します。

# notebook-1
!./wsrecoder.py

録音サーバの起動後、新しくタブを開いて https://761d5a9eff82.ngrok.io にアクセスすると(実際のURLは上記のセル出力の値に直して下さい)以下のコントロール画面が表示されるので、"START", “STOP” で録音の開始/停止ができます。

recorder_control

ラベルの生成と音声の録音

ここまでで音声を録音する準備が整いましたが、音声コマンド認識の学習をするには音声とコマンドのペアが必要になります。 録音サーバを起動したノートブックとは別に、新たに別のノートブックを開いて以下のセルを実行します。このノートブックを “notebook-2” とします。

%%bash
# notebook-2
cat << EOF > mkdata.sh
#!/bin/bash
cnt=0
NUM_SAMPLES=100
NUM_CLASSES=8


CLASSES=("前進" "後退" "停止" "右" "左" "もっと右" "もっと左" "ターボブースト")

time while [ \$cnt -lt \$NUM_SAMPLES ]
do
  index=\`expr \$RANDOM \\* \$NUM_CLASSES / 32768\`
  echo "\$cnt \${CLASSES[\$index]}"
  echo "\${CLASSES[\$index]}" >> labels.txt
  cnt=\`expr \$cnt + 1\`
  sleep 2
done
EOF
chmod 755 mkdata.sh

先ほどのコントロール画面で “START” をクリックして録音を開始してから(この際、マイク使用の許可を求められるので「許可」をクリックして下さい)、 mkdata.sh を起動すると、以下のようにラベルが 2 秒間隔で出力されるので、「ぜんしん」、「こうたい」、と読み上げて音声を記録していきます。

# notebook-2
!./mkdata.sh

# 0 前進
# 1 後退
# 2 もっと左
# 3 右
# ...

mkdata.sh の出力は 100 サンプルで終了するようになっています。終わったら、コントロール画面の “STOP” をクリックして下さい。録音サーバを起動したノートブック(“notebook-1”)で、録音された音声が “frame_20200625052006.wav” のようなファイル名で一塊の wav ファイルに出力されます。

mkdata.sh を実行したノートブック(“notebook-2”)では読み上げたラベルが labels.txt として記録されています。

# notebook-2
!head -4 labels.txt
# 前進
# 後退
# もっと左
# 右

Colab の左端のメニューからフォルダのアイコンを選んで、labels.txt をダウンロード&アップロードで “notebook-2” から “notebook-1” に持っていきます。 以降は “notebook-1” で作業を行います。

wav ファイルの切り分け

録音サーバ(“wsrecoder.py”)が動いていると続きの作業ができないので、"wsrecoder.py" を実行したセル左側の停止ボタンで停止させておいて下さい。 ここからは 100 サンプルが一塊になった wav ファイルをラベル単位に切り分けていきます。まずは必要なライブラリをインポートします。

# notebook-1
import numpy as np
import IPython.display
from pydub import AudioSegment
from pydub.silence import split_on_silence

保存した wav からタイムスタンプを切り出します。

# notebook-1
ts=!ls frame*.wav | sed -e 's/frame_//' -e 's/.wav//'
ts=ts[0]
ts
# '20200625052006'

データセット用のディレクトリを作り、"notebook-2" から持ってきたラベルのファイル名に対応する音声のタイムスタンプを加えて、作成したディレクトリに移動します。

# notebook-1
!mkdir -p data
!mv labels.txt labels_{ts}.txt
!mv *{ts}* data
!ls data
# frame_20200625052006.npy  frame_20200625052006.wav  labels_20200625052006.txt

一塊の wav ファイルを無音部分で複数ファイルに分割し、ラベルテキストを読み込みます。

# notebook-1
sound = AudioSegment.from_file("data/frame_%s.wav" % ts , format="wav")
chunks = split_on_silence(sound, min_silence_len=500, silence_thresh=-40, keep_silence=400)
for i, chunk in enumerate(chunks):
    chunk.export("data/segment_" + str(i) +".wav", format="wav")
with open("data/labels_%s.txt" % ts, "r") as f:
    lines = f.readlines()
    labels = [line.strip() for line in lines]

ラベルの先頭 10 サンプルは以下のようになっています。

# notebook-1
labels[0:10]

# ['前進', '後退', 'もっと左', '右', '前進', 'ターボブースト', 'もっと左', '右', '左']

先頭のラベルと分割した wav ファイルを読み込みます。

# notebook-1
index=0
print("%2d : %s" % (index, labels[index]))
IPython.display.Audio("data/segment_%d.wav" % index)

# 0 : 右

ノートブックの画面上では読み込んだ wav の再生コントロールが表示されるので、再生ボタンをクリックしてラベルと音声が合致していることを確認しましょう。

confirm_segment

上記のセルを index の値を変えながらズレをチェックしていきます。全件チェックするのは大変なので 20 サンプル毎ぐらいで十分です。 ズレや誤りが見つかったら、その都度手作業で修正していきます。

例えば、出だしの index=0 の wav に咳払い等が入ってしまった場合は、以下のようにして後続の wav ファイルのインデックスを前に詰めます。

# notebook-1
%%bash
for i in `seq 1 100`
do
  mv data/segment_$i.wav data/segment_`expr $i - 1`.wav
done

また、ラベルが「右」なのに何故か「ひだり!」と発声してしまった場合などはラベルの方を書き換えたりもしました。

確認が終わった分割後の wav は対応するタイムスタンプでサブディレクトリを作って整理しておきます。

!mkdir -p data/{ts}
!mv data/segment* data/{ts}

筆者の場合は以上の処理を繰り返して(とりあえず) 8 セット程作成しました。録音と推論の環境に違いがなければ、これくらいで十分だと思います。 実際に JetBot を操縦する場合は意外とモーター音が大きいので、録音作業時の傍らで JetBot のモータを稼働させてモーター音混じりのサンプルも作っておいた方がよいです(というか作る必要がありました。 JetBot のモーター音の鳴らし方は後述します)。

以下のようにして、 NeMo が読み込める JSON 形式に加工します。

%%bash
declare -A commands
commands=( \
 ["前進"]="forward" \
 ["後退"]="backword" \
 ["右"]="right" \
 ["左"]="left" \
 ["もっと右"]="hard_right" \ 
 ["もっと左"]="hard_left" \
 ["停止"]="stop" \
 ["ターボブースト"]="turbo_boost" 
)
dataset="jetbot_asr.json"
rm -f $dataset
for labels in data/labels*.txt
do
  ts=`echo $labels | sed -e 's/^.*_//' -e 's/.txt//'`
  idx=0
  for label in `cat $labels`
  do
    command=${commands[$label]}
    wav=data/$ts/segment_$idx.wav
    duration=`sox ${wav} -n stat 2>&1 | awk -F " " '/^Length/{print $3}'`
    record="{\"audio_filepath\": \"$wav\", \"duration\": $duration, \"command\": \"$command\"}"
    echo $record >> $dataset
    idx=`expr $idx + 1`
  done
done

できあがったデータはこんな感じになります。今回はコマンド分類なので “command” 属性に正解ラベルを設定する形式になります。

!head -5 jetbot_asr.json

# {"audio_filepath": "data/20200625050543/segment_0.wav", "duration": 0.889000, "command": "left"}
# {"audio_filepath": "data/20200625050543/segment_1.wav", "duration": 0.842000, "command": "stop"}
# {"audio_filepath": "data/20200625050543/segment_2.wav", "duration": 0.978000, "command": "right"}
# {"audio_filepath": "data/20200625050543/segment_3.wav", "duration": 0.810000, "command": "turbo_boost"}
# {"audio_filepath": "data/20200625050543/segment_4.wav", "duration": 0.801000, "command": "hard_right"}

無音データの作成

次に無音部分のデータを作ります。推論時はマイクが拾う音声を連続的に推論してコマンド分類する訳ですが、 JetBot に何も話かけていない場合は「何もしなくてよい」というクラスに分類してもらう必要があります。

発話せずに上記の“notebook-1"の手順を実行して、雑音ありと雑音なしの 2 パターンを作りました。

!ls *wav

# frame_20200624045836.wav  frame_20200624050547.wav

雑音ありの wav を固定長で分割して with_noize フォルダに移動します。

!sox frame_20200624045836.wav split_.wav trim 0 1.2 : newfile : restart
!mkdir -p data/with_noize
!mv split_*wav data/with_noize

雑音なしの wav も同様に処理します。

!sox frame_20200624050547.wav split_.wav trim 0 1.2 : newfile : restart
!mkdir -p data/without_noize
!mv split_*wav data/without_noize

これらを JSON 化して発話ありのデータに追加します。

%%bash
dataset="jetbot_asr.json"
for idx in `seq -f "%03.0f" 1 50`
do
  wav=data/with_noize/split_$idx.wav
  duration=`sox ${wav} -n stat 2>&1 | awk -F " " '/^Length/{print $3}'`
  record="{\"audio_filepath\": \"$wav\", \"duration\": $duration, \"command\": \"void\"}"
  echo $record >> $dataset
  wav=data/without_noize/split_$idx.wav
  duration=`sox ${wav} -n stat 2>&1 | awk -F " " '/^Length/{print $3}'`
  record="{\"audio_filepath\": \"$wav\", \"duration\": $duration, \"command\": \"void\"}"
  echo $record >> $dataset
done

実際にデータを作ってみると、いくつか duration が取れていないデータが出来てしまいました。

!cat jetbot_asr.json | grep -e "duration\": ," 

# {"audio_filepath": "data/20200625050543/segment_19.wav", "duration": , "command": "turbo_boost"}
# {"audio_filepath": "data/20200625050543/segment_20.wav", "duration": , "command": "forward"}
# {"audio_filepath": "data/20200625050543/segment_21.wav", "duration": , "command": "right"}
# {"audio_filepath": "data/20200625050543/segment_22.wav", "duration": , "command": "hard_left"}
# {"audio_filepath": "data/20200625050543/segment_23.wav", "duration": , "command": "hard_left"}
# {"audio_filepath": "data/20200625050543/segment_24.wav", "duration": , "command": "hard_left"}
# {"audio_filepath": "data/20200625050543/segment_25.wav", "duration": , "command": "right"}
# {"audio_filepath": "data/20200625050543/segment_26.wav", "duration": , "command": "hard_right"}

これくらいの量なら大して影響なさそうなので、フィルタして落としてしまいます。

!cat jetbot_asr.json | grep -v "duration\": ," > jetbot_asr.json.filtered

中身をシャッフルして、

%%bash
get_seeded_random()
{
  seed="$1"
  openssl enc -aes-256-ctr -pass pass:"$seed" -nosalt \
    </dev/zero 2>/dev/null
}

sort --random-source=<(get_seeded_random 42) -R jetbot_asr.json.filtered > jetbot_asr.json.shuf

こんな感じになりました。

!head -5 jetbot_asr.json.shuf

# {"audio_filepath": "data/20200625050543/segment_18.wav", "duration": 1.007000, "command": "turbo_boost"}
# {"audio_filepath": "data/20200625050543/segment_1.wav", "duration": 0.842000, "command": "stop"}
# {"audio_filepath": "data/20200625050543/segment_13.wav", "duration": 0.906000, "command": "left"}
# {"audio_filepath": "data/20200625050543/segment_9.wav", "duration": 0.847000, "command": "left"}
# {"audio_filepath": "data/20200625050543/segment_5.wav", "duration": 0.845000, "command": "forward"}

実際の作業はアレコレやり直したり、苦手なコマンドのデータを追加したりと紆余曲折があって最終的には 1787 件のデータになってしまいました。。。

!wc -l *.json

# 1787 jetbot_asr.json
# 1787 total

最後に 8:1:1 で分割して、学習、検証、テストセットとします。

!head -1428 jetbot_asr.json.shuf > jetbot_asr_train.json
!tail -356 jetbot_asr.json.shuf | head -178 > jetbot_asr_dev.json
!tail -178 jetbot_asr.json.shuf > jetbot_asr_test.json
!wc -l *.json

# 1787 jetbot_asr.json
#  178 jetbot_asr_dev.json
#  178 jetbot_asr_test.json
# 1428 jetbot_asr_train.json
# 3571 total

作成したデータセットの GCS に保存しておきましょう。

from google.colab import auth
auth.authenticate_user()

!gsutil cp -r data gs://somewhere/jetbot_asr/
!gsutil cp *.josn data gs://somewhere/jetbot_asr/data/

ようやく音声分類データセットができたので、次はモデルの学習です。

3. 音声コマンド分類モデルの学習

ここからは前章で作成したデータセットを使って音声コマンド分類のモデルの学習を行います。 新たに Colab でノートブックを作って、ランタイムのアクセラレータは "GPU” にして下さい。

まずは必要なライブラリをインストールしていきます。

!apt-get update
!apt-get install sox swig pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev \
  libsndfile1-dev python-setuptools libboost-all-dev python-dev cmake

NeMo のコードも取得してインストールします。 詳しくは後述しますが、 JetBot に手軽にインストールできる PyTorch のバージョンとも絡むので少し前のコミットを使います。

!git clone https://github.com/NVIDIA/NeMo
!cd NeMo && git checkout 331fc4b

前回と同様の問題があるので一部修正をしてしまって。。。

!cp NeMo/nemo/collections/asr/parts/jasper.py jasper.py
!cat jasper.py |  sed -e '158s#/#//#g' > jasper.py.mod
!cp jasper.py.mod NeMo/nemo/collections/asr/parts/jasper.py 

!diff jasper.py NeMo/nemo/collections/asr/parts/jasper.py 
# 158c158
# <         ) / self.conv.stride[0] + 1
# ---
# >         ) // self.conv.stride[0] + 1

pip でインストールします。

!cd NeMo; pip install .
!cd NeMo; pip install .[asr]

前章で GCS に退避したデータセットも取ってきましょう。

from google.colab import auth
auth.authenticate_user()
!gsutil cp -r gs://somewhere/jetbot_asr/data/ .

コマンド分類のラベル(9クラス)を定義します。

labels = [
  "forward",
  "backword",
  "right",
  "left",
  "hard_right", 
  "hard_left",
  "stop",
  "turbo_boost",
  "void"
]

ここで、訓練データの各クラスの分布を見てみましょう(アレコレすったもんだした後のホントの最終形なので、前章で wc -l した結果と整合性とれてないかもしれません。。。)。

import json
train_set = []
with open("data/jetbot_asr_train.json", "r") as f:
  lines = f.readlines()
  for line in lines:
    train_set.append(json.loads(line))

counts = [0] * len(labels)

for example in train_set:
  counts[labels.index(example["command"])] += 1

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.font_manager import FontProperties
plt.style.use('seaborn-whitegrid')
%matplotlib inline

f, ax1 = plt.subplots(figsize=(12, 3), dpi=90)
fp = FontProperties(size=7)
ax1.bar(labels, counts, width=0.5)

プロットするとこんな感じです。私の滑舌のせいかもしれませんが「ひだり」が意外と苦手でした。

label_histgram

学習の実行

ここからは学習の実行です。学習自体はあっという間に終わった気がします。

import nemo
import nemo.collections.asr as nemo_asr

ファクトリを作って、

nf = nemo.core.NeuralModuleFactory(log_dir='./log', create_tb_writer=True)
tb_writer = nf.tb_writer

モデルの定義をロードします。前回は 15x5 のモデルを使いましたが、音声コマンド認識なので 3x1 の小規模なモデルで十分です。

from ruamel.yaml import YAML
yaml = YAML(typ="safe")
with open("./NeMo/examples/asr/configs/quartznet_speech_commands_3x1_v1.yaml") as f:
    config = yaml.load(f)
model_def = config["JasperEncoder"]

必要なクラスをインポートします。今回は分類問題なので使用するクラスが前回と一部、異なります。

from functools import partial
from nemo.collections.asr.helpers import process_evaluation_batch, process_evaluation_epoch
from nemo.collections.asr.helpers import monitor_classification_training_progress
from nemo.collections.asr.helpers import process_classification_evaluation_batch
from nemo.collections.asr.helpers import process_classification_evaluation_epoch

訓練サンプルの数を確認しておきましょう(やっぱり増えてますね。。。)。

num_train_samples = !(wc -l data/jetbot_asr_train.json | cut -f 1 -d " " )
num_train_samples = int(num_train_samples[0])
num_train_samples

# 1837

エポック数とバッチサイズはこんな感じにしました。

num_epochs       = 50
batch_size       = 32 

トータルステップ数を確認しておきます。

total_steps = num_train_samples * num_epochs // batch_size
total_steps 

# 2870

学習の最初の 10% をウォームアップにします。

warmup_steps = total_steps // 10
warmup_steps

# 287

その他のハイパーパラメータはこんな感じです。

batches_per_step = 1
dropout          = 0.2
lr               = 0.02
weight_decay     = 0.0001
sample_rate      = 16000
n_fft            = 512
output_dir       = "./output"
epochs_per_eval  = 1

今回も前回と同じくドロップアウトを適用してみます。

for i, layer in enumerate(model_def["jasper"]):
  layer['dropout'] = dropout

ここからはデータ入力レイヤの定義です。データレイヤのクラスには nemo_asr.AudioToSpeechLabelDataLayer を指定します。

train_dataset = "data/jetbot_asr_train.json"
val_dataset = "data/jetbot_asr_dev.json"

with open(train_dataset) as f:
  lines = f.readlines()
  num_train_samples = len(lines)
  steps_per_epoch = num_train_samples // batch_size

data_layer = nemo_asr.AudioToSpeechLabelDataLayer(
    manifest_filepath=train_dataset,
    labels=labels,
    sample_rate=sample_rate,
    batch_size=batch_size,
    shuffle=True
)

# [NeMo I 2020-07-23 07:19:13 collections:240] Filtered duration for loading collection is 0.000000.
# [NeMo I 2020-07-23 07:19:13 collections:243] # 1837 files loaded accounting to # 9 labels
# [NeMo I 2020-07-23 07:19:13 data_layer:961] # of classes :9  

data_layer_val = nemo_asr.AudioToSpeechLabelDataLayer(
    manifest_filepath=val_dataset,
    labels=labels,
    sample_rate=sample_rate,
    batch_size=batch_size,
    shuffle=False
)

# [NeMo I 2020-07-23 07:19:14 collections:240] Filtered duration for loading collection is 0.000000.
# [NeMo I 2020-07-23 07:19:14 collections:243] # 206 files loaded accounting to # 9 labels
# [NeMo I 2020-07-23 07:19:14 data_layer:961] # of classes :9

次に前処理の定義です。

data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
    n_fft=n_fft, sample_rate=sample_rate, features=model_def["feat_in"], stft_conv=True)

# [NeMo I 2020-07-23 07:19:17 features:144] PADDING: 16
# [NeMo I 2020-07-23 07:19:17 features:152] STFT using conv

spec_augment = nemo_asr.SpectrogramAugmentation(**config["SpectrogramAugmentation"])

エンコーダ、デコーダ、ロスを定義します。今回はコマンド分類なのでデコーダには nemo_asr.JasperDecoderForClassification を使います。 ロスも nemo_asr.CrossEntropyLossNM になります。

encoder = nemo_asr.JasperEncoder(**model_def)
decoder = nemo_asr.JasperDecoderForClassification(feat_in=model_def["jasper"][-1]['filters'], num_classes=len(labels),
    **config['JasperDecoderForClassification'],
)
ce_loss = nemo_asr.CrossEntropyLossNM()

学習グラフと検証グラフを定義して、

audio_signal, audio_signal_len, commands, command_len = data_layer()
processed_signal, processed_signal_len = data_preprocessor(input_signal=audio_signal, length=audio_signal_len)
processed_signal = spec_augment(input_spec=processed_signal)

encoded, encoded_len = encoder(audio_signal=processed_signal, length=processed_signal_len)
decoded = decoder(encoder_output=encoded)

loss = ce_loss(logits=decoded, labels=commands)


audio_signal_v, audio_signal_len_v, commands_v, command_len_v = data_layer_val()
processed_signal_v, processed_signal_len_v = data_preprocessor(input_signal=audio_signal_v, length=audio_signal_len_v)

encoded_v, encoded_len_v = encoder(audio_signal=processed_signal_v, length=processed_signal_len_v)
decoded_v = decoder(encoder_output=encoded_v)

loss_v = ce_loss(logits=decoded_v, labels=commands_v)

進捗、チェックポイント、検証の各種コールバック関数を定義します。

logger_cb = nemo.core.SimpleLossLoggerCallback(
    tb_writer=tb_writer,
    tensors=[loss, decoded, commands],
    print_func=partial(monitor_classification_training_progress, eval_metric=None),
    get_tb_values=lambda x: [("loss", x[0])],)

ckpt_cb =  nemo.core.CheckpointCallback(folder=output_dir, step_freq=steps_per_epoch)

eval_cb = nemo.core.EvaluatorCallback(
    eval_tensors=[loss_v, decoded_v, commands_v],
    user_iter_callback=partial(process_classification_evaluation_batch, top_k=1),
    user_epochs_done_callback=partial(process_classification_evaluation_epoch, eval_metric=1, tag="DEV"),
    eval_step=steps_per_epoch * epochs_per_eval,
    tb_writer=tb_writer,
)
callbacks = [logger_cb, ckpt_cb, eval_cb]

ようやくですが、学習ループを回します。

optimization_params={"num_epochs": num_epochs, "lr": lr, "weight_decay": weight_decay, "momentum": 0.95, "betas": (0.98, 0.5), "grad_norm_clip": None}
nf.train(tensors_to_optimize=[loss], callbacks=callbacks, optimizer="novograd",
             optimization_params=optimization_params, batches_per_step=batches_per_step)

# [NeMo I 2020-07-23 07:20:02 deprecated_callbacks:195] Starting .....
# [NeMo I 2020-07-23 07:20:02 callbacks:534] Found 2 modules with weights:
...
# [NeMo I 2020-07-23 07:22:20 callbacks:465] Saved checkpoint: ./output/trainer-STEP-2900.pt
# [NeMo I 2020-07-23 07:22:20 deprecated_callbacks:339] Final Evaluation ..............................
# [NeMo I 2020-07-23 07:22:20 helpers:271] ==========>>>>>>Evaluation Loss DEV: 0.388
# [NeMo I 2020-07-23 07:22:20 helpers:273] ==========>>>>>>Evaluation Accuracy Top@1 DEV: 96.1165
# [NeMo I 2020-07-23 07:22:20 deprecated_callbacks:344] Evaluation time: 0.16168832778930664 seconds

検証セットでの正答率が 96% なので悪くなさそうです。

推論の実行

学習済みモデルができたのでテストセットで推論してみましょう。

今度はテストセットでデータレイヤを定義して、

test_dataset = 'data/jetbot_asr_test.json'

import json
test_set = []
with open("data/jetbot_asr_test.json", "r") as f:
  lines = f.readlines()
  for line in lines:
    test_set.append(json.loads(line))

data_layer_test = nemo_asr.AudioToSpeechLabelDataLayer(
    manifest_filepath=test_dataset,
    labels=labels,
    sample_rate=sample_rate,
    batch_size=batch_size,
    shuffle=False
)

# [NeMo I 2020-07-23 07:22:49 collections:240] Filtered duration for loading collection is 0.000000.
# [NeMo I 2020-07-23 07:22:49 collections:243] # 206 files loaded accounting to # 9 labels
# [NeMo I 2020-07-23 07:22:49 data_layer:961] # of classes :9

推論グラフを構築します。

audio_signal_t, audio_signal_len_t, commands_t, command_len_t = data_layer_test()
processed_signal_t, processed_signal_len_t = data_preprocessor(input_signal=audio_signal_t, length=audio_signal_len_t)

encoded_t, encoded_len_t = encoder(audio_signal=processed_signal_t, length=processed_signal_len_t)
decoded_t = decoder(encoder_output=encoded_t)

loss_t = ce_loss(logits=decoded_t, labels=commands_t)

推論の実行はこんな感じで行います。

evaluated_tensors = nf.infer(tensors=[loss_t, decoded_t, commands_t], checkpoint_dir="./output")

# [NeMo I 2020-07-23 07:22:53 actions:1574] Restoring JasperEncoder from ./output/JasperEncoder-STEP-2900.pt
# [NeMo I 2020-07-23 07:22:53 actions:1574] Restoring JasperDecoderForClassification from ./output/JasperDecoderForClassification-STEP-2900.pt
# [NeMo I 2020-07-23 07:22:53 actions:695] Evaluating batch 0 out of 7
# ...
# [NeMo I 2020-07-23 07:22:53 actions:695] Evaluating batch 6 out of 7

精度を確認します。

from nemo.collections.asr.metrics import classification_accuracy

correct_count = 0
total_count = 0

for batch_idx, (logits, labels) in enumerate(zip(evaluated_tensors[1], evaluated_tensors[2])):
    acc = classification_accuracy(
        logits=logits,
        targets=labels,
        top_k=[1]
    )
    acc = acc[0]
    correct_count += int(acc * logits.size(0))
    total_count += logits.size(0)

print(f"Total correct / Total count : {correct_count} / {total_count}")
print(f"Final accuracy : {correct_count / float(total_count)}")

# Total correct / Total count : 199 / 206
# Final accuracy : 0.9660194174757282

こちらも 96.6% の正答率です。 JetBot を操作するには十分な感じですが、もう少し詳しく見てみましょう6

import torch
import librosa
import json
import IPython.display as ipd

class ReverseMapLabel:
    def __init__(self, data_layer: nemo_asr.AudioToSpeechLabelDataLayer):
        self.label2id = dict(data_layer._dataset.label2id)
        self.id2label = dict(data_layer._dataset.id2label)

    def __call__(self, pred_idx, label_idx):
        return self.id2label[pred_idx], self.id2label[label_idx]

sample_idx = 0
incorrect_preds = []
rev_map = ReverseMapLabel(data_layer_test)

# Remember, evaluated_tensor = (loss, logits, labels)
for batch_idx, (logits, labels) in enumerate(zip(evaluated_tensors[1], evaluated_tensors[2])):
    probs = torch.softmax(logits, dim=-1)
    probas, preds = torch.max(probs, dim=-1)

    incorrect_ids = (preds != labels).nonzero()
    for idx in incorrect_ids:
        proba = float(probas[idx][0])
        pred = int(preds[idx][0])
        label = int(labels[idx][0])
        idx = int(idx[0]) + sample_idx

        incorrect_preds.append((idx, *rev_map(pred, label), proba))

    sample_idx += labels.size(0)

print(f"Num test samples : {total_count}")
print(f"Num errors : {len(incorrect_preds)}")

incorrect_preds = sorted(incorrect_preds, key=lambda x: x[-1], reverse=False)

# Num test samples : 206
# Num errors : 7

7つ程、不正解がありました(出力は [id, pred, label, prob]です)。

for incorrect_sample in incorrect_preds:
    print(str(incorrect_sample))
# (28, 'turbo_boost', 'left', 0.47237205505371094)
# (62, 'turbo_boost', 'hard_right', 0.48559606075286865)
# (124, 'right', 'forward', 0.9396188259124756)
# (18, 'hard_right', 'backword', 0.9929646253585815)
# (126, 'stop', 'backword', 0.9997004270553589)
# (147, 'left', 'turbo_boost', 0.9997300505638123)
# (106, 'stop', 'turbo_boost', 0.9999876022338867)

試しに幾つか確認してみました。

check_incorrect

画像なので再生できませんが、id=28 は「たーぼぶーすと」、id=106 は「ていし」と発話しています。。。 どうもラベルの方が間違ってるようですね。無音部分で分割した wav とラベルの整合性チェックは結構粗く行ったので、取りこぼしがあったのかもしれません。 とはいえ、概ね正しくコマンドを認識できているようですので、学習結果を GCS に保存しておきましょう。

!tar zcvf output.tar.gz output
!gsutil cp output.tar.gz gs://somewhere/jetbot_asr/

ようやくですが(3回目?)、このモデルを JetBot にデプロイして走行試験をしてみましょう。

4. JetBot での走行試験

ここからは、お手元に JetBot がある想定です。いまから自分で作る気概のある人はオブジェクトの広場の「 JetBot を動かしてみよう 」の記事7を参考にして下さい。

JetBot のセットアップ

JetBot の OS イメージは JetBot の Wiki 8 から入手できます。ただし NeMo を動かすには PyTorch のバージョンが古過ぎるので、そのあたりを入れ替えていく必要があります。

まずは上記のサイトから jetbot_image_v0p4p0.zip を入手し、Echer 9 を使って microSD カードに書き込んでください。イメージを書き込んだ microSD カードを JetBot のスロットに挿入し、

microsd_slot

HDMI のスロットにモニター、USB のスロットにマウス、キーボードを接続します。準備ができたらUSBバッテリーにケーブルを接続して JetBot を起動します。 ここからは普通の Ubuntu の操作とあまり変わりありません。

さて、執筆時点で最新の NeMo では PyTorch 1.6 以上ということになっているのですが、先ほど microSD に焼き込んだ jetbot_image_v0p4p0.zip は JetPack 4.3 がベースになっています。こちら10のサイトで確認したところ、 JetPack 4.3 向けの PyTorch は 1.4 が最新のようですので、PyTorch をこちらに入れ替え、NeMo のバージョンを少し落とすことにしました。

ここからは、 JetBot のデスクトップを操作して作業します。 まずは WiFi に接続して下さい。走行時は PC のマイクで拾った音声を WiFi 経由で JetBot に送信する形になるので、起動時に WiFi に自動接続する設定にしておきましょう。

次に Terminal を開いてセットアップ作業をしていきます。ちなみに、ここからの作業は USBバッテリー駆動で低電力モードで動いているせいか、かなり時間がかかりました。。。まずは、 PyTorch を 1.4 に入れ替えましょう。

jetbot@jetson-4-3:~$ wget https://nvidia.box.com/shared/static/ncgzus5o23uck9i5oth2n8n06k340l6k.whl -O torch-1.4.0-cp36-cp36m-linux_aarch64.whl
jetbot@jetson-4-3:~$ sudo apt-get install python3-pip libopenblas-base libopenmpi-dev 
jetbot@jetson-4-3:~$ sudo pip3 install Cython
jetbot@jetson-4-3:~$ sudo pip3 install numpy torch-1.4.0-cp36-cp36m-linux_aarch64.whl

つづいて NeMo のインストールです。走行時はモニタ、キーボード、マウスを外して WiFi 経由で JetBot 上で動作する Jupyter Lab に接続して作業するので、Jupyter Lab のデフォルトのフォルダ(/home/jetbot/Notebooks)で作業します。

まずは コミット 331fc4b をチェックアウトします。ちゃんと PyTorch 1.4 がサポートされたバージョンです。

jetbot@jetson-4-3:~$ cd /home/jetbot/Notebooks
jetbot@jetson-4-3:~/Notebooks$ sudo apt-get install sox swig pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev \
  libsndfile1-dev python-setuptools libboost-all-dev python-dev cmake
jetbot@jetson-4-3:~/Notebooks$ git clone https://github.com/NVIDIA/NeMo
jetbot@jetson-4-3:~/Notebooks$ cd NeMo
jetbot@jetson-4-3:~/Notebooks/NeMo$ git checkout 331fc4b
jetbot@jetson-4-3:~/Notebooks/NeMo$ grep -n -A 5 Requirement README.rst
# 81:**Requirements**
# 82-
# 83-1) Python 3.6 or 3.7
# 84-2) PyTorch 1.4.* with GPU support
# 85-3) (optional, for best performance) NVIDIA APEX. Install from here: https://github.com/NVIDIA/apex

さらに必要なライブラリと NeMo をインストールします11

jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo pip3 uninstall enum34
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get install liblapack-dev
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get install gfortran
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get install protobuf-compiler libprotobuf-dev
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo pip3 install numpy==1.19.0
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo pip3 install .

続いて asr モジュールをインストールしたかったのですが、libffi や llvm-config がないと怒られるのでインストールします。

jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get install libffi-dev
jetbot@jetson-4-3:~/Notebooks/NeMo$ wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add -
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get -y install clang-7 lldb-7 lld-7
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get -y install libllvm-7-ocaml-dev libllvm7 llvm-7 llvm-7-dev \
  llvm-7-doc llvm-7-examples llvm-7-runtime
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo ln -s /usr/lib/llvm-7/bin/llvm-config /usr/bin/llvm-config

改めて asr モジュールをインストール

jetbot@jetson-4-3:~/Notebooks/NeMo$ export LLVM_CONFIG=/usr/bin/llvm-config
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo pip3 install .[asr]

NeMo がインストールできたので、一つ上のディレクトリに動いておきます。あと他にも入れておくものがありました。

jetbot@jetson-4-3:~/Notebooks/NeMo$ cd ..
jetbot@jetson-4-3:~/Notebooks$ sudo pip install configargparse
jetbot@jetson-4-3:~/Notebooks$ sudo pip install samplerate

分類対象のクラスをテキストに出力しておきます。モデルが出力するテンソルのインデックスに対応する行をこのファイルから拾ってクラスの文字列表現にするので、順番は重要です。

jetbot@jetson-4-3:~/Notebooks$ cat << EOF > labels_9class.txt
forward
backword
right
left
hard_right
hard_left
stop
turbo_boost
void
EOF

次に前章で保存した学習済みモデルを GCS からダウンロードします12

jetbot@jetson-4-3:~/Notebooks$ gsutil cp gs://somewhere/jetbot_asr/output.tar.gz .
jetbot@jetson-4-3:~/Notebooks$ tar zxvf output.tar.gz

WSS 接続に使う鍵と証明書も改めて作ります。

jetbot@jetson-4-3:~/Notebooks$ mkdir ssl
jetbot@jetson-4-3:~/Notebooks$ openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout "./ssl/cert.key" -out "./ssl/cert.pem" -batch

最後にストリーミングで音声認識をして JetBot を制御するサーバのロジックです。長いので Jupyter Lab でノートブックを開いてセルから出力しました( JetBot からは https://localhost:8888/lab でアクセスできます)。デフォルトのフォルダにノートブックを作ったので、 jetbot_asr_server.py/home/jetbot/Notebooks に出力される想定です。

%%bash 
cat <<EOF > jetbot_asr_server.py
#!/usr/bin/env python3
import sys
import os
import math
import configargparse
import tornado.ioloop
import tornado.web
import tornado.websocket
import wave
import numpy as np
import datetime
import nemo
import nemo.collections.asr as nemo_asr
from ruamel.yaml import YAML
from nemo.backends.pytorch.nm import DataLayerNM
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType
import torch
import samplerate
import wave
import json

from jetbot import Robot

"""
class Robot:

  def forward(self, speed):
    print("Robot: forward : %s" % (speed))

  def backward(self, speed):
    print("Robot: backward : %s" % (speed))

  def set_motors(self, left, right):
    print("Robot: set_motors : %s, %s" % (left, right))

  def stop(self):
    print("Robot: stop")
"""
# How to make server key and certificate.
# openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout "cert.key" -out "cert.pem" -batch

def log(message):
    print(message, flush=True)

def get_parser(parser=None, required=True):
  if parser is None:
    parser = configargparse.ArgumentParser(
      description='',
      config_file_parser_class=configargparse.YAMLConfigFileParser,
      formatter_class=configargparse.ArgumentDefaultsHelpFormatter)

  parser.add('--config', is_config_file=True, help='config file path')
  parser.add_argument('--labels', default="./labels_9class.txt", type=str, help='List of label(class).')
  parser.add_argument('--yaml_config', default="./NeMo/examples/asr/configs/quartznet_speech_commands_3x1_v1.yaml",
                      type=str, help='YAML config.')
  parser.add_argument('--encoder_ckpt', default="./output/JasperEncoder-STEP-2900.pt",
                      type=str, help="path to encoder checkpoint.")
  parser.add_argument('--decoder_ckpt', default="./output/JasperDecoderForClassification-STEP-2900.pt",
                      type=str, help="path to decoder checkpoint.")
  parser.add_argument('--nf_log_dir', default="./log", type=str, help="log directory for nefaul factory.")
  parser.add_argument('--sample_rate', default=16000, type=int, help="Server side sample rate.")
  parser.add_argument('--browser_sample_rate', default=16000, type=int, help="Client side sample rate.")
  parser.add_argument('--cert_file', default="./ssl/cert.pem", type=str, help="SSL cert file.")
  parser.add_argument('--key_file', default="./ssl/cert.key", type=str, help="SSL key file.")
  parser.add_argument('--websocket_port', default=8889, type=int, help="WebSocket port.")
  parser.add_argument('--frame_len', default=0.2, type=float, help="frame length (sec)")
  parser.add_argument('--frame_overlap', default=0.5, type=float, help="frame overlap length before and after.(sec)")
  parser.add_argument('--message_len', default=2048, type=int, help="message size of websocket upload.")
  parser.add_argument('--window_stride', default=0.01, type=float, help="window stride of fft.")
  parser.add_argument('--dump_wav', default=0, type=int, help="whether to dump frame as wav.")
  parser.add_argument('--dump_npy', default=0, type=int, help="whether to dump frame as npy.")
  parser.add_argument('--confidence_thresh', default=0.80, type=float, help="confidence threshold of detection.")
  return parser

def calc_frame_size(args):
  log("frame_len(sec) = %f" % (args.frame_len))
  log("frame_overlap(sec) = %f" % (args.frame_overlap))
  log("browser_sample_rate = %d" % (args.browser_sample_rate))
  log("message_len = %d" % (args.message_len))
  log("sample_rate = %d" % (args.sample_rate))
  n_frame_len = int(args.frame_len * args.browser_sample_rate)
  if n_frame_len % args.message_len != 0:
    n_frame_len = (n_frame_len // args.message_len + 1) * args.message_len
  n_frame_overlap = int(args.frame_overlap * args.browser_sample_rate)
  log("n_frame_len = %d" % (n_frame_len))
  log("n_frame_overlap = %d" % (n_frame_overlap))

  return n_frame_len, n_frame_overlap

def load_labels(label):
  log("Loading labels.")
  with open(label, "r") as f:
      lines = f.readlines()
      labels = [line.strip() for line in lines]
  return labels

def load_model_def(yaml_config):
  log("Loading 3x1 model definition.")
  yaml = YAML(typ="safe")
  with open(yaml_config) as f:
    config = yaml.load(f)
  model_def = config["JasperEncoder"]
  return model_def

def build_neural_factory(model_def, labels, encoder_ckpt, decoder_ckpt, nf_log_dir, sample_rate, window_stride, n_fft=512):

  log("Creating NeuralModuleFactory.")
  nf = nemo.core.NeuralModuleFactory(log_dir=nf_log_dir)

  data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(dither=0.0, pad_to=0,
    n_fft=n_fft, sample_rate=sample_rate, window_stride=window_stride, features=model_def["feat_in"], stft_conv=True)

  log("Creating AudioDataLayer.")
  class AudioDataLayer(DataLayerNM):

    @property
    def output_ports(self):
      return {
        'audio_signal': NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)),
        'a_sig_length': NeuralType(tuple('B'), LengthsType()),
      }

    def __init__(self, sample_rate):
      super().__init__()
      self._sample_rate = sample_rate
      self.has_next = True

    def __iter__(self):
      return self

    def __next__(self):
      if not self.has_next:
        raise StopIteration
      self.has_next = False
      return torch.as_tensor(self.signal, dtype=torch.float32), torch.as_tensor(self.signal_shape, dtype=torch.int64)

    def set_signal(self, signal):
      self.signal = np.reshape(signal, [1, -1])
      self.signal_shape = np.expand_dims(self.signal.size, 0).astype(np.int64)
      self.has_next = True

    def __len__(self):
      return 1

    @property
    def dataset(self):
      return None

    @property
    def data_iterator(self):
      return self

  data_layer = AudioDataLayer(sample_rate=sample_rate)

  log("Creating encoder, decoder, greedy_decoder.")
  encoder = nemo_asr.JasperEncoder(**model_def)
  decoder = nemo_asr.JasperDecoderForClassification(feat_in=model_def["jasper"][-1]['filters'], 
              num_classes=len(labels), return_logits=True, pooling_type='avg')
  ce_loss = nemo_asr.CrossEntropyLossNM()

  log("Restoring parameters from the checkpoints.")
  encoder.restore_from(encoder_ckpt)
  decoder.restore_from(decoder_ckpt)

  log("Building DAG.")
  audio_signal, audio_signal_len = data_layer()
  processed_signal, processed_signal_len = data_preprocessor(input_signal=audio_signal, length=audio_signal_len)

  encoded, encoded_len = encoder(audio_signal=processed_signal, length=processed_signal_len)
  decoded = decoder(encoder_output=encoded)

  log("Setting up infer_signal function.")
  def infer_signal(self, signal):
    data_layer.set_signal(signal)
    tensors = self.infer(tensors=[decoded], verbose=False)
    logits = tensors[0][0]
    probs = torch.softmax(logits, dim=-1)
    return probs 

  nf.infer_signal = infer_signal.__get__(nf)

  return nf

class FrameASR:
  def __init__(self, nf, model_def, labels, n_frame_len, n_frame_overlap,
               sample_rate, browser_sample_rate, confidence_thresh,
               dump_wav, dump_npy):

    log("Initialize FrameASR.")
    self.nf = nf
    self.labels = labels
    self.sample_rate = sample_rate
    self.browser_sample_rate = browser_sample_rate
    self.n_frame_len = n_frame_len
    self.n_frame_overlap = n_frame_overlap
    self.buffer = np.zeros(shape=2 * self.n_frame_overlap + self.n_frame_len, dtype=np.float32)
    self.confidence_thresh = confidence_thresh
    self.dump_wav = dump_wav
    self.dump_npy = dump_npy
    self.debug()
    self.reset()

  def debug(self):
    log("labels              = %s" % (self.labels))
    log("sample_rate         = %d" % (self.sample_rate))
    log("browser_sample_rate = %d" % (self.browser_sample_rate))
    log("n_frame_len         = %d" % (self.n_frame_len))
    log("n_frame_overlap     = %d" % (self.n_frame_overlap))
    log("asr buffer size     = %s" % (self.buffer.shape))
    log("confidence_thresh   = %s" % (self.confidence_thresh))

  def dump_as_wav(self, frame, timestamp, cls):
    arr = (frame * 32767).astype(np.int16) # 32767 is max value of 16 bit int.
    filename = 'frame_%s_predict_as_%s.wav' % (timestamp, cls)
    log("wrting wav file : %s" % filename)
    with wave.open(filename, 'wb') as wf:
      wf.setnchannels(1)
      wf.setsampwidth(2) # 2bytes(16bit precision)
      wf.setframerate(16000)
      wf.writeframes(arr.tobytes('C'))

  def _decode(self, frame):
    #log("len(frame) = %d" % (len(frame)))
    assert len(frame)==self.n_frame_len
    # append new frame to buffer and scroll forward 1 frame length.
    self.buffer[:-self.n_frame_len] = self.buffer[self.n_frame_len:]
    self.buffer[-self.n_frame_len:] = frame
    if self.sample_rate == self.browser_sample_rate :
      buffer = self.buffer
    else:
      rate = self.sample_rate/self.browser_sample_rate
      buffer = samplerate.resample(self.buffer, rate, 'sinc_best')
    probs = self.nf.infer_signal(buffer).cpu().numpy()[0]

    #self._debug(probs)
    idx = np.argmax(probs)

    if self.labels[idx] != "void" and probs[idx] >= self.confidence_thresh:
      timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
      if self.dump_npy:
        np.save('frame_%s_predict_as_%s' % (timestamp, self.labels[idx]), resampled)
      if self.dump_wav:
        self.dump_as_wav(resampled, timestamp, self.labels[idx])

    return idx, self.labels[idx], probs[idx]

  def _debug(self, probs):
    predict=""
    for i, cls in enumerate(self.labels):
      predict += "%s=%5.2f, " % (cls, probs[i])
    predict = predict[:-1] 
    log(predict)

  def detect(self, frame=None):
    if frame is None:
      frame = np.zeros(shape=self.n_frame_len, dtype=np.float32)
    if len(frame) < self.n_frame_len:
      frame = np.pad(frame, [0, self.n_frame_len - len(frame)], 'constant')
    return self._decode(frame)

  def reset(self):
    self.buffer=np.zeros(shape=self.buffer.shape, dtype=np.float32)

class WebSocketHandler(tornado.websocket.WebSocketHandler):

  def initialize(self, n_frame_len, message_len, browser_sample_rate,
                   asr, confidence_thresh,  robot):
    log("Initialize WebSocket handler.")
    self.message_len = message_len
    self.n_frame_len = n_frame_len
    self.browser_sample_rate = browser_sample_rate
    self.n_messages_per_frame = n_frame_len // message_len
    log("browser_sample_rate : %d" % (browser_sample_rate))
    log("n_frame_len : %d" % (self.n_frame_len))
    log("n_messages_per_frame : %d" % (self.n_messages_per_frame))
    self.asr = asr
    self.confidence_thresh = confidence_thresh
    self.robot = robot
    self.pos = 0

  def open(self):
    self.buffer = []
    log("audio socket opened")

  def on_message(self, message):
    message = np.frombuffer(message, dtype='float32')
    #log("on message : size=%d, buffer size=%d" % (len(message), len(self.buffer)))
    assert len(message)==self.message_len
    self.buffer.append(message)
    if len(self.buffer) >= self.n_messages_per_frame:
      self.detect()

  def detect(self):
    frame = np.array(self.buffer).flatten()
    #log("frame length = %s" % (frame.shape))

    # Invoke ASR
    timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
    idx, cls, prob = self.asr.detect(frame)

    if cls != "void" and prob >= self.confidence_thresh:
      log("detected : %s(%d) = %5.2f" % (cls, idx, prob))
      self.write_message("%s : command=%s, probability=%5.2f\n" 
        % (timestamp, cls, prob))
      self.run_command(cls)

    self.buffer.clear()

  def run_command(self, command):
    if command == 'forward':
      #self.robot.forward(0.3)
      self.robot.set_motors(0.3, 0.28)
    elif command == 'right':
      self.robot.set_motors(0.3, 0.2)
    elif command == 'left':
      self.robot.set_motors(0.2, 0.3)
    elif command == 'hard_right':
      self.robot.set_motors(0.3, 0.1)
    elif command == 'hard_left':
      self.robot.set_motors(0.1, 0.3)
    elif command == 'stop':
      self.robot.stop()
    elif command == 'backword':
      #self.robot.backward(0.3)
      self.robot.set_motors(-0.3, -0.28)
    elif command == 'turbo_boost':
      self.robot.forward(1.0)

  def on_close(self):
    self.detect()
    log("audio socket closed")

  def check_origin(self, origin):
    return True

class ControlPageHandler(tornado.web.RequestHandler):

  def initialize(self, browser_sample_rate):
    log("Initialize control page handler.")
    self.browser_sample_rate = browser_sample_rate

  def get(self):
    host=self.request.host
    html = """
<html>
<div id="control" style="padding: 10px">
  <font color="#ffffff">
    <button id="start_transcribe"
            style="border: 0px; padding: 10px; border-radius: 10px; background-color: #00bfff; margin: 10px">
        START
    </button>
    <button id="stop_transcribe"
            style="border: 0px; padding: 10px; border-radius: 10px; background-color: #ff69b4">
        STOP
    </button>
  </font>
</div>
<div>
  <textarea id="transcribed" cols="120" rows="8"></textarea>
</div>

<script>
  var current_stream = null;
  var context = null;
  var ws = null;
  var textarea = document.getElementById("transcribed")
  var upload = function(stream) {
    current_stream = stream;
    context = new AudioContext();
    context = new AudioContext({ sampleRate: %d });
    ws = new WebSocket('wss://%s/websocket');

    ws.onmessage = function(event) {
      textarea.value = textarea.value + event.data;
      textarea.scrollTop = textarea.scrollHeight;
    };

    console.log(context.sampleRate);

    var input = context.createMediaStreamSource(stream)
    var processor = context.createScriptProcessor(0, 1, 1);

    input.connect(processor);
    processor.connect(context.destination);

    processor.onaudioprocess = function(e) {
      var voice = e.inputBuffer.getChannelData(0);
      ws.send(voice.buffer);
    };
  };

  function start_transcribe(){
    navigator.mediaDevices.getUserMedia({ audio: true, video: false }).then(upload)
  }

  function stop_transcribe(){
    ws.close();
   current_stream.getTracks().forEach(track => track.stop());
    context.close();
  }

  document.getElementById("start_transcribe").onclick = start_transcribe;
  document.getElementById("stop_transcribe").onclick = stop_transcribe;

</script>
</html>
""" % (self. browser_sample_rate, host)
    self.write(html)

  def check_origin(self, origin):
    return True  


def main(cmd_args):
  parser = get_parser()
  args, _ = parser.parse_known_args(cmd_args)

  n_frame_len, n_frame_overlap = calc_frame_size(args)

  labels = load_labels(label=args.labels)

  model_def = load_model_def(yaml_config=args.yaml_config)

  nf = build_neural_factory(model_def, labels,
    encoder_ckpt=args.encoder_ckpt, decoder_ckpt=args.decoder_ckpt,
    nf_log_dir=args.nf_log_dir,
    sample_rate=args.sample_rate,
    window_stride=args.window_stride
  )

  asr = FrameASR(nf, model_def, labels,
                 n_frame_len=n_frame_len, n_frame_overlap=n_frame_overlap,
                 sample_rate=args.sample_rate, 
                 browser_sample_rate=args.browser_sample_rate,
                 confidence_thresh=args.confidence_thresh,
                 dump_wav=args.dump_wav,
                 dump_npy=args.dump_npy)

  log("Starting SSL server on port %d" % args.websocket_port)

  robot = Robot()

  app = tornado.web.Application([(r"/websocket", WebSocketHandler, {
      'n_frame_len': n_frame_len,
      'message_len': args.message_len,
      'browser_sample_rate': args.browser_sample_rate,
      'asr': asr,
      'confidence_thresh': args.confidence_thresh,
      'robot':robot
    }),
    (r"/control", ControlPageHandler, {
      'browser_sample_rate': args.browser_sample_rate})
  ])
  http_server = tornado.httpserver.HTTPServer(app, ssl_options={
    "certfile": args.cert_file,
    "keyfile" : args.key_file,
  })

  http_server.listen(args.websocket_port)
  tornado.ioloop.IOLoop.instance().start()

if __name__ == '__main__':
    main(sys.argv[1:])  
EOF
chmod 755 jetbot_asr_server.py

これで、ひととおりの準備が整いました。いよいよ走行試験です。 JetBot からモニタ、キーボード、マウスを取り外して、 走行させる場所までもっていきましょう。

走行試験

それでは走行試験を始めます。環境の設定としては マイクの付いた Windows PC と JetBot が同じ WiFi につながっているという状況です。 まずは、 JetBot の IP アドレスを確認しましょう。 WiFi に接続できていれば LED に表示がでているはずです。

jetbot_led

Windows PC の Chrome から JetBot 上で稼働している Jupyter Lab に接続します。筆者の環境では https://192.168.0.144:8888/lab ですね。 Jupyter Lab に接続できたら再び Terminal を開きます。

jetbot_terminal

Terminal から音声コマンド認識サーバを起動します。

jetbot@jetson-4-3:~$ cd Notebooks/
jetbot@jetson-4-3:~/Notebooks$ ./jetbot_asr_server.py
# ...
# frame_len(sec) = 0.200000
# frame_overlap(sec) = 0.500000
# browser_sample_rate = 16000
# message_len = 2048
# sample_rate = 16000
# n_frame_len = 4096
# n_frame_overlap = 8000
# Loading labels.
# Loading 3x1 model definition.
# Creating NeuralModuleFactory.
# ...
# Creating AudioDataLayer.
# Creating encoder, decoder, greedy_decoder.
# Restoring parameters from the checkpoints.
# Building DAG.
# Setting up infer_signal function.
# Initialize FrameASR.
# labels              = ['forward', 'backword', 'right', 'left', 'hard_right', 'hard_left', 'stop', 'turbo_boost', 'void']
# sample_rate         = 16000
# browser_sample_rate = 16000
# n_frame_len         = 4096
# n_frame_overlap     = 8000
# asr buffer size     = 20096
# confidence_thresh   = 0.8
# Starting SSL server on port 8889

簡単に動きを説明すると、PC のマイクからサンプリングレート 16KHz で拾った音声を 2048 サンプルを 1 メッセージとして連続的に WebSocket で JetBot に送り続けており、 JetBot 側はそれを約1.2秒分のバッファに貯めていて 2 メッセージ(=4096サンプル)受信する毎に直近の1.2秒分を音声コマンド認識に投入する動きとなります。言い換えると、約1.2秒(8000*2+4096)のウインドウを約0.25秒(=4096/16000)のストライドで動かしながらコマンド認識していることになります13

音声コマンド認識サーバが起動したら Chrome のタブをもう一枚開けて、音声コマンド認識サーバのコントロール画面( https://192.168.0.144:8889/control )を表示します。

録音サーバの時と同様に “START” をクリック、マイクへのアクセスを要求されるので許可して下さい。

あとは PC に向かって、「ぜんしん」、「みぎ」、「こうたい」と発話するとコントロール画面のテキストエリアに認識結果が表示され、JetBot が動き始めます。

jetbot_control

実際の様子はこんな感じになります。だいたい意図したとおりに動かせてますね。

動作中の JetBot の状況は jtop で確認できます。動作中はこんな感じになってました。

jtop

5. おまけ

「後述します」として、まだ書いてない話が残っていましたね。

データセット作成時に JetBot のモータ音を鳴らす

これは JetBot の Jupyter Lab に接続して、/home/jetbot/Notebooks/basic_motion/basic_motion.ipynb を開いてください。 Robot クラスを使って JetBot のモータを制御できるので適当に駆動させながら、録音作業を実施します。

Colab で音声認識サーバを動かす

こちらは3章の学習に使った Colab のノートブックに2章の要領で ngrok をセットアップします。そこで jetbot_asr_server.py を実行して ngrok 経由でアクセスすれば OK です。ただし jetbot_asr_server.py を編集して Robot クラスをコメントになっているダミーコードに入れ替えておいて下さい。

6. おわりに

今回は「Jetson 触ってみたい」という理由で音声コマンド認識に挑戦してみましたが、やはり物理的にモノがあると意外と手間がかかちゃいましたね。 次回は少し前の話になりますが、BERT の事前学習を改良した ELECTRA を試してみたいと思います。 BERT の改良は他にもいろいろあるのですが、 Google のコードなので、Colab の無料 TPU で回せそうですし。いつも文章分類ばっかりなので、固有表現抽出などやってみます。


  1. https://pj.ninjal.ac.jp/corpus_center/csj/ 

  2. 「手元に JetBot なんてないよ!」という方の為にストリーミングでのコマンド認識そのものは Colab で動くようにしておきますのでご安心ください。 

  3. 私もこの記事の内容は自宅の環境で動かして書いています。「公開するのは会社のFWの内側でなく Colab だから。。。」、とかいう話もあるかも知れませんが、GCS や各種 Google API へのアクセス権を乗っ取られたりするとえらいことなので。 

  4. WSS で接続したい理由があったはずなのですが失念してしまいました。単純に接続できなかったからなのですが、接続先が手元の環境だったか Colab だったか。。。 

  5. 元々は先月の音声テキスト認識モデルをストリーミングで動作させるコードを切った貼ったしたモノなのでメソッド名とか不適切ですが見なかったことにしてください。。。 

  6. このコードは https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/02_Speech_Commands.ipynb から借用しています。Apache 2.0 ライセンスですね。 

  7. https://www.ogis-ri.co.jp/otc/hiroba/technical/lets-try-jetbot/ 最近はキットも販売してるみたいです。 

  8. https://github.com/NVIDIA-AI-IOT/jetbot/wiki/Software-Setup 

  9. 筆者は balenaEtcher-Setup-1.5.101.exe を落として使いました。 

  10. https://forums.developer.nvidia.com/t/pytorch-for-jetson-nano-version-1-6-0-now-available/72048 

  11. 一度に入れても大丈夫だと思うのですが、実際の作業は NeMo のインストール時にアレが足らない、コレが無いと怒られながらの作業だったので、その時の手順ママで表記しています。 

  12. しれっと gsutil でコピーした体で書いてますが、 JetBot で gsutils のセットアップをした記憶がありません。。。JetBot から Colab のノートブックを開いて、GCS からノートブックのランタイムにコピーして、ノートブックのダウンロード機能で落としてきたのかも。。。 

  13. nframeoverlap(=8000) という変数名や 8000*2 と2倍しているのが意味不明ですが、音声テキスト認識のサンプルを作ってから、コマンド認識に変更したので直しきれてないんですね。。。