超言理論

特に益もない日記である

PyBrainで学習したニューラルネットワークを保存しておきたい

最近はPyBrainで一通りのことができるようにいろいろ勉強してます。
今回はタイトル通りですね。


大量のデータからニューラルネットを学習したりすることもあると思います。
こういうとき、学習には時間がかかりますね。
これを毎回やっていては埒があきません。ということで、一度学習したニューラルネットの重みデータを外部に保存して、そしてそれを利用して再度ニューラルネットワークを構築する方法を書いていこうと思います。

まず、学習したニューラルネットを保存するためには、PickleモジュールというPythonのプログラムが必要です。
このモジュールはリストや辞書、その他の複雑な構造を持ったオブジェクトを書き出すことができるようになるプログラムで、特に難しい手続きなしに「開いているファイルオブジェクト」に「指定したオブジェクト」を保存できるようになります。
そして、どうやらこのPickleはデフォルトでPythonに入っているようなので、普通にプログラム内で呼び出せます。
ということで、実際に適当な学習を行うPyBrainのプログラムにPickleを使ったニューラルネットの外部保存を書き足してみます。

from pybrain.datasets import SupervisedDataSet
from pybrain.tools.shortcuts import buildNetwork
from pybrain.supervised import BackpropTrainer
import pickle
import random
import math

NN = buildNetwork(1, 10, 1)

DataSet = SupervisedDataSet(1, 1)
for i in range(0, 1000):
	x = random.random() * math.pi * 2
	DataSet.addSample((x), math.sin(x)+random.random()/10,))

trainer = BackpropTrainer(NN, DataSet)
for i in range(0, 1000):
    trainer.train()

file = open('Trained.Network', 'w')
pickle.dump(NN, file)
file.close()

やっていることは単純で、まず

import pickle

でpickleモジュールを読み出し、

トレーニング済みのニューラルネットワークのデータNNを

file = open('Trained.Network', 'w')
pickle.dump(NN, file)
file.close()

で、読み出したファイル'Trained.Network'に書き込み(ダンプ)しているだけ。
これで学習したデータはすべてファイルに書き出されます。

そして、この書き出したニューラルネットワークのデータを利用するときは、同様にPickleを使って読み出してあげれば良くて、こんな感じになります。

from pybrain.datasets import SupervisedDataSet
from pybrain.tools.shortcuts import buildNetwork
from pybrain.supervised import BackpropTrainer
import pickle
import random
import math

file = open('Trained.Network')
NN = pickle.load(file)
file.close()

TestSet = SupervisedDataSet(1, 1)
for i in range(0, 1000):
	x = random.random() * math.pi * 2
	TestSet.addSample((x), math.sin(x)+random.random()/10,))

Trainer = BackpropTrainer(NN, DataSet, verbose = True)
trainer.testOnData(TestSet, verbose = True)

こっちも単純で、さっき保存したニューラルネットの情報を読み出してきて、Pickleを使って復元しているだけ。

file = open('Trained.Network')
NN = pickle.load(file)
file.close()

Pythonって便利なモジュールそろってて何するにも楽ちんですね!


あ、あと、PyBrainの方にも公式のファイル保存APIがあるようなんですが、まだ公式ドキュメントの方は読み終わってないので、またわかったら書こうと思います。どっちでもそう変わらないとは思いますけど。


Copyright © 2012-2016 Masahiro MIZUKAMI All Rights Reserved.