2017
03.16

Tensorflowの学習モデルの利用について

Internship

インターン生2年の村山です。
インターンの仕事の一つである機械学習を使ってワカサギを自動で釣るというイベントが終わりました。

このイベントはワカサギを釣る釣り竿の先端に加速度センサーをつけ加速度の値から釣り竿が静止している時、魚が引いた際に発生する加速度(竿の揺れ具合)の判定を行い、自動で糸の巻き上げを行う仕組みです。
釣り竿の部分はRaspberryPiを、学習の部分は機械学習のライブラリであるTensorflowを使用しました。
RaspberryPiに加速度センサーや釣り竿をくっつけて釣りはおこなうのですが、RaspberryPi上で何万というサンプルデータ数から学習モデルの生成をするのは処理速度的に厳しいです。
したがって、別の高性能なコンピュータで、静止している時、引いてる時の加速度のサンプルデータから学習モデルを作っておき、Tensorflowにあるcheckpointという学習モデルのセーブデータのようなものをRaspberryPiの方で読み込んで新しいパラメータ(加速度の値)を渡し学習の判定結果を得るという形をとりました。
checkpointは作る際にいくつかのファイルが生成されます。

そして学習モデルを使用するにはこれらのファイルをダウンロードして一つのディレクトリにまとめて置いておく必要があります。
また、学習モデルの生成に使った時のtensorflowの変数も実行側で再度宣言する必要があります。

入力データを入れて結果を得たいだけなのに、一つでも上記の変数の値が間違っていたら動かず、また動かない時どこが間違っているのか探すのも大変です。
それに加えて学習モデル(checkpoint)の生成で、学習ステップ数を500から1000に変えた時、生成されるcheckpointのファイルを全て500のから1000に取り替えなくてはいけず、学習モデルの精度を変更するたび結構手間でした。

なので別の方法を探しました。
こう・・ライブラリの中で・・・ファイルは一つだけで・・入力データを渡すだけでシュッと判定の結果が返ってくるようなものを・・・
探してみると、pbファイルを作成して読み込む方法がありました。
この方法だと学習に使った変数の宣言もいらず学習モデルのデータがあるpbファイルを一つ渡し、それを読み込むだけで済むことがわかりました。

まずpbファイルとcheckpointファイルの違いですが、

checkpoint

公式によると、大まかにはtensorの変数をバイナリファイルに保存する。と書いてありました。
重みやバイアスの変数を保存することで学習トレーニングの途中で中断、再開が可能になるということですね。
使うにはsaver=tf.train.Saver()でオブジェクトを生成し、saver.save()で変数を保存します。

pbファイル

まずpbファイルについてですが、これはProtocolBufferというGoogleで開発されている通信や永続化を目的としたシリアライズフォーマットです。複雑なデータ構造をバイナリデータで効率よく表現することができます。
TensorflowにはGraphオブジェクトがあり、ここにはinput、outputなどの学習に使用したオペレーションを表すノードが保持されています。
tf.Graph()でGraphオブジェクトを生成し、as_graph_def()でGraphDefオブジェクトを生成して使います。
GraphDefはProtocolBufferで作られ、そのファイルからコードを生成し、読み込みや値の保存、グラフの操作を可能にします。
なのでpbファイルはこのGraphDefオブジェクトを使うために必要そうです。
またどちらも入力と出力のノードを他のノードと区別するために、nameオプションでinputとoutputとつける必要があります。

参考資料をもとに学習のトレーニングが終わった後、その時に得た重みとバイアスを保存する形をとりました。
W_2とb_2にはトレーニング後の重み(_W)とバイアス(_b)が入っており、新しい入力データの変数を格納するx_2とこのW_2とb_2で計算し、判定を行います。
g_2.as_defaultで書き込む宣言を、write_graphでpbファイルが作成されます。

作成したpbファイルの学習精度が正しいかのテストもしました

読み込むのは探すといくつか方法はあったのですが結果としてうまくいったのはFastGFileでpbファイルを読み込む方法でした。
ParseFromString()で読み込み、tf.import_graph_def()でそのGraphDefオブジェクトをデフォルトのグラフに設定します。
sess.runでグラフを作成する際に用意したinputとoutputという名前をつけた変数を呼び出します。
outputにはすでに学習済みの重みとバイアスが入っているので新しい入力データであるinputの値を入れるだけで結果が出てきます。

学習に使ったサンプルデータの一つからoutputの出力結果はこのようになりました

上の配列がoutputで返ってくる結果です。どのクラスが一番入力データから尤もらしいかの値が入っています。
配列で返ってくるので一番値が大きいインデックスからクラスの判定をします(0だったらA,1だったらB..という感じです)

個人的にはこちらのが変数やcheckpointのファイル作成などが必要なくて好みです。
しかし調べていたらどうやらグラフとcheckpointを一つにまとめるfreeze_graph.pyなるものがありました。こちらも試してみたいです。

Comment

  1. No comments yet.

  1. No trackbacks yet.

You must be logged in to post a comment.