TOP PAGE

Unityで行うシンプルなディープニューラルネットワーク

ディープニューラルネットワーク(DNN:Deep Neural Network)の基本として、 2層ネットワークを定義し、それで MNIST(機械学習で便利な手書き画像データセット)を利用した学習実験を行える作品です。
(作成に当たって、書籍:ゼロから作るDeep Learning(978-4-87311-758-4)のコードを参考にしました。)
なお、バッチ学習には対応していません。
バッチ学習で適切なライブラリを使うと高速でメモリ効率向上を期待できるらしいのですが、 一般的なライブラリを使わずに全部自作なので、単純にできることを重視しました。

以下が実行画面のイメージで、@からCのボタン操作で素材の取得から学習、予測まで簡単に確認できます。
次のように操作できます。
過去に学習したデータに対し、それに続けて学習する場合は、起動直後にCボタンの操作を行い、それからBボタンの操作で学習を継続できます。
(起動直後で、先にBボタンの操作を行うと、学習済みデータは消えて、最初からの学習になります)

Unityプロジェクトとソース(パッケージ)

以下のように作りました。

このプロジェクトを実現しているソースコードは、Canvasにアタッチする後述の TrrainCanvas.cs と、下記のライブラリ的な2つのファイルで出来ています。 TrrainCanvas.csのコードを以下に示します。
MNISTから6万個の画像を入手していますが、その中で下記では先頭から(train_size = 500)個だけで学習しています。
よって、Dの添え字入力で500〜59999の入力範囲は、学習に使われなかった画像の予測判定になります。
ある程度正解しうる予測値にするためには、Learnボタンで,5000回以上の学習が必要でしょう。
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using System;
using System.Net;
using System.IO.Compression;
using System.IO;
using UnityEngine.UI;// RawImag用
using TMPro;

public class TrrainCanvas : MonoBehaviour
{
    [SerializeField] RawImage rawImage; // 検証画像の表示用■■■■■
    public int width = 28;
    public int height = 28;
    private Texture2D texture;// 上記rawImage用テクスチャ
    private Color32[] pixels;  // 上記textureの更新用

    NumArray x_train = null; // [1,60000,28*28]の手書き画像
    NumArray t_train = null; // One-Hot ベクトル表現で、学習目標

    [SerializeField] TMP_InputField inputIdxImg;//using TMPro;■■■■■
    private int idxImg = 0;// 上記描画画面を指定する添え字
    private int idxImgPrev = -1;// 上記描画画面を指定する添え字


    [SerializeField] TextMeshProUGUI textTarget;//using TMPro ■■■■■
    [SerializeField] TextMeshProUGUI textPredict;//using TMPro ■■■■■

    string x_train_file = "x_train.bin";// 学習や検証素材を、前処理してシリアラスした画像ファイル(28*28の6万個)
    string t_train_file = "t_train_a.bin";// 上記の正解ファイルで、One-Hot ベクトル表現に前処理してファイル

    TwoLayerNet twoLayerNet = null;
    static Color32[] defualtPixels = null;

    int train_size = 500;//  60000 個画像の先頭から、何個までを訓練用に使うかのサイズ
    int iters_num = 0; // 勾配法の算出繰り返しの回数指定変数(Learnボタンで設定)
    int iters_cnt = 0;//上記の学習カウント(iters_num>iters_cnt以外が学習中)
    float train_next_time=0;

    static Color32[] GetDefualPixels(int width, int height)// RawImageに初期描画する模様生成
    {
        // 画素配列を初期化
        Color32[] pixels = new Color32[width * height];
        for (int y = 0; y < height; y++)
        {
            for (int x = 0; x < width; x++)
            {
                if ((x + y) % 2 == 0)
                {
                    pixels[y * width + x] = new Color32(255, 255, 255, 255);// Color.white;
                }
                else
                {
                    pixels[y * width + x] = Color.black;// 市松模様
                }
            }
        }
        return pixels;
    }
    void Start()
    {
        UnityEngine.Random.InitState(123);//乱数シード指定(完成後にコメント化の予定)

        // テクスチャを生成
        texture = new Texture2D(width, height, TextureFormat.RGBA32, false);
        texture.filterMode = FilterMode.Point; // ドット絵のように表示する場合
        rawImage.texture = texture;

        defualtPixels = GetDefualPixels(width , height);// 画素配列を初期化
    }

