今回も音声認識です。前回は QuartzNet と JSUT データセットで音声→テキスト変換の検証を行いました。今回は問題を音声コマンド認識に簡単化し、自前のデータセットを作成、PCのマイクに話かけてストリーミングでの推論を試します。ついでに JetBot の操縦をしてみましょう。
1. はじめに
今回も 前回 に続いて音声認識です。使用するモデルは前回と同じ QuartzNet ですので、前回を未読の方は 第10回 の記事に目を通して戻ってきて頂けると、より理解がしやすいと思います。前回は音声→テキスト変換の話だったのですが、音声認識モデルを作るとどうしてもマイクに話かけて認識されるかどうか試したくなります。ですが前回の JSUT コーパスは単一女性の声のみですので、私の声が認識できるはずもなく。CSJコーパス1あたりを購入して試しても良いのですが、この連載は読んだ人が試せるようにしたいので問題を音声コマンド認識にすることにしました。
音声コマンド認識なら自分の声を認識するくらいのデータセットは比較的手軽に作れそうです。なのですが、音声コマンド認識ですと認識した結果で「テレビがつく」とか何か動かないとつまらないです。「何かないかな?」と思っていたところ、社内の誰かが作って遊んで飽き…ではなく技術検証作業が完了した JetBot が稼働状態のまま会社のキャビネットの中で眠っていました。
QuartzNet は 15x5 のフルサイズモデルでも Jetson Nano でリアルタイムストリーミングでの処理が可能なので、音声コマンド認識用の小規模モデルなら USBバッテリー駆動でフルパワーがでてない JetBot でも動きそう。
そういう訳で今回は自前でデータセットを作成し、モデルを学習、JetBot にデプロイして操縦という流れで進めていきましょう2。
まずは、音声コマンド分類データセットの作成です。
2. 音声コマンド分類データセットの作成
JetBot を声で操作する為の音声コマンド分類のデータセットを作成します。環境としては マイクの付いた PC、Windows 10、Chrome、Colab です。
大まかな流れとして、以下のような作業になります。
- ラベル生成スクリプトが出力するラベルを読み上げ
- その声を PC のマイクで拾い、WebSocket 経由で Colab 上で起動した録音サーバに飛ばし、 wav ファイルとして保存
- 保存した wav を無音部分で分割し 1. で出力したラベルと突き合わせながら、必要に応じて修正
まず、Colab 上で起動した録音サーバに接続する為、ngrock の準備をします。
ngrock
ngrock は簡単にいうと NAT やファイアーウォールの内側で動作しているサーバをセキュアなトンネルを通して手軽にインターネットに公開するサービスです。 Colab のランタイムで動作しているサーバプロセスに直接アクセスすることはできないので、ngrock でトンネルを作ってあげる訳です。
「ローカルサーバをインターネットに公開できる」というのは「悪意のあるアクセスを NAT やファイウォールの内側に呼び込む可能性を作っている」というのことなので、そこは注意して下さい。この手のサービスの利用を社内ルールで禁止している会社もあろうかと思います。3。
まずは、 https://ngrok.com/ でサインアップを済ませて下さい。無料版で OK です。
サインアップしたら認証トークンの値を確認してください。
認証トークンの値が確認できたら、 Colab でノートブックを開きましょう。音声を拾って wav に書き出すだけなので、ランタイムのアクセラレータは None で構いません。
この章ではノートブックを複数並行で使用するので、このノートブックを “noteboook-1” とします。以後、コード例先頭に “# notebook-?” の行がある場合は当該のノートブックでそのセルを実行するものとします。
それでは、次に ngrok の実行コマンドを取得します。
# notebook-1 !wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip !unzip ngrok-stable-linux-amd64.zip
ngrok コマンドを使って、先ほど確認した認証トークンを設定しておきます。
# notebook-1 !./ngrok authtoken your_auth_token_here # Authtoken saved to configuration file: /root/.ngrok2/ngrok.yml
秘密鍵と証明書の生成
Chrome から録音サーバへの WebSocket 接続には HTTPS(WSS) を使用したので、 以下のようにして秘密鍵と証明書を生成しておきます。 セキュリティ的な強度とかは考慮しておらず、とりあえず WSS で接続したいが為にやっています4。
# notebook-1 !mkdir ssl !openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout "./ssl/cert.key" -out "./ssl/cert.pem" -batch
録音サーバ
録音サーバで使用するライブラリをインストールします。
# notebook-1 !pip install configargparse !pip install samplerate
録音サーバのコードは以下のとおりです5。
%%bash # notebook-1 cat <<EOF > wsrecoder.py #!/usr/bin/env python3 import sys import os import math import configargparse import tornado.ioloop import tornado.web import tornado.websocket import wave import numpy as np import datetime import samplerate # How to make server key and certificate. # openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout "cert.key" -out "cert.pem" -batch def log(message): print(message, flush=True) def get_parser(parser=None, required=True): if parser is None: parser = configargparse.ArgumentParser( description='', config_file_parser_class=configargparse.YAMLConfigFileParser, formatter_class=configargparse.ArgumentDefaultsHelpFormatter) parser.add('--config', is_config_file=True, help='config file path') parser.add_argument('--sample_rate', default=16000, type=int, help="Server side sample rate.") parser.add_argument('--browser_sample_rate', default=48000, type=int, help="Client side sample rate.") parser.add_argument('--cert_file', default="./ssl/cert.pem", type=str, help="SSL cert file.") parser.add_argument('--key_file', default="./ssl/cert.key", type=str, help="SSL key file.") parser.add_argument('--websocket_port', default=8889, type=int, help="WebSocket port.") parser.add_argument('--frame_len', default=210.0, type=float, help="frame length (sec)") parser.add_argument('--message_len', default=2048, type=int, help="message size of websocket upload.") parser.add_argument('--dump_wav', default=1, type=int, help="whether to dump frame as wav.") parser.add_argument('--dump_npy', default=1, type=int, help="whether to dump frame as npy.") return parser def calc_frame_size(args): log("frame_len(sec) = %f" % (args.frame_len)) log("browser_sample_rate = %d" % (args.browser_sample_rate)) log("message_len = %d" % (args.message_len)) n_browser_frame_len = int(args.frame_len * args.browser_sample_rate) log("Round frame_len times brower_sample_rate(%d) to a multiple of message_len(%d)" % (n_browser_frame_len, args.message_len)) n_browser_frame_len = args.message_len * (n_browser_frame_len // args.message_len) log("n_browser_frame_len = %d" % (n_browser_frame_len)) resample_rate = args.sample_rate / args.browser_sample_rate pseudo_browser_frame = np.zeros(n_browser_frame_len, dtype=np.float32) pseudo_resampled = samplerate.resample(pseudo_browser_frame, resample_rate, 'sinc_best') n_frame_len = pseudo_resampled.shape[0] log("sample_rate = %d" % (args.sample_rate)) log("resample rate = %f" % (resample_rate)) log("n_frame_len = %d" % (n_frame_len)) return n_browser_frame_len, n_frame_len class WebSocketHandler(tornado.websocket.WebSocketHandler): def initialize(self, n_browser_frame_len, n_frame_len, message_len, sample_rate, browser_sample_rate, dump_wav=False, dump_npy=False): log("Initialize WebSocket handler.") self.message_len = message_len self.n_browser_frame_len = n_browser_frame_len self.n_frame_len = n_frame_len self.sample_rate = sample_rate self.browser_sample_rate = browser_sample_rate assert n_browser_frame_len % message_len == 0 self.n_messages_per_frame = n_browser_frame_len // message_len log("sample_rate : %d" % (sample_rate)) log("browser_sample_rate : %d" % (browser_sample_rate)) log("n_frame_len : %d" % (self.n_frame_len)) log("n_browser_frame_len : %d" % (self.n_browser_frame_len)) log("n_messages_per_frame : %d" % (self.n_messages_per_frame)) self.dump_wav = dump_wav self.dump_npy = dump_npy self.pos = 0 def open(self): self.buffer = [] log("audio socket opened") def on_message(self, message): message = np.frombuffer(message, dtype='float32') log("on message : size=%d, buffer size=%d" % (len(message), len(self.buffer))) assert len(message)==self.message_len self.buffer.append(message) if len(self.buffer) >= self.n_messages_per_frame: self.transcribe() def transcribe(self): browser_frame = np.array(self.buffer).flatten() log("browsr frame length = %s" % (browser_frame.shape)) rate = self.sample_rate/self.browser_sample_rate log("resample rate = %f" % (rate)) frame = samplerate.resample(browser_frame, rate, 'sinc_best') log("resampled frame length = %s" % (frame.shape)) timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') if self.dump_npy: np.save("frame_%s" % (timestamp), frame) if self.dump_wav: self.dump_as_wav(frame, timestamp) self.buffer.clear() def dump_as_wav(self, frame, timestamp): arr = (frame * 32767).astype(np.int16) # 32767 is max value of 16 bit int. filename = 'frame_%s.wav' % (timestamp) log("wrting wav file : %s" % filename) with wave.open(filename, 'wb') as wf: wf.setnchannels(1) wf.setsampwidth(2) # 2bytes(16bit precision) wf.setframerate(self.sample_rate) wf.writeframes(arr.tobytes('C')) def on_close(self): self.transcribe() log("audio socket closed") def check_origin(self, origin): return True class ControlPageHandler(tornado.web.RequestHandler): def get(self): host=self.request.host html = """ <html> <div id="control" style="padding: 10px"> <font color="#ffffff"> <button id="start_transcribe" style="border: 0px; padding: 10px; border-radius: 10px; background-color: #00bfff; margin: 10px"> START </button> <button id="stop_transcribe" style="border: 0px; padding: 10px; border-radius: 10px; background-color: #ff69b4"> STOP </button> </font> </div> <script> var current_stream = null; var context = null; var ws = null; var upload = function(stream) { current_stream = stream; context = new AudioContext(); ws = new WebSocket('wss://%s/websocket'); console.log(context.sampleRate); var input = context.createMediaStreamSource(stream) var processor = context.createScriptProcessor(0, 1, 1); input.connect(processor); processor.connect(context.destination); processor.onaudioprocess = function(e) { var voice = e.inputBuffer.getChannelData(0); ws.send(voice.buffer); }; }; function start_transcribe(){ navigator.mediaDevices.getUserMedia({ audio: true, video: false }).then(upload) } function stop_transcribe(){ ws.close(); current_stream.getTracks().forEach(track => track.stop()); context.close(); } document.getElementById("start_transcribe").onclick = start_transcribe; document.getElementById("stop_transcribe").onclick = stop_transcribe; </script> </html> """ % host self.write(html) def check_origin(self, origin): return True def main(cmd_args): parser = get_parser() args, _ = parser.parse_known_args(cmd_args) n_browser_frame_len, n_frame_len = calc_frame_size(args) log("Starting SSL server on port %d" % args.websocket_port) app = tornado.web.Application([(r"/websocket", WebSocketHandler, { 'n_browser_frame_len': n_browser_frame_len, 'n_frame_len': n_frame_len, 'message_len': args.message_len, 'sample_rate': args.sample_rate, 'browser_sample_rate': args.browser_sample_rate, 'dump_wav': args.dump_wav, 'dump_npy': args.dump_npy }), (r"/control", ControlPageHandler) ]) http_server = tornado.httpserver.HTTPServer(app, ssl_options={ "certfile": args.cert_file, "keyfile" : args.key_file, }) http_server.listen(args.websocket_port) tornado.ioloop.IOLoop.instance().start() if __name__ == '__main__': main(sys.argv[1:]) EOF chmod 755 wsrecoder.py
トンネルの生成と録音サーバの起動
一通りの準備ができたので、以下のようにしてトンネルを作ります。セルから出力された URL がトンネルの入り口です (筆者が試した時はときどきエラーになりましたが、その場合も何度か試すとうまくいきました)。
# notebook-1 get_ipython().system_raw('./ngrok http https://localhost:8889 &') ! curl -s https://localhost:4040/api/tunnels | python3 -c \ "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])" # https://761d5a9eff82.ngrok.io
続いて録音サーバを起動します。
# notebook-1 !./wsrecoder.py
録音サーバの起動後、新しくタブを開いて https://761d5a9eff82.ngrok.io にアクセスすると(実際のURLは上記のセル出力の値に直して下さい)以下のコントロール画面が表示されるので、"START", “STOP” で録音の開始/停止ができます。
ラベルの生成と音声の録音
ここまでで音声を録音する準備が整いましたが、音声コマンド認識の学習をするには音声とコマンドのペアが必要になります。 録音サーバを起動したノートブックとは別に、新たに別のノートブックを開いて以下のセルを実行します。このノートブックを “notebook-2” とします。
%%bash # notebook-2 cat << EOF > mkdata.sh #!/bin/bash cnt=0 NUM_SAMPLES=100 NUM_CLASSES=8 CLASSES=("前進" "後退" "停止" "右" "左" "もっと右" "もっと左" "ターボブースト") time while [ \$cnt -lt \$NUM_SAMPLES ] do index=\`expr \$RANDOM \\* \$NUM_CLASSES / 32768\` echo "\$cnt \${CLASSES[\$index]}" echo "\${CLASSES[\$index]}" >> labels.txt cnt=\`expr \$cnt + 1\` sleep 2 done EOF chmod 755 mkdata.sh
先ほどのコントロール画面で “START” をクリックして録音を開始してから(この際、マイク使用の許可を求められるので「許可」をクリックして下さい)、 mkdata.sh を起動すると、以下のようにラベルが 2 秒間隔で出力されるので、「ぜんしん」、「こうたい」、と読み上げて音声を記録していきます。
# notebook-2 !./mkdata.sh # 0 前進 # 1 後退 # 2 もっと左 # 3 右 # ...
mkdata.sh の出力は 100 サンプルで終了するようになっています。終わったら、コントロール画面の “STOP” をクリックして下さい。録音サーバを起動したノートブック(“notebook-1”)で、録音された音声が “frame_20200625052006.wav” のようなファイル名で一塊の wav ファイルに出力されます。
mkdata.sh を実行したノートブック(“notebook-2”)では読み上げたラベルが labels.txt として記録されています。
# notebook-2 !head -4 labels.txt # 前進 # 後退 # もっと左 # 右
Colab の左端のメニューからフォルダのアイコンを選んで、labels.txt をダウンロード&アップロードで “notebook-2” から “notebook-1” に持っていきます。 以降は “notebook-1” で作業を行います。
wav ファイルの切り分け
録音サーバ(“wsrecoder.py”)が動いていると続きの作業ができないので、"wsrecoder.py" を実行したセル左側の停止ボタンで停止させておいて下さい。 ここからは 100 サンプルが一塊になった wav ファイルをラベル単位に切り分けていきます。まずは必要なライブラリをインポートします。
# notebook-1 import numpy as np import IPython.display from pydub import AudioSegment from pydub.silence import split_on_silence
保存した wav からタイムスタンプを切り出します。
# notebook-1 ts=!ls frame*.wav | sed -e 's/frame_//' -e 's/.wav//' ts=ts[0] ts # '20200625052006'
データセット用のディレクトリを作り、"notebook-2" から持ってきたラベルのファイル名に対応する音声のタイムスタンプを加えて、作成したディレクトリに移動します。
# notebook-1 !mkdir -p data !mv labels.txt labels_{ts}.txt !mv *{ts}* data !ls data # frame_20200625052006.npy frame_20200625052006.wav labels_20200625052006.txt
一塊の wav ファイルを無音部分で複数ファイルに分割し、ラベルテキストを読み込みます。
# notebook-1 sound = AudioSegment.from_file("data/frame_%s.wav" % ts , format="wav") chunks = split_on_silence(sound, min_silence_len=500, silence_thresh=-40, keep_silence=400) for i, chunk in enumerate(chunks): chunk.export("data/segment_" + str(i) +".wav", format="wav") with open("data/labels_%s.txt" % ts, "r") as f: lines = f.readlines() labels = [line.strip() for line in lines]
ラベルの先頭 10 サンプルは以下のようになっています。
# notebook-1 labels[0:10] # ['前進', '後退', 'もっと左', '右', '前進', 'ターボブースト', 'もっと左', '右', '左']
先頭のラベルと分割した wav ファイルを読み込みます。
# notebook-1 index=0 print("%2d : %s" % (index, labels[index])) IPython.display.Audio("data/segment_%d.wav" % index) # 0 : 右
ノートブックの画面上では読み込んだ wav の再生コントロールが表示されるので、再生ボタンをクリックしてラベルと音声が合致していることを確認しましょう。
上記のセルを index の値を変えながらズレをチェックしていきます。全件チェックするのは大変なので 20 サンプル毎ぐらいで十分です。 ズレや誤りが見つかったら、その都度手作業で修正していきます。
例えば、出だしの index=0 の wav に咳払い等が入ってしまった場合は、以下のようにして後続の wav ファイルのインデックスを前に詰めます。
# notebook-1 %%bash for i in `seq 1 100` do mv data/segment_$i.wav data/segment_`expr $i - 1`.wav done
また、ラベルが「右」なのに何故か「ひだり!」と発声してしまった場合などはラベルの方を書き換えたりもしました。
確認が終わった分割後の wav は対応するタイムスタンプでサブディレクトリを作って整理しておきます。
!mkdir -p data/{ts} !mv data/segment* data/{ts}
筆者の場合は以上の処理を繰り返して(とりあえず) 8 セット程作成しました。録音と推論の環境に違いがなければ、これくらいで十分だと思います。 実際に JetBot を操縦する場合は意外とモーター音が大きいので、録音作業時の傍らで JetBot のモータを稼働させてモーター音混じりのサンプルも作っておいた方がよいです(というか作る必要がありました。 JetBot のモーター音の鳴らし方は後述します)。
以下のようにして、 NeMo が読み込める JSON 形式に加工します。
%%bash declare -A commands commands=( \ ["前進"]="forward" \ ["後退"]="backword" \ ["右"]="right" \ ["左"]="left" \ ["もっと右"]="hard_right" \ ["もっと左"]="hard_left" \ ["停止"]="stop" \ ["ターボブースト"]="turbo_boost" ) dataset="jetbot_asr.json" rm -f $dataset for labels in data/labels*.txt do ts=`echo $labels | sed -e 's/^.*_//' -e 's/.txt//'` idx=0 for label in `cat $labels` do command=${commands[$label]} wav=data/$ts/segment_$idx.wav duration=`sox ${wav} -n stat 2>&1 | awk -F " " '/^Length/{print $3}'` record="{\"audio_filepath\": \"$wav\", \"duration\": $duration, \"command\": \"$command\"}" echo $record >> $dataset idx=`expr $idx + 1` done done
できあがったデータはこんな感じになります。今回はコマンド分類なので “command” 属性に正解ラベルを設定する形式になります。
!head -5 jetbot_asr.json # {"audio_filepath": "data/20200625050543/segment_0.wav", "duration": 0.889000, "command": "left"} # {"audio_filepath": "data/20200625050543/segment_1.wav", "duration": 0.842000, "command": "stop"} # {"audio_filepath": "data/20200625050543/segment_2.wav", "duration": 0.978000, "command": "right"} # {"audio_filepath": "data/20200625050543/segment_3.wav", "duration": 0.810000, "command": "turbo_boost"} # {"audio_filepath": "data/20200625050543/segment_4.wav", "duration": 0.801000, "command": "hard_right"}
無音データの作成
次に無音部分のデータを作ります。推論時はマイクが拾う音声を連続的に推論してコマンド分類する訳ですが、 JetBot に何も話かけていない場合は「何もしなくてよい」というクラスに分類してもらう必要があります。
発話せずに上記の“notebook-1"の手順を実行して、雑音ありと雑音なしの 2 パターンを作りました。
!ls *wav # frame_20200624045836.wav frame_20200624050547.wav
雑音ありの wav を固定長で分割して with_noize フォルダに移動します。
!sox frame_20200624045836.wav split_.wav trim 0 1.2 : newfile : restart !mkdir -p data/with_noize !mv split_*wav data/with_noize
雑音なしの wav も同様に処理します。
!sox frame_20200624050547.wav split_.wav trim 0 1.2 : newfile : restart !mkdir -p data/without_noize !mv split_*wav data/without_noize
これらを JSON 化して発話ありのデータに追加します。
%%bash dataset="jetbot_asr.json" for idx in `seq -f "%03.0f" 1 50` do wav=data/with_noize/split_$idx.wav duration=`sox ${wav} -n stat 2>&1 | awk -F " " '/^Length/{print $3}'` record="{\"audio_filepath\": \"$wav\", \"duration\": $duration, \"command\": \"void\"}" echo $record >> $dataset wav=data/without_noize/split_$idx.wav duration=`sox ${wav} -n stat 2>&1 | awk -F " " '/^Length/{print $3}'` record="{\"audio_filepath\": \"$wav\", \"duration\": $duration, \"command\": \"void\"}" echo $record >> $dataset done
実際にデータを作ってみると、いくつか duration が取れていないデータが出来てしまいました。
!cat jetbot_asr.json | grep -e "duration\": ," # {"audio_filepath": "data/20200625050543/segment_19.wav", "duration": , "command": "turbo_boost"} # {"audio_filepath": "data/20200625050543/segment_20.wav", "duration": , "command": "forward"} # {"audio_filepath": "data/20200625050543/segment_21.wav", "duration": , "command": "right"} # {"audio_filepath": "data/20200625050543/segment_22.wav", "duration": , "command": "hard_left"} # {"audio_filepath": "data/20200625050543/segment_23.wav", "duration": , "command": "hard_left"} # {"audio_filepath": "data/20200625050543/segment_24.wav", "duration": , "command": "hard_left"} # {"audio_filepath": "data/20200625050543/segment_25.wav", "duration": , "command": "right"} # {"audio_filepath": "data/20200625050543/segment_26.wav", "duration": , "command": "hard_right"}
これくらいの量なら大して影響なさそうなので、フィルタして落としてしまいます。
!cat jetbot_asr.json | grep -v "duration\": ," > jetbot_asr.json.filtered
中身をシャッフルして、
%%bash get_seeded_random() { seed="$1" openssl enc -aes-256-ctr -pass pass:"$seed" -nosalt \ </dev/zero 2>/dev/null } sort --random-source=<(get_seeded_random 42) -R jetbot_asr.json.filtered > jetbot_asr.json.shuf
こんな感じになりました。
!head -5 jetbot_asr.json.shuf # {"audio_filepath": "data/20200625050543/segment_18.wav", "duration": 1.007000, "command": "turbo_boost"} # {"audio_filepath": "data/20200625050543/segment_1.wav", "duration": 0.842000, "command": "stop"} # {"audio_filepath": "data/20200625050543/segment_13.wav", "duration": 0.906000, "command": "left"} # {"audio_filepath": "data/20200625050543/segment_9.wav", "duration": 0.847000, "command": "left"} # {"audio_filepath": "data/20200625050543/segment_5.wav", "duration": 0.845000, "command": "forward"}
実際の作業はアレコレやり直したり、苦手なコマンドのデータを追加したりと紆余曲折があって最終的には 1787 件のデータになってしまいました。。。
!wc -l *.json # 1787 jetbot_asr.json # 1787 total
最後に 8:1:1 で分割して、学習、検証、テストセットとします。
!head -1428 jetbot_asr.json.shuf > jetbot_asr_train.json !tail -356 jetbot_asr.json.shuf | head -178 > jetbot_asr_dev.json !tail -178 jetbot_asr.json.shuf > jetbot_asr_test.json !wc -l *.json # 1787 jetbot_asr.json # 178 jetbot_asr_dev.json # 178 jetbot_asr_test.json # 1428 jetbot_asr_train.json # 3571 total
作成したデータセットの GCS に保存しておきましょう。
from google.colab import auth auth.authenticate_user() !gsutil cp -r data gs://somewhere/jetbot_asr/ !gsutil cp *.josn data gs://somewhere/jetbot_asr/data/
ようやく音声分類データセットができたので、次はモデルの学習です。
3. 音声コマンド分類モデルの学習
ここからは前章で作成したデータセットを使って音声コマンド分類のモデルの学習を行います。 新たに Colab でノートブックを作って、ランタイムのアクセラレータは "GPU” にして下さい。
まずは必要なライブラリをインストールしていきます。
!apt-get update !apt-get install sox swig pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev \ libsndfile1-dev python-setuptools libboost-all-dev python-dev cmake
NeMo のコードも取得してインストールします。 詳しくは後述しますが、 JetBot に手軽にインストールできる PyTorch のバージョンとも絡むので少し前のコミットを使います。
!git clone https://github.com/NVIDIA/NeMo !cd NeMo && git checkout 331fc4b
前回と同様の問題があるので一部修正をしてしまって。。。
!cp NeMo/nemo/collections/asr/parts/jasper.py jasper.py !cat jasper.py | sed -e '158s#/#//#g' > jasper.py.mod !cp jasper.py.mod NeMo/nemo/collections/asr/parts/jasper.py !diff jasper.py NeMo/nemo/collections/asr/parts/jasper.py # 158c158 # < ) / self.conv.stride[0] + 1 # --- # > ) // self.conv.stride[0] + 1
pip でインストールします。
!cd NeMo; pip install . !cd NeMo; pip install .[asr]
前章で GCS に退避したデータセットも取ってきましょう。
from google.colab import auth auth.authenticate_user() !gsutil cp -r gs://somewhere/jetbot_asr/data/ .
コマンド分類のラベル(9クラス)を定義します。
labels = [ "forward", "backword", "right", "left", "hard_right", "hard_left", "stop", "turbo_boost", "void" ]
ここで、訓練データの各クラスの分布を見てみましょう(アレコレすったもんだした後のホントの最終形なので、前章で wc -l
した結果と整合性とれてないかもしれません。。。)。
import json train_set = [] with open("data/jetbot_asr_train.json", "r") as f: lines = f.readlines() for line in lines: train_set.append(json.loads(line)) counts = [0] * len(labels) for example in train_set: counts[labels.index(example["command"])] += 1 import matplotlib.pyplot as plt import numpy as np from matplotlib.font_manager import FontProperties plt.style.use('seaborn-whitegrid') %matplotlib inline f, ax1 = plt.subplots(figsize=(12, 3), dpi=90) fp = FontProperties(size=7) ax1.bar(labels, counts, width=0.5)
プロットするとこんな感じです。私の滑舌のせいかもしれませんが「ひだり」が意外と苦手でした。
学習の実行
ここからは学習の実行です。学習自体はあっという間に終わった気がします。
import nemo import nemo.collections.asr as nemo_asr
ファクトリを作って、
nf = nemo.core.NeuralModuleFactory(log_dir='./log', create_tb_writer=True) tb_writer = nf.tb_writer
モデルの定義をロードします。前回は 15x5 のモデルを使いましたが、音声コマンド認識なので 3x1 の小規模なモデルで十分です。
from ruamel.yaml import YAML yaml = YAML(typ="safe") with open("./NeMo/examples/asr/configs/quartznet_speech_commands_3x1_v1.yaml") as f: config = yaml.load(f) model_def = config["JasperEncoder"]
必要なクラスをインポートします。今回は分類問題なので使用するクラスが前回と一部、異なります。
from functools import partial from nemo.collections.asr.helpers import process_evaluation_batch, process_evaluation_epoch from nemo.collections.asr.helpers import monitor_classification_training_progress from nemo.collections.asr.helpers import process_classification_evaluation_batch from nemo.collections.asr.helpers import process_classification_evaluation_epoch
訓練サンプルの数を確認しておきましょう(やっぱり増えてますね。。。)。
num_train_samples = !(wc -l data/jetbot_asr_train.json | cut -f 1 -d " " ) num_train_samples = int(num_train_samples[0]) num_train_samples # 1837
エポック数とバッチサイズはこんな感じにしました。
num_epochs = 50 batch_size = 32
トータルステップ数を確認しておきます。
total_steps = num_train_samples * num_epochs // batch_size total_steps # 2870
学習の最初の 10% をウォームアップにします。
warmup_steps = total_steps // 10 warmup_steps # 287
その他のハイパーパラメータはこんな感じです。
batches_per_step = 1 dropout = 0.2 lr = 0.02 weight_decay = 0.0001 sample_rate = 16000 n_fft = 512 output_dir = "./output" epochs_per_eval = 1
今回も前回と同じくドロップアウトを適用してみます。
for i, layer in enumerate(model_def["jasper"]): layer['dropout'] = dropout
ここからはデータ入力レイヤの定義です。データレイヤのクラスには nemo_asr.AudioToSpeechLabelDataLayer
を指定します。
train_dataset = "data/jetbot_asr_train.json" val_dataset = "data/jetbot_asr_dev.json" with open(train_dataset) as f: lines = f.readlines() num_train_samples = len(lines) steps_per_epoch = num_train_samples // batch_size data_layer = nemo_asr.AudioToSpeechLabelDataLayer( manifest_filepath=train_dataset, labels=labels, sample_rate=sample_rate, batch_size=batch_size, shuffle=True ) # [NeMo I 2020-07-23 07:19:13 collections:240] Filtered duration for loading collection is 0.000000. # [NeMo I 2020-07-23 07:19:13 collections:243] # 1837 files loaded accounting to # 9 labels # [NeMo I 2020-07-23 07:19:13 data_layer:961] # of classes :9 data_layer_val = nemo_asr.AudioToSpeechLabelDataLayer( manifest_filepath=val_dataset, labels=labels, sample_rate=sample_rate, batch_size=batch_size, shuffle=False ) # [NeMo I 2020-07-23 07:19:14 collections:240] Filtered duration for loading collection is 0.000000. # [NeMo I 2020-07-23 07:19:14 collections:243] # 206 files loaded accounting to # 9 labels # [NeMo I 2020-07-23 07:19:14 data_layer:961] # of classes :9
次に前処理の定義です。
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( n_fft=n_fft, sample_rate=sample_rate, features=model_def["feat_in"], stft_conv=True) # [NeMo I 2020-07-23 07:19:17 features:144] PADDING: 16 # [NeMo I 2020-07-23 07:19:17 features:152] STFT using conv spec_augment = nemo_asr.SpectrogramAugmentation(**config["SpectrogramAugmentation"])
エンコーダ、デコーダ、ロスを定義します。今回はコマンド分類なのでデコーダには nemo_asr.JasperDecoderForClassification
を使います。
ロスも nemo_asr.CrossEntropyLossNM
になります。
encoder = nemo_asr.JasperEncoder(**model_def) decoder = nemo_asr.JasperDecoderForClassification(feat_in=model_def["jasper"][-1]['filters'], num_classes=len(labels), **config['JasperDecoderForClassification'], ) ce_loss = nemo_asr.CrossEntropyLossNM()
学習グラフと検証グラフを定義して、
audio_signal, audio_signal_len, commands, command_len = data_layer() processed_signal, processed_signal_len = data_preprocessor(input_signal=audio_signal, length=audio_signal_len) processed_signal = spec_augment(input_spec=processed_signal) encoded, encoded_len = encoder(audio_signal=processed_signal, length=processed_signal_len) decoded = decoder(encoder_output=encoded) loss = ce_loss(logits=decoded, labels=commands) audio_signal_v, audio_signal_len_v, commands_v, command_len_v = data_layer_val() processed_signal_v, processed_signal_len_v = data_preprocessor(input_signal=audio_signal_v, length=audio_signal_len_v) encoded_v, encoded_len_v = encoder(audio_signal=processed_signal_v, length=processed_signal_len_v) decoded_v = decoder(encoder_output=encoded_v) loss_v = ce_loss(logits=decoded_v, labels=commands_v)
進捗、チェックポイント、検証の各種コールバック関数を定義します。
logger_cb = nemo.core.SimpleLossLoggerCallback( tb_writer=tb_writer, tensors=[loss, decoded, commands], print_func=partial(monitor_classification_training_progress, eval_metric=None), get_tb_values=lambda x: [("loss", x[0])],) ckpt_cb = nemo.core.CheckpointCallback(folder=output_dir, step_freq=steps_per_epoch) eval_cb = nemo.core.EvaluatorCallback( eval_tensors=[loss_v, decoded_v, commands_v], user_iter_callback=partial(process_classification_evaluation_batch, top_k=1), user_epochs_done_callback=partial(process_classification_evaluation_epoch, eval_metric=1, tag="DEV"), eval_step=steps_per_epoch * epochs_per_eval, tb_writer=tb_writer, ) callbacks = [logger_cb, ckpt_cb, eval_cb]
ようやくですが、学習ループを回します。
optimization_params={"num_epochs": num_epochs, "lr": lr, "weight_decay": weight_decay, "momentum": 0.95, "betas": (0.98, 0.5), "grad_norm_clip": None} nf.train(tensors_to_optimize=[loss], callbacks=callbacks, optimizer="novograd", optimization_params=optimization_params, batches_per_step=batches_per_step) # [NeMo I 2020-07-23 07:20:02 deprecated_callbacks:195] Starting ..... # [NeMo I 2020-07-23 07:20:02 callbacks:534] Found 2 modules with weights: ... # [NeMo I 2020-07-23 07:22:20 callbacks:465] Saved checkpoint: ./output/trainer-STEP-2900.pt # [NeMo I 2020-07-23 07:22:20 deprecated_callbacks:339] Final Evaluation .............................. # [NeMo I 2020-07-23 07:22:20 helpers:271] ==========>>>>>>Evaluation Loss DEV: 0.388 # [NeMo I 2020-07-23 07:22:20 helpers:273] ==========>>>>>>Evaluation Accuracy Top@1 DEV: 96.1165 # [NeMo I 2020-07-23 07:22:20 deprecated_callbacks:344] Evaluation time: 0.16168832778930664 seconds
検証セットでの正答率が 96% なので悪くなさそうです。
推論の実行
学習済みモデルができたのでテストセットで推論してみましょう。
今度はテストセットでデータレイヤを定義して、
test_dataset = 'data/jetbot_asr_test.json' import json test_set = [] with open("data/jetbot_asr_test.json", "r") as f: lines = f.readlines() for line in lines: test_set.append(json.loads(line)) data_layer_test = nemo_asr.AudioToSpeechLabelDataLayer( manifest_filepath=test_dataset, labels=labels, sample_rate=sample_rate, batch_size=batch_size, shuffle=False ) # [NeMo I 2020-07-23 07:22:49 collections:240] Filtered duration for loading collection is 0.000000. # [NeMo I 2020-07-23 07:22:49 collections:243] # 206 files loaded accounting to # 9 labels # [NeMo I 2020-07-23 07:22:49 data_layer:961] # of classes :9
推論グラフを構築します。
audio_signal_t, audio_signal_len_t, commands_t, command_len_t = data_layer_test() processed_signal_t, processed_signal_len_t = data_preprocessor(input_signal=audio_signal_t, length=audio_signal_len_t) encoded_t, encoded_len_t = encoder(audio_signal=processed_signal_t, length=processed_signal_len_t) decoded_t = decoder(encoder_output=encoded_t) loss_t = ce_loss(logits=decoded_t, labels=commands_t)
推論の実行はこんな感じで行います。
evaluated_tensors = nf.infer(tensors=[loss_t, decoded_t, commands_t], checkpoint_dir="./output") # [NeMo I 2020-07-23 07:22:53 actions:1574] Restoring JasperEncoder from ./output/JasperEncoder-STEP-2900.pt # [NeMo I 2020-07-23 07:22:53 actions:1574] Restoring JasperDecoderForClassification from ./output/JasperDecoderForClassification-STEP-2900.pt # [NeMo I 2020-07-23 07:22:53 actions:695] Evaluating batch 0 out of 7 # ... # [NeMo I 2020-07-23 07:22:53 actions:695] Evaluating batch 6 out of 7
精度を確認します。
from nemo.collections.asr.metrics import classification_accuracy correct_count = 0 total_count = 0 for batch_idx, (logits, labels) in enumerate(zip(evaluated_tensors[1], evaluated_tensors[2])): acc = classification_accuracy( logits=logits, targets=labels, top_k=[1] ) acc = acc[0] correct_count += int(acc * logits.size(0)) total_count += logits.size(0) print(f"Total correct / Total count : {correct_count} / {total_count}") print(f"Final accuracy : {correct_count / float(total_count)}") # Total correct / Total count : 199 / 206 # Final accuracy : 0.9660194174757282
こちらも 96.6% の正答率です。 JetBot を操作するには十分な感じですが、もう少し詳しく見てみましょう6。
import torch import librosa import json import IPython.display as ipd class ReverseMapLabel: def __init__(self, data_layer: nemo_asr.AudioToSpeechLabelDataLayer): self.label2id = dict(data_layer._dataset.label2id) self.id2label = dict(data_layer._dataset.id2label) def __call__(self, pred_idx, label_idx): return self.id2label[pred_idx], self.id2label[label_idx] sample_idx = 0 incorrect_preds = [] rev_map = ReverseMapLabel(data_layer_test) # Remember, evaluated_tensor = (loss, logits, labels) for batch_idx, (logits, labels) in enumerate(zip(evaluated_tensors[1], evaluated_tensors[2])): probs = torch.softmax(logits, dim=-1) probas, preds = torch.max(probs, dim=-1) incorrect_ids = (preds != labels).nonzero() for idx in incorrect_ids: proba = float(probas[idx][0]) pred = int(preds[idx][0]) label = int(labels[idx][0]) idx = int(idx[0]) + sample_idx incorrect_preds.append((idx, *rev_map(pred, label), proba)) sample_idx += labels.size(0) print(f"Num test samples : {total_count}") print(f"Num errors : {len(incorrect_preds)}") incorrect_preds = sorted(incorrect_preds, key=lambda x: x[-1], reverse=False) # Num test samples : 206 # Num errors : 7
7つ程、不正解がありました(出力は [id, pred, label, prob]です)。
for incorrect_sample in incorrect_preds: print(str(incorrect_sample)) # (28, 'turbo_boost', 'left', 0.47237205505371094) # (62, 'turbo_boost', 'hard_right', 0.48559606075286865) # (124, 'right', 'forward', 0.9396188259124756) # (18, 'hard_right', 'backword', 0.9929646253585815) # (126, 'stop', 'backword', 0.9997004270553589) # (147, 'left', 'turbo_boost', 0.9997300505638123) # (106, 'stop', 'turbo_boost', 0.9999876022338867)
試しに幾つか確認してみました。
画像なので再生できませんが、id=28 は「たーぼぶーすと」、id=106 は「ていし」と発話しています。。。 どうもラベルの方が間違ってるようですね。無音部分で分割した wav とラベルの整合性チェックは結構粗く行ったので、取りこぼしがあったのかもしれません。 とはいえ、概ね正しくコマンドを認識できているようですので、学習結果を GCS に保存しておきましょう。
!tar zcvf output.tar.gz output !gsutil cp output.tar.gz gs://somewhere/jetbot_asr/
ようやくですが(3回目?)、このモデルを JetBot にデプロイして走行試験をしてみましょう。
4. JetBot での走行試験
ここからは、お手元に JetBot がある想定です。いまから自分で作る気概のある人はオブジェクトの広場の「 JetBot を動かしてみよう 」の記事7を参考にして下さい。
JetBot のセットアップ
JetBot の OS イメージは JetBot の Wiki 8 から入手できます。ただし NeMo を動かすには PyTorch のバージョンが古過ぎるので、そのあたりを入れ替えていく必要があります。
まずは上記のサイトから jetbot_image_v0p4p0.zip
を入手し、Echer 9 を使って microSD カードに書き込んでください。イメージを書き込んだ microSD カードを JetBot のスロットに挿入し、
HDMI のスロットにモニター、USB のスロットにマウス、キーボードを接続します。準備ができたらUSBバッテリーにケーブルを接続して JetBot を起動します。 ここからは普通の Ubuntu の操作とあまり変わりありません。
さて、執筆時点で最新の NeMo では PyTorch 1.6 以上ということになっているのですが、先ほど microSD に焼き込んだ jetbot_image_v0p4p0.zip
は JetPack 4.3 がベースになっています。こちら10のサイトで確認したところ、 JetPack 4.3 向けの PyTorch は 1.4 が最新のようですので、PyTorch をこちらに入れ替え、NeMo のバージョンを少し落とすことにしました。
ここからは、 JetBot のデスクトップを操作して作業します。 まずは WiFi に接続して下さい。走行時は PC のマイクで拾った音声を WiFi 経由で JetBot に送信する形になるので、起動時に WiFi に自動接続する設定にしておきましょう。
次に Terminal を開いてセットアップ作業をしていきます。ちなみに、ここからの作業は USBバッテリー駆動で低電力モードで動いているせいか、かなり時間がかかりました。。。まずは、 PyTorch を 1.4 に入れ替えましょう。
jetbot@jetson-4-3:~$ wget https://nvidia.box.com/shared/static/ncgzus5o23uck9i5oth2n8n06k340l6k.whl -O torch-1.4.0-cp36-cp36m-linux_aarch64.whl jetbot@jetson-4-3:~$ sudo apt-get install python3-pip libopenblas-base libopenmpi-dev jetbot@jetson-4-3:~$ sudo pip3 install Cython jetbot@jetson-4-3:~$ sudo pip3 install numpy torch-1.4.0-cp36-cp36m-linux_aarch64.whl
つづいて NeMo のインストールです。走行時はモニタ、キーボード、マウスを外して WiFi 経由で JetBot 上で動作する Jupyter Lab に接続して作業するので、Jupyter Lab のデフォルトのフォルダ(/home/jetbot/Notebooks)で作業します。
まずは コミット 331fc4b をチェックアウトします。ちゃんと PyTorch 1.4 がサポートされたバージョンです。
jetbot@jetson-4-3:~$ cd /home/jetbot/Notebooks jetbot@jetson-4-3:~/Notebooks$ sudo apt-get install sox swig pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev \ libsndfile1-dev python-setuptools libboost-all-dev python-dev cmake jetbot@jetson-4-3:~/Notebooks$ git clone https://github.com/NVIDIA/NeMo jetbot@jetson-4-3:~/Notebooks$ cd NeMo jetbot@jetson-4-3:~/Notebooks/NeMo$ git checkout 331fc4b jetbot@jetson-4-3:~/Notebooks/NeMo$ grep -n -A 5 Requirement README.rst # 81:**Requirements** # 82- # 83-1) Python 3.6 or 3.7 # 84-2) PyTorch 1.4.* with GPU support # 85-3) (optional, for best performance) NVIDIA APEX. Install from here: https://github.com/NVIDIA/apex
さらに必要なライブラリと NeMo をインストールします11。
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo pip3 uninstall enum34 jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get install liblapack-dev jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get install gfortran jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get install protobuf-compiler libprotobuf-dev jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo pip3 install numpy==1.19.0 jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo pip3 install .
続いて asr モジュールをインストールしたかったのですが、libffi や llvm-config がないと怒られるのでインストールします。
jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get install libffi-dev jetbot@jetson-4-3:~/Notebooks/NeMo$ wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get -y install clang-7 lldb-7 lld-7 jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo apt-get -y install libllvm-7-ocaml-dev libllvm7 llvm-7 llvm-7-dev \ llvm-7-doc llvm-7-examples llvm-7-runtime jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo ln -s /usr/lib/llvm-7/bin/llvm-config /usr/bin/llvm-config
改めて asr モジュールをインストール
jetbot@jetson-4-3:~/Notebooks/NeMo$ export LLVM_CONFIG=/usr/bin/llvm-config jetbot@jetson-4-3:~/Notebooks/NeMo$ sudo pip3 install .[asr]
NeMo がインストールできたので、一つ上のディレクトリに動いておきます。あと他にも入れておくものがありました。
jetbot@jetson-4-3:~/Notebooks/NeMo$ cd .. jetbot@jetson-4-3:~/Notebooks$ sudo pip install configargparse jetbot@jetson-4-3:~/Notebooks$ sudo pip install samplerate
分類対象のクラスをテキストに出力しておきます。モデルが出力するテンソルのインデックスに対応する行をこのファイルから拾ってクラスの文字列表現にするので、順番は重要です。
jetbot@jetson-4-3:~/Notebooks$ cat << EOF > labels_9class.txt forward backword right left hard_right hard_left stop turbo_boost void EOF
次に前章で保存した学習済みモデルを GCS からダウンロードします12。
jetbot@jetson-4-3:~/Notebooks$ gsutil cp gs://somewhere/jetbot_asr/output.tar.gz . jetbot@jetson-4-3:~/Notebooks$ tar zxvf output.tar.gz
WSS 接続に使う鍵と証明書も改めて作ります。
jetbot@jetson-4-3:~/Notebooks$ mkdir ssl jetbot@jetson-4-3:~/Notebooks$ openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout "./ssl/cert.key" -out "./ssl/cert.pem" -batch
最後にストリーミングで音声認識をして JetBot を制御するサーバのロジックです。長いので Jupyter Lab でノートブックを開いてセルから出力しました( JetBot からは https://localhost:8888/lab でアクセスできます)。デフォルトのフォルダにノートブックを作ったので、 jetbot_asr_server.py
が /home/jetbot/Notebooks
に出力される想定です。
%%bash cat <<EOF > jetbot_asr_server.py #!/usr/bin/env python3 import sys import os import math import configargparse import tornado.ioloop import tornado.web import tornado.websocket import wave import numpy as np import datetime import nemo import nemo.collections.asr as nemo_asr from ruamel.yaml import YAML from nemo.backends.pytorch.nm import DataLayerNM from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType import torch import samplerate import wave import json from jetbot import Robot """ class Robot: def forward(self, speed): print("Robot: forward : %s" % (speed)) def backward(self, speed): print("Robot: backward : %s" % (speed)) def set_motors(self, left, right): print("Robot: set_motors : %s, %s" % (left, right)) def stop(self): print("Robot: stop") """ # How to make server key and certificate. # openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout "cert.key" -out "cert.pem" -batch def log(message): print(message, flush=True) def get_parser(parser=None, required=True): if parser is None: parser = configargparse.ArgumentParser( description='', config_file_parser_class=configargparse.YAMLConfigFileParser, formatter_class=configargparse.ArgumentDefaultsHelpFormatter) parser.add('--config', is_config_file=True, help='config file path') parser.add_argument('--labels', default="./labels_9class.txt", type=str, help='List of label(class).') parser.add_argument('--yaml_config', default="./NeMo/examples/asr/configs/quartznet_speech_commands_3x1_v1.yaml", type=str, help='YAML config.') parser.add_argument('--encoder_ckpt', default="./output/JasperEncoder-STEP-2900.pt", type=str, help="path to encoder checkpoint.") parser.add_argument('--decoder_ckpt', default="./output/JasperDecoderForClassification-STEP-2900.pt", type=str, help="path to decoder checkpoint.") parser.add_argument('--nf_log_dir', default="./log", type=str, help="log directory for nefaul factory.") parser.add_argument('--sample_rate', default=16000, type=int, help="Server side sample rate.") parser.add_argument('--browser_sample_rate', default=16000, type=int, help="Client side sample rate.") parser.add_argument('--cert_file', default="./ssl/cert.pem", type=str, help="SSL cert file.") parser.add_argument('--key_file', default="./ssl/cert.key", type=str, help="SSL key file.") parser.add_argument('--websocket_port', default=8889, type=int, help="WebSocket port.") parser.add_argument('--frame_len', default=0.2, type=float, help="frame length (sec)") parser.add_argument('--frame_overlap', default=0.5, type=float, help="frame overlap length before and after.(sec)") parser.add_argument('--message_len', default=2048, type=int, help="message size of websocket upload.") parser.add_argument('--window_stride', default=0.01, type=float, help="window stride of fft.") parser.add_argument('--dump_wav', default=0, type=int, help="whether to dump frame as wav.") parser.add_argument('--dump_npy', default=0, type=int, help="whether to dump frame as npy.") parser.add_argument('--confidence_thresh', default=0.80, type=float, help="confidence threshold of detection.") return parser def calc_frame_size(args): log("frame_len(sec) = %f" % (args.frame_len)) log("frame_overlap(sec) = %f" % (args.frame_overlap)) log("browser_sample_rate = %d" % (args.browser_sample_rate)) log("message_len = %d" % (args.message_len)) log("sample_rate = %d" % (args.sample_rate)) n_frame_len = int(args.frame_len * args.browser_sample_rate) if n_frame_len % args.message_len != 0: n_frame_len = (n_frame_len // args.message_len + 1) * args.message_len n_frame_overlap = int(args.frame_overlap * args.browser_sample_rate) log("n_frame_len = %d" % (n_frame_len)) log("n_frame_overlap = %d" % (n_frame_overlap)) return n_frame_len, n_frame_overlap def load_labels(label): log("Loading labels.") with open(label, "r") as f: lines = f.readlines() labels = [line.strip() for line in lines] return labels def load_model_def(yaml_config): log("Loading 3x1 model definition.") yaml = YAML(typ="safe") with open(yaml_config) as f: config = yaml.load(f) model_def = config["JasperEncoder"] return model_def def build_neural_factory(model_def, labels, encoder_ckpt, decoder_ckpt, nf_log_dir, sample_rate, window_stride, n_fft=512): log("Creating NeuralModuleFactory.") nf = nemo.core.NeuralModuleFactory(log_dir=nf_log_dir) data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(dither=0.0, pad_to=0, n_fft=n_fft, sample_rate=sample_rate, window_stride=window_stride, features=model_def["feat_in"], stft_conv=True) log("Creating AudioDataLayer.") class AudioDataLayer(DataLayerNM): @property def output_ports(self): return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), } def __init__(self, sample_rate): super().__init__() self._sample_rate = sample_rate self.has_next = True def __iter__(self): return self def __next__(self): if not self.has_next: raise StopIteration self.has_next = False return torch.as_tensor(self.signal, dtype=torch.float32), torch.as_tensor(self.signal_shape, dtype=torch.int64) def set_signal(self, signal): self.signal = np.reshape(signal, [1, -1]) self.signal_shape = np.expand_dims(self.signal.size, 0).astype(np.int64) self.has_next = True def __len__(self): return 1 @property def dataset(self): return None @property def data_iterator(self): return self data_layer = AudioDataLayer(sample_rate=sample_rate) log("Creating encoder, decoder, greedy_decoder.") encoder = nemo_asr.JasperEncoder(**model_def) decoder = nemo_asr.JasperDecoderForClassification(feat_in=model_def["jasper"][-1]['filters'], num_classes=len(labels), return_logits=True, pooling_type='avg') ce_loss = nemo_asr.CrossEntropyLossNM() log("Restoring parameters from the checkpoints.") encoder.restore_from(encoder_ckpt) decoder.restore_from(decoder_ckpt) log("Building DAG.") audio_signal, audio_signal_len = data_layer() processed_signal, processed_signal_len = data_preprocessor(input_signal=audio_signal, length=audio_signal_len) encoded, encoded_len = encoder(audio_signal=processed_signal, length=processed_signal_len) decoded = decoder(encoder_output=encoded) log("Setting up infer_signal function.") def infer_signal(self, signal): data_layer.set_signal(signal) tensors = self.infer(tensors=[decoded], verbose=False) logits = tensors[0][0] probs = torch.softmax(logits, dim=-1) return probs nf.infer_signal = infer_signal.__get__(nf) return nf class FrameASR: def __init__(self, nf, model_def, labels, n_frame_len, n_frame_overlap, sample_rate, browser_sample_rate, confidence_thresh, dump_wav, dump_npy): log("Initialize FrameASR.") self.nf = nf self.labels = labels self.sample_rate = sample_rate self.browser_sample_rate = browser_sample_rate self.n_frame_len = n_frame_len self.n_frame_overlap = n_frame_overlap self.buffer = np.zeros(shape=2 * self.n_frame_overlap + self.n_frame_len, dtype=np.float32) self.confidence_thresh = confidence_thresh self.dump_wav = dump_wav self.dump_npy = dump_npy self.debug() self.reset() def debug(self): log("labels = %s" % (self.labels)) log("sample_rate = %d" % (self.sample_rate)) log("browser_sample_rate = %d" % (self.browser_sample_rate)) log("n_frame_len = %d" % (self.n_frame_len)) log("n_frame_overlap = %d" % (self.n_frame_overlap)) log("asr buffer size = %s" % (self.buffer.shape)) log("confidence_thresh = %s" % (self.confidence_thresh)) def dump_as_wav(self, frame, timestamp, cls): arr = (frame * 32767).astype(np.int16) # 32767 is max value of 16 bit int. filename = 'frame_%s_predict_as_%s.wav' % (timestamp, cls) log("wrting wav file : %s" % filename) with wave.open(filename, 'wb') as wf: wf.setnchannels(1) wf.setsampwidth(2) # 2bytes(16bit precision) wf.setframerate(16000) wf.writeframes(arr.tobytes('C')) def _decode(self, frame): #log("len(frame) = %d" % (len(frame))) assert len(frame)==self.n_frame_len # append new frame to buffer and scroll forward 1 frame length. self.buffer[:-self.n_frame_len] = self.buffer[self.n_frame_len:] self.buffer[-self.n_frame_len:] = frame if self.sample_rate == self.browser_sample_rate : buffer = self.buffer else: rate = self.sample_rate/self.browser_sample_rate buffer = samplerate.resample(self.buffer, rate, 'sinc_best') probs = self.nf.infer_signal(buffer).cpu().numpy()[0] #self._debug(probs) idx = np.argmax(probs) if self.labels[idx] != "void" and probs[idx] >= self.confidence_thresh: timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') if self.dump_npy: np.save('frame_%s_predict_as_%s' % (timestamp, self.labels[idx]), resampled) if self.dump_wav: self.dump_as_wav(resampled, timestamp, self.labels[idx]) return idx, self.labels[idx], probs[idx] def _debug(self, probs): predict="" for i, cls in enumerate(self.labels): predict += "%s=%5.2f, " % (cls, probs[i]) predict = predict[:-1] log(predict) def detect(self, frame=None): if frame is None: frame = np.zeros(shape=self.n_frame_len, dtype=np.float32) if len(frame) < self.n_frame_len: frame = np.pad(frame, [0, self.n_frame_len - len(frame)], 'constant') return self._decode(frame) def reset(self): self.buffer=np.zeros(shape=self.buffer.shape, dtype=np.float32) class WebSocketHandler(tornado.websocket.WebSocketHandler): def initialize(self, n_frame_len, message_len, browser_sample_rate, asr, confidence_thresh, robot): log("Initialize WebSocket handler.") self.message_len = message_len self.n_frame_len = n_frame_len self.browser_sample_rate = browser_sample_rate self.n_messages_per_frame = n_frame_len // message_len log("browser_sample_rate : %d" % (browser_sample_rate)) log("n_frame_len : %d" % (self.n_frame_len)) log("n_messages_per_frame : %d" % (self.n_messages_per_frame)) self.asr = asr self.confidence_thresh = confidence_thresh self.robot = robot self.pos = 0 def open(self): self.buffer = [] log("audio socket opened") def on_message(self, message): message = np.frombuffer(message, dtype='float32') #log("on message : size=%d, buffer size=%d" % (len(message), len(self.buffer))) assert len(message)==self.message_len self.buffer.append(message) if len(self.buffer) >= self.n_messages_per_frame: self.detect() def detect(self): frame = np.array(self.buffer).flatten() #log("frame length = %s" % (frame.shape)) # Invoke ASR timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') idx, cls, prob = self.asr.detect(frame) if cls != "void" and prob >= self.confidence_thresh: log("detected : %s(%d) = %5.2f" % (cls, idx, prob)) self.write_message("%s : command=%s, probability=%5.2f\n" % (timestamp, cls, prob)) self.run_command(cls) self.buffer.clear() def run_command(self, command): if command == 'forward': #self.robot.forward(0.3) self.robot.set_motors(0.3, 0.28) elif command == 'right': self.robot.set_motors(0.3, 0.2) elif command == 'left': self.robot.set_motors(0.2, 0.3) elif command == 'hard_right': self.robot.set_motors(0.3, 0.1) elif command == 'hard_left': self.robot.set_motors(0.1, 0.3) elif command == 'stop': self.robot.stop() elif command == 'backword': #self.robot.backward(0.3) self.robot.set_motors(-0.3, -0.28) elif command == 'turbo_boost': self.robot.forward(1.0) def on_close(self): self.detect() log("audio socket closed") def check_origin(self, origin): return True class ControlPageHandler(tornado.web.RequestHandler): def initialize(self, browser_sample_rate): log("Initialize control page handler.") self.browser_sample_rate = browser_sample_rate def get(self): host=self.request.host html = """ <html> <div id="control" style="padding: 10px"> <font color="#ffffff"> <button id="start_transcribe" style="border: 0px; padding: 10px; border-radius: 10px; background-color: #00bfff; margin: 10px"> START </button> <button id="stop_transcribe" style="border: 0px; padding: 10px; border-radius: 10px; background-color: #ff69b4"> STOP </button> </font> </div> <div> <textarea id="transcribed" cols="120" rows="8"></textarea> </div> <script> var current_stream = null; var context = null; var ws = null; var textarea = document.getElementById("transcribed") var upload = function(stream) { current_stream = stream; context = new AudioContext(); context = new AudioContext({ sampleRate: %d }); ws = new WebSocket('wss://%s/websocket'); ws.onmessage = function(event) { textarea.value = textarea.value + event.data; textarea.scrollTop = textarea.scrollHeight; }; console.log(context.sampleRate); var input = context.createMediaStreamSource(stream) var processor = context.createScriptProcessor(0, 1, 1); input.connect(processor); processor.connect(context.destination); processor.onaudioprocess = function(e) { var voice = e.inputBuffer.getChannelData(0); ws.send(voice.buffer); }; }; function start_transcribe(){ navigator.mediaDevices.getUserMedia({ audio: true, video: false }).then(upload) } function stop_transcribe(){ ws.close(); current_stream.getTracks().forEach(track => track.stop()); context.close(); } document.getElementById("start_transcribe").onclick = start_transcribe; document.getElementById("stop_transcribe").onclick = stop_transcribe; </script> </html> """ % (self. browser_sample_rate, host) self.write(html) def check_origin(self, origin): return True def main(cmd_args): parser = get_parser() args, _ = parser.parse_known_args(cmd_args) n_frame_len, n_frame_overlap = calc_frame_size(args) labels = load_labels(label=args.labels) model_def = load_model_def(yaml_config=args.yaml_config) nf = build_neural_factory(model_def, labels, encoder_ckpt=args.encoder_ckpt, decoder_ckpt=args.decoder_ckpt, nf_log_dir=args.nf_log_dir, sample_rate=args.sample_rate, window_stride=args.window_stride ) asr = FrameASR(nf, model_def, labels, n_frame_len=n_frame_len, n_frame_overlap=n_frame_overlap, sample_rate=args.sample_rate, browser_sample_rate=args.browser_sample_rate, confidence_thresh=args.confidence_thresh, dump_wav=args.dump_wav, dump_npy=args.dump_npy) log("Starting SSL server on port %d" % args.websocket_port) robot = Robot() app = tornado.web.Application([(r"/websocket", WebSocketHandler, { 'n_frame_len': n_frame_len, 'message_len': args.message_len, 'browser_sample_rate': args.browser_sample_rate, 'asr': asr, 'confidence_thresh': args.confidence_thresh, 'robot':robot }), (r"/control", ControlPageHandler, { 'browser_sample_rate': args.browser_sample_rate}) ]) http_server = tornado.httpserver.HTTPServer(app, ssl_options={ "certfile": args.cert_file, "keyfile" : args.key_file, }) http_server.listen(args.websocket_port) tornado.ioloop.IOLoop.instance().start() if __name__ == '__main__': main(sys.argv[1:]) EOF chmod 755 jetbot_asr_server.py
これで、ひととおりの準備が整いました。いよいよ走行試験です。 JetBot からモニタ、キーボード、マウスを取り外して、 走行させる場所までもっていきましょう。
走行試験
それでは走行試験を始めます。環境の設定としては マイクの付いた Windows PC と JetBot が同じ WiFi につながっているという状況です。 まずは、 JetBot の IP アドレスを確認しましょう。 WiFi に接続できていれば LED に表示がでているはずです。
Windows PC の Chrome から JetBot 上で稼働している Jupyter Lab に接続します。筆者の環境では https://192.168.0.144:8888/lab ですね。 Jupyter Lab に接続できたら再び Terminal を開きます。
Terminal から音声コマンド認識サーバを起動します。
jetbot@jetson-4-3:~$ cd Notebooks/ jetbot@jetson-4-3:~/Notebooks$ ./jetbot_asr_server.py # ... # frame_len(sec) = 0.200000 # frame_overlap(sec) = 0.500000 # browser_sample_rate = 16000 # message_len = 2048 # sample_rate = 16000 # n_frame_len = 4096 # n_frame_overlap = 8000 # Loading labels. # Loading 3x1 model definition. # Creating NeuralModuleFactory. # ... # Creating AudioDataLayer. # Creating encoder, decoder, greedy_decoder. # Restoring parameters from the checkpoints. # Building DAG. # Setting up infer_signal function. # Initialize FrameASR. # labels = ['forward', 'backword', 'right', 'left', 'hard_right', 'hard_left', 'stop', 'turbo_boost', 'void'] # sample_rate = 16000 # browser_sample_rate = 16000 # n_frame_len = 4096 # n_frame_overlap = 8000 # asr buffer size = 20096 # confidence_thresh = 0.8 # Starting SSL server on port 8889
簡単に動きを説明すると、PC のマイクからサンプリングレート 16KHz で拾った音声を 2048 サンプルを 1 メッセージとして連続的に WebSocket で JetBot に送り続けており、 JetBot 側はそれを約1.2秒分のバッファに貯めていて 2 メッセージ(=4096サンプル)受信する毎に直近の1.2秒分を音声コマンド認識に投入する動きとなります。言い換えると、約1.2秒(8000*2+4096)のウインドウを約0.25秒(=4096/16000)のストライドで動かしながらコマンド認識していることになります13。
音声コマンド認識サーバが起動したら Chrome のタブをもう一枚開けて、音声コマンド認識サーバのコントロール画面( https://192.168.0.144:8889/control )を表示します。
録音サーバの時と同様に “START” をクリック、マイクへのアクセスを要求されるので許可して下さい。
あとは PC に向かって、「ぜんしん」、「みぎ」、「こうたい」と発話するとコントロール画面のテキストエリアに認識結果が表示され、JetBot が動き始めます。
実際の様子はこんな感じになります。だいたい意図したとおりに動かせてますね。
動作中の JetBot の状況は jtop で確認できます。動作中はこんな感じになってました。
5. おまけ
「後述します」として、まだ書いてない話が残っていましたね。
データセット作成時に JetBot のモータ音を鳴らす
これは JetBot の Jupyter Lab に接続して、/home/jetbot/Notebooks/basic_motion/basic_motion.ipynb
を開いてください。
Robot
クラスを使って JetBot のモータを制御できるので適当に駆動させながら、録音作業を実施します。
Colab で音声認識サーバを動かす
こちらは3章の学習に使った Colab のノートブックに2章の要領で ngrok をセットアップします。そこで jetbot_asr_server.py
を実行して ngrok 経由でアクセスすれば OK です。ただし jetbot_asr_server.py
を編集して Robot
クラスをコメントになっているダミーコードに入れ替えておいて下さい。
6. おわりに
今回は「Jetson 触ってみたい」という理由で音声コマンド認識に挑戦してみましたが、やはり物理的にモノがあると意外と手間がかかちゃいましたね。 次回は少し前の話になりますが、BERT の事前学習を改良した ELECTRA を試してみたいと思います。 BERT の改良は他にもいろいろあるのですが、 Google のコードなので、Colab の無料 TPU で回せそうですし。いつも文章分類ばっかりなので、固有表現抽出などやってみます。
-
「手元に JetBot なんてないよ!」という方の為にストリーミングでのコマンド認識そのものは Colab で動くようにしておきますのでご安心ください。 ↩
-
私もこの記事の内容は自宅の環境で動かして書いています。「公開するのは会社のFWの内側でなく Colab だから。。。」、とかいう話もあるかも知れませんが、GCS や各種 Google API へのアクセス権を乗っ取られたりするとえらいことなので。 ↩
-
WSS で接続したい理由があったはずなのですが失念してしまいました。単純に接続できなかったからなのですが、接続先が手元の環境だったか Colab だったか。。。 ↩
-
元々は先月の音声テキスト認識モデルをストリーミングで動作させるコードを切った貼ったしたモノなのでメソッド名とか不適切ですが見なかったことにしてください。。。 ↩
-
このコードは https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/02_Speech_Commands.ipynb から借用しています。Apache 2.0 ライセンスですね。 ↩
-
https://www.ogis-ri.co.jp/otc/hiroba/technical/lets-try-jetbot/ 最近はキットも販売してるみたいです。 ↩
-
https://github.com/NVIDIA-AI-IOT/jetbot/wiki/Software-Setup ↩
-
筆者は balenaEtcher-Setup-1.5.101.exe を落として使いました。 ↩
-
https://forums.developer.nvidia.com/t/pytorch-for-jetson-nano-version-1-6-0-now-available/72048 ↩
-
一度に入れても大丈夫だと思うのですが、実際の作業は NeMo のインストール時にアレが足らない、コレが無いと怒られながらの作業だったので、その時の手順ママで表記しています。 ↩
-
しれっと gsutil でコピーした体で書いてますが、 JetBot で gsutils のセットアップをした記憶がありません。。。JetBot から Colab のノートブックを開いて、GCS からノートブックのランタイムにコピーして、ノートブックのダウンロード機能で落としてきたのかも。。。 ↩
-
nframeoverlap(=8000) という変数名や 8000*2 と2倍しているのが意味不明ですが、音声テキスト認識のサンプルを作ってから、コマンド認識に変更したので直しきれてないんですね。。。 ↩