簡単なDNNの基本として2層ネットワークを目指したライブラリ(towlayernet.py)のソースです。
以下ではTwoLayerNetクラスを定義しています。
ここでは、バッチ学習を考慮いていません。
(バッチで適切なライブラリを使うと、高速でメモリ効率の向上も期待できるらしいが、全部自作なので単純さを重要視しました。)
この実装では、Numpyクラスの代わりにNumArrayクラスを利用しています。
またUnityと連係するコードは、こちらで紹介しています。
using System.Collections;
using System.Collections.Generic;
using System.IO;
using UnityEngine;
public class TwoLayerNet
{
NumArray naW1 = new NumArray();// 1層目の重みパラメタ
NumArray naB1 = new NumArray();// 1層目のバイアスパラメタ
NumArray naW2 = new NumArray();// 2層目の重みパラメタ
NumArray naB2 = new NumArray();// 2層目のバイアスパラメタ
NumArray naA1 = null; // predictの計算過程で一時記憶し、gradientで利用
NumArray naZ1 = null; // predictの計算過程で一時記憶し、gradientで利用
NumArray gradsW2 = null;// gradientにる勾配情報の一次記憶用
NumArray gradsB2 = null;
NumArray gradsW1 = null;
NumArray gradsB1 = null;
public float loss;// 学習で進行状態(交差エントロピー誤差)
float learning_rate = 0.1f; // 学習率
// コンストラクタ
public TwoLayerNet(string path = "")
{
using (FileStream fileStream = new FileStream(path, FileMode.Open))
{
naW1.read(fileStream);// 学習済みの重み、バイアス情報を読み取る。
naB1.read(fileStream);
naW2.read(fileStream);
naB2.read(fileStream);
}
}
public TwoLayerNet(int input_size = 784, int hidden_size = 50, int output_size = 10, float weight_init_std = 0.01f)
{
this.naW1 = NumArray.createGaussian(hidden_size, input_size, 1, 0, weight_init_std);
this.naB1 = NumArray.createParam(0, hidden_size);
this.naW2 = NumArray.createGaussian(output_size, hidden_size, 1, 0, weight_init_std);
this.naB2 = NumArray.createParam(0, output_size);
}
// 予測メソッド x のグレー画像(28×28)の判定をソフトマックスを介して返す。
public NumArray predict(NumArray x)
{
//Debug.Log("x:" + x);
//Debug.Log($"naW1:{this.naW1}, naW1.shapeR{naW1.shapeR}");
//Debug.Log("naB1:" + this.naB1);
this.naA1 = NumArray.dot(x, this.naW1);
//Debug.Log("naA1:" + naA1);
this.naA1.add_matrix(this.naB1);
//Debug.Log("naA1:" + naA1);
this.naZ1 = naA1.sigmoid();
//Debug.Log("naZ1:" + naZ1);
NumArray naA2 = NumArray.dot(this.naZ1, this.naW2);
naA2.add_matrix(this.naB2);
//Debug.Log("naA2:" + naA2);
NumArray naY = NumArray.softmax(naA2);
return naY;// 確率分布を返す。
}
// 勾配降下法でシグモイド関数の勾配を求める時に利用するメソッド
NumArray sigmoid_grad(NumArray x)
{ // (1.0 - sigmoid(x)) * sigmoid(x)
NumArray sigmoid_x = x.sigmoid();
NumArray grad = sigmoid_x.deepcopy();
grad.mul_scalar(-1);
grad.add_scalar(1);
grad.mul_matrix(sigmoid_x);
return grad;
}
// x:入力データ, t:教師データ 損失との勾配を誤差逆伝搬法で求め、gradsW2,gradsB2,gradsW1,gradsB1に設定
public void gradient_and_set_W2_B2_W1_B1(NumArray x, NumArray t)
{
//Debug.Log($"x:{x}");
//Debug.Log($"t:{t}");
NumArray naY = predict(x); // 予測値を取得
//Debug.Log($"naY:{naY}, naY.shapeR:{naY.shapeR}");
this.loss = NumArray.cross_entropy_error(naY, t);// 交差エントロピー誤差取得
//Debug.Log($"-----cross_entropy_error:{loss}");
// -------W1,B1, W2,B2 の予測値から、正解に近づける勾配(gradsW2,B2,W1,B1)を算出---------
NumArray naDy = t.deepcopy();
naDy.mul_scalar(-1);
naDy.add_matrix(naY);
//Debug.Log($"naDy:{naDy}, naDy.shapeR:{naDy.shapeR}");
this.gradsW2 = NumArray.dot(this.naZ1.T(), naDy);
//Debug.Log($"gradsW2:{gradsW2}, gradsW2.shapeR:{gradsW2.shapeR}");
this.gradsB2 = NumArray.sum(naDy, 0);
//Debug.Log($"gradsB2:{gradsB2}, gradsB2.shapeR:{gradsB2.shapeR}");
NumArray da1 = NumArray.dot(naDy, this.naW2.T());
//Debug.Log($"da1:{da1}, da1.shapeR:{da1.shapeR}");
NumArray dz1 = this.sigmoid_grad(this.naA1);
dz1.mul_matrix(da1);
//Debug.Log($"dz1:{dz1}, dz1.shapeR:{dz1.shapeR}");
this.gradsW1 = NumArray.dot(x.T(), dz1);
//Debug.Log($"gradsW1:{gradsW1}, gradsW1.shapeR:{gradsW1.shapeR}");
this.gradsB1 = NumArray.sum(dz1, 0);
//Debug.Log($"gradsB1:{gradsB1}, gradsB1.shapeR:{gradsB1.shapeR}");
// -------W1,B1, W2,B2 の重みバイアスを正解に近づける更新----------------
this.gradsW1.mul_scalar( -learning_rate );
this.naW1.add_matrix(this.gradsW1);
// Debug.Log($"naW1:{naW1}, naW1.shapeR:{naW1.shapeR}");
this.gradsB1.mul_scalar(-learning_rate);
this.naB1.add_matrix(this.gradsB1);
// Debug.Log($"naB1:{naB1}, naB1.shapeR:{naB1.shapeR}");
this.gradsW2.mul_scalar(-learning_rate);
this.naW2.add_matrix(this.gradsW2);
// Debug.Log($"naW2:{naW2}, naW1.shapeR:{naW2.shapeR}");
this.gradsB2.mul_scalar(-learning_rate);
this.naB2.add_matrix(this.gradsB2);
// Debug.Log($"naB2:{naB2}, naB2.shapeR:{naB2.shapeR}");
//throw new System.Exception("強制停止");//デバック用
}
// W1,B1, W2,B2 の重みバイアスのパラメタをファイルにシリアライスする。
public void save_params(string path= "weight_bias_params_0.bin")
{
using (FileStream fileStream = new FileStream("test.bin", FileMode.Create))
{
naW1.write(fileStream);// 直列化(Serialize) して保存
naB1.write(fileStream);
naW2.write(fileStream);
naB2.write(fileStream);
}
}
}