    void FixedUpdate()
    {
        if (this.Learning())//  学習中
        {
            return;
        }
        string strInputIdx = this.inputIdxImg.text;// 6万個の画像のオペレータの入力位置
        try
        {
            this.idxImg = int.Parse(strInputIdx);// 入力文字列から添え字取得
            this.textTarget.text = "One-hot representation of answer";// 正解表示部の初期表
            if (this.x_train != null)// 入力添え字の画像と正解のワンホットの表現の文字列表示
            {
                this.pixels = x_train.getPixels(width, height, idxImg);// 位置から画像を取得
                this.textTarget.text = t_train.getOneOfKstring(idxImg);
            } else
            {
                this.pixels = defualtPixels;// 画素配列を初期化
            }

            if( this.twoLayerNet != null)// // 入力添え字の画像の予測判定結果表示
            {
                if (this.idxImg == this.idxImgPrev) return;
                this.idxImgPrev = this.idxImg;

                NumArray naX = this.x_train.createLineAt(idxImg);

                NumArray naY = twoLayerNet.predict(naX);
                Debug.Log($"naY.shape:{naY.shapeR}");
                naY.shapeR[1] = 1;
                this.textPredict.text = $"{naY.getOneOfKstring(0)}";// ワンホット表現の文字列取得

            } else
            {
                this.textPredict.text = "Prediction results"; // 予測確率分布表示部の初期表
            }
        }
        catch (Exception e) {
            this.pixels = defualtPixels;// 画素配列を初期化
            Debug.Log(e); 
        }

        texture.SetPixels32(pixels);// テクスチャに画素群(pixels)を設定して適用
        texture.Apply();
    }

    public static void MNIST_downloader()
    {
        string url = "https://storage.googleapis.com/cvdf-datasets/mnist/"; // MNIST データベース の URL

        System.Diagnostics.Stopwatch sw = new System.Diagnostics.Stopwatch(); // ダウインロード時間確認用
        sw.Start();
        Debug.Log("START: " + sw.Elapsed.TotalSeconds + " seconds");

        string[] download_files = { "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz" };

        foreach (string file in download_files)
        {
            string download_url = url + file;
            using (WebClient client = new WebClient())
            {
                client.DownloadFile(download_url, file);
                Debug.Log(download_url + "をダウインロードしました");
            }
        }
        sw.Stop();
        Debug.Log("END: " + sw.Elapsed.TotalSeconds + " seconds");
        Debug.Log("ダウンロード時間: " + sw.Elapsed.TotalSeconds + " seconds");
    }

    // 学習素材の前処理を行って、x_train_fileとt_train_fileファイルを生成し、
    // this.x_trainと、 this.t_trainを設定
    public void Pretreatment() // ダウンロードファイルの前処理
    {
        try
        {
            mk_x_train_bin();//手書き画像6万件の前処理
            mk_t_train_bin();//手書き画像6万件の答えをOne-Hot ベクトル表現へ前処理
        }
        catch (Exception e)
        {
            Debug.Log($"エラー: {e.Message}");
        }

    }

    void mk_x_train_bin()//手書き画像6万件の前処理
    {
        string filePath = "train-images-idx3-ubyte.gz"; // ファイルパス
        int imageWidth = 28;
        int imageHeight = 28;
        int imageCount = 60000;// この数の画像

        byte[] imagesBytes = null;// train-images-idx3-ubyte.gzの画像部のbyte列記憶域
        try
        {
            using (FileStream fileStream = new FileStream(filePath, FileMode.Open))
            using (GZipStream gzipStream = new GZipStream(fileStream, CompressionMode.Decompress))
            using (BinaryReader reader = new BinaryReader(gzipStream))
            {
                reader.ReadBytes(16);// メタデータを読み飛ばす (16 バイト)

                // 画像データを読み込む
                imagesBytes = reader.ReadBytes(imageWidth * imageHeight * imageCount);
            }
        }
        catch (Exception ex)
        {
            Debug.Log($"エラー: {ex.Message}");
        }

        // イメージデータを前処理して記憶
        this.x_train = NumArray.createByBytes(imagesBytes, 28 * 28);
        this.x_train.mul_scalar(1.0f / 255f);// 画像データ0〜255を、0〜1.0に正規化
        using (FileStream fileStream = new FileStream(x_train_file, FileMode.Create))
        {
            this.x_train.write(fileStream);//"x_train.bin"に前処理した画像(28*28の6万個)データをシリアライズ
        }
    }

    void mk_t_train_bin()//手書き画像6万件の答えをOne-Hot ベクトル表現へ前処理
    {
        string filePath = "train-labels-idx1-ubyte.gz"; // ファイルパス
        int imageCount = 60000;// この数の画像

        byte[] buf = null;// train-images-idx3-ubyte.gzの画像部のbyte列記憶域
        try
        {
            using (FileStream fileStream = new FileStream(filePath, FileMode.Open))
            using (GZipStream gzipStream = new GZipStream(fileStream, CompressionMode.Decompress))
            using (BinaryReader reader = new BinaryReader(gzipStream))
            {
                reader.ReadBytes(8);// メタデータを読み飛ばす (8 バイト)              
                buf = reader.ReadBytes(imageCount);// 画像の正解データを読み込む
            }
        }
        catch (Exception ex)
        {
            Debug.Log($"エラー: {ex.Message}");
        }

        // イメージデータを前処理して記憶
        this.t_train = NumArray.create(10, buf.Length);//One-Hot ベクトル表現
        for (int i = 0; i < buf.Length; i++)
        {
            this.t_train.FPAt(buf[i], i, 0) = 1.0f;
            //Debug.Log(this.t_train.getOneOfKstring(i));
        }
        using (FileStream fileStream = new FileStream(t_train_file, FileMode.Create))
        {
            this.t_train.write(fileStream);
        }
    }

    public void predict()
    {
        // 予測実行
        try
        {
            if (this.x_train == null)
            {
                this.x_train = new NumArray();
                using (FileStream fileStream = new FileStream(x_train_file, FileMode.Open))
                {
                    this.x_train.read(fileStream);
                }

                this.t_train = new NumArray();
                using (FileStream fileStream = new FileStream(t_train_file, FileMode.Open))
                {
                    this.t_train.read(fileStream);
                }
            }
            // 学習済みデータの読み込み
            if(this.twoLayerNet  == null) this.twoLayerNet = new TwoLayerNet("weight_bias_params_0.bin");
        }
        catch(Exception e)
        {
            this.textPredict.text = e.ToString();
        }
    }

    // 学習スタート実行(-W1,B1,  W2,B2 の重みバイアスを正解に近づける一回の計算)
    public void Learn()
    {
        if (this.iters_cnt < this.iters_num) return;// 既に学習中

        if (this.x_train == null)
        {
            this.x_train = new NumArray();
            using (FileStream fileStream = new FileStream(x_train_file, FileMode.Open))
            {
                this.x_train.read(fileStream);
            }

            this.t_train = new NumArray();
            using (FileStream fileStream = new FileStream(t_train_file, FileMode.Open))
            {
                this.t_train.read(fileStream);
            }
        }

        iters_num += 1000; // この数の学習を指示

        if(this.twoLayerNet == null)
        {
            this.twoLayerNet = new TwoLayerNet(28 * 28, 50, 10, 0.01f);
            //this.twoLayerNet = new TwoLayerNet("weight_bias_params_0.bin");
            this.iters_cnt = 0;
        }
    }

    bool Learning()// 学習の実行中(次のFixupdateタイミング近くまで、勾配降下法で解に近づける)
    {
        if (this.twoLayerNet == null) return false;
        if (this.iters_cnt == this.iters_num) return false;
        if (this.iters_cnt > this.iters_num)
        {
            this.iters_cnt = this.iters_num;// 学習の終了
            return false;
        }

        for ( ; ; )
        {
            if (this.iters_cnt >= iters_num) return false;// Learnボタンによる所定回数の学習が終わった。
            this.iters_cnt++;

            int idx = NumArray.random.Next(0, train_size);// 学習対象の画像取得

            NumArray x_batch = this.x_train.createLineAt(idx);
            x_batch.shapeR[1] = 1;
            NumArray t_batch = this.t_train.createLineAt(idx);
            t_batch.shapeR[1] = 1;
            //Debug.Log($"x_batch:{x_batch}");
            //Debug.Log($"t_batch:{t_batch}");

            this.twoLayerNet.gradient_and_set_W2_B2_W1_B1(x_batch, t_batch);// twoLayerNet内の gradsW2,B2,W1,B1の傾き設定

            if (this.train_next_time < Time.time) break;
        }
        string msg = $"count:{iters_cnt,7}, cross entropy loss:{this.twoLayerNet.loss,10:F6} ";
        Debug.Log(msg);
        this.textPredict.text = msg;
        this.train_next_time = Time.time + Time.fixedDeltaTime * 0.9f;
        return true;
    }

    private void OnApplicationQuit()
    {
        this.twoLayerNet.save_params("weight_bias_params_0.bin");// 学習パラメタの保存
    }
}