参考サイト: tensorflow

pythonメニュー

tensorflowの変数保存と復元

tensorflowの変数保存と復元

tensorflow.Variable用の変数が保存対象です。
以下では、counter変数で5回カウントして、最後で保存しています。
# -*- coding: utf-8 -*-
#<meta charaset="utf-8">

import numpy as np
import tensorflow as tf

counter=tf.Variable(0,name="counter") # 初期値ゼロの変数を作った。
count_up=counter+1
count_op=tf.assign(counter, count_up) # 「counter = counter+1」に相当する処理 

counter2=tf.Variable(9999,name="counter2") # 初期値9999の変数を作った。

saver = tf.train.Saver() # 保存の準備

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for n in range(5):
       print(  sess.run(counter2) , sess.run(count_op) )
       #saver.save(sess, "inc_models/model.ckpt") # 保存

    saver.save(sess, "inc_models/model.ckpt") # 保存
実行結果は次の通り
9999 1
9999 2
9999 3
9999 4
9999 5
上記のsaveによる保存で、tensorflowに使っているすべての変数を保存しています。
上記例では、counterとcounter2が記憶対象で、inc_modelsディレクト地内に次のファイルが生成される。
checkpoint
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta
上記コードで、最後行の保存(saver.save)をコメントにして、 代わりに繰り返し内のsaver.saveのコメントを外し、毎回保存させた。
その場合の最後に残るファイルは上記と同じで、上書き保存されているのが分かる。

上記の実行後、次のソースファイルで実行の継続を行えます。
# -*- coding: utf-8 -*-
#<meta charaset="uft-8">

# tasave1.py で行った続きの計算を行わせる目標
import numpy as np
import tensorflow as tf

counter=tf.Variable(0,name="counter") # 初期値ゼロの変数を作った。
count_up=counter+1
count_op=tf.assign(counter, count_up) # 「counter = counter+1」に相当する処理 

counter2=tf.Variable(0,name="counter2") # 初期値0の変数を作った。

saver = tf.train.Saver() # 保存の準備

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, "inc_models/model.ckpt") # これでファイル群から保存時の変数に復元している。 (上記と違うコードはこれだけ)
    for n in range(5):
       print(  sess.run(counter2) , sess.run(count_op) )
    
    saver.save(sess, "inc_models/model.ckpt")

継続処理なので、この実行前に、継続に必要なファイル群を作るプログラムを実行しておく必要があります。
前回のプログラム実行でファイル群を作り、それからこのプログラムを実行した結果を以下に示します。
9999 6
9999 7
9999 8
9999 9
9999 10
前回行ったデータの続きが、実行できています。
このプログラムの最後でtensorflowの変数を保存しているので、もう一度実行すると、続きができます。
9999 11
9999 12
9999 13
9999 14
9999 15
上記のsaver.restoreを使う場合、前回の実行結果が記憶される"inc_models/model.ckpt"のファイルがないと実行エラーです。

上記2つをまとめたプログラム

「inc_models」の保存ファイル用フォルダが無い状態で実行でき、次回はその継続ができるプログラムです。
# -*- coding: utf-8 -*-
#<meta charaset="uft-8">

import os
import numpy as np
import tensorflow as tf

counter=tf.Variable(0,name="counter") # 初期値ゼロの変数を作った。
count_up=counter+1
count_op=tf.assign(counter, count_up) # 「counter = counter+1」に相当する処理 

saver = tf.train.Saver()# 保存の準備

continueFlag = os.path.isfile("./inc_models/model.ckpt.index") # リストア用のファイルがあるか?
if not continueFlag :
   print("最初の実行です。");
else:
   print("前回の継続の実行です。");

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    if continueFlag :
       saver.restore(sess, "inc_models/model.ckpt") # 前回の継続用のリストア
    
    for n in range(5):
       print( sess.run(count_op) )
    
    saver.save(sess, "inc_models/model.ckpt") # 次に継続できるように、保存する。

"./inc_models/model.ckpt.index"のファイルが無い状態で実行すると、 「最初の実行です。」の表示で、5回カウントアップして終了します。
この終了の直前に、tensorfowの変数を保存しています。
再び、実行する時は"./inc_models/model.ckpt.index"のファイルがあるので、 「前回の継続の実行です。」の表示をして5回カウントアップして終了します。

tensorflowの変数をglobal_stepの番号を指定して保存し、番号を指定して復元する。

前回と同じように、counterを記憶する例で、10回に1回保存します。
ローカル変数_global_stepで、50回のカウントが行われ、global_stepが10,20,30,40,50で変数の保存します。
以下のプログラムでは、コマンドラインで保存時に使ったglobal_stepの番号を指定すると、 その続きの演算ができるようにしました。
# -*- coding: utf-8 -*-
#<meta charaset="utf-8"> tsave3.py

import sys 
import os 
import numpy as np
import tensorflow as tf

start_global_step_value = 0 #この値で過去の指定した履歴より開始できるようにする。
args = sys.argv # コマンドライン引数取得
if len(args) >= 2: #コマンドライン引数がある?
    start_global_step_value=int(args[1]) # 引数を数値に変えて、復元時の番号に利用
else:
    print('最初の実行です。')

if start_global_step_value != 0:
   chptpath = "models/model.ckpt-{}.index".format( start_global_step_value )
   if not os.path.isfile( chptpath ) : 
      print(chptpath + "の対象の保存ファイルが見つかりません。")
      # 【E】
      sys.exit(1)
   else:
      print(chptpath[0:-6] + "の対象の保存ファイルからリストアして継続します。")

# 【A】
counter=tf.Variable(0,name="counter") # 初期値ゼロの変数を作った。
count_up=counter+1
count_op=tf.assign(counter, count_up) # 「counter = counter+1」に相当する処理 

saver = tf.train.Saver()

with tf.Session() as sess:
    _global_step = 0 # ローカル変数
    sess.run(tf.global_variables_initializer())
    if start_global_step_value != 0:
        saver.restore(sess, chptpath[0:-6] )
        _global_step = start_global_step_value
        # 【C】
    
    for _global_step in range(_global_step+1, _global_step + 51):
       v = sess.run(count_op)
       if (_global_step) % 10 == 0: # 10回ごとに表示と保存
          print(v, "save_step file: models/model.ckpt-{}.index".format(_global_step) )
          # 【B】 
          saver.save(sess, "models/model.ckpt" , global_step=_global_step)

上記をtsave3.pyで作成した場合の実行例を示します。
(my2) F:\python_ai\tensortest>python tsave3.py
最初の実行です。
10 save_step file: models/model.ckpt-10.index
20 save_step file: models/model.ckpt-20.index
30 save_step file: models/model.ckpt-30.index
40 save_step file: models/model.ckpt-40.index
50 save_step file: models/model.ckpt-50.index

(my2) F:\python_ai\tensortest>dir models
 ドライブ F のボリューム ラベルは WORK です
 ボリューム シリアル番号は 980C-420E です

 F:\python_ai\tensortest\models のディレクトリ

2018/12/21  00:15    <DIR>          .
2018/12/21  00:15    <DIR>          ..
2018/12/21  00:15               259 checkpoint
2018/12/21  00:15                 4 model.ckpt-10.data-00000-of-00001
2018/12/21  00:15               123 model.ckpt-10.index
2018/12/21  00:15             2,770 model.ckpt-10.meta
2018/12/21  00:15                 4 model.ckpt-20.data-00000-of-00001
2018/12/21  00:15               123 model.ckpt-20.index
2018/12/21  00:15             2,770 model.ckpt-20.meta
2018/12/21  00:15                 4 model.ckpt-30.data-00000-of-00001
2018/12/21  00:15               123 model.ckpt-30.index
2018/12/21  00:15             2,770 model.ckpt-30.meta
2018/12/21  00:15                 4 model.ckpt-40.data-00000-of-00001
2018/12/21  00:15               123 model.ckpt-40.index
2018/12/21  00:15             2,770 model.ckpt-40.meta
2018/12/21  00:15                 4 model.ckpt-50.data-00000-of-00001
2018/12/21  00:15               123 model.ckpt-50.index
2018/12/21  00:15             2,770 model.ckpt-50.meta
              16 個のファイル              14,744 バイト
               2 個のディレクトリ  40,500,297,728 バイトの空き領域

(my2) F:\python_ai\tensortest>python tsave3.py 30
models/model.ckpt-30の対象の保存ファイルからリストアして継続します。
40 save_step file: models/model.ckpt-40.index
50 save_step file: models/model.ckpt-50.index
60 save_step file: models/model.ckpt-60.index
70 save_step file: models/model.ckpt-70.index
80 save_step file: models/model.ckpt-80.index

(my2) F:\python_ai\tensortest>
上記の実行例で行った最後のコマンド操作「python tsave3.py 30」は、
「model.ckpt-30.index」にファイルより復元して、その続きを行わせるための表現です。

なお、保存時に指定するローカル変数「_global_step」をtensorflowの変数として保存する場合の追加コードを示します。
まず、【A】の箇所に次のコードを追加します。
global_step_var = tf.Variable(0, name='global_step') # ローカル変数を保存用に使うTensorflow変数 
global_step_holder = tf.placeholder(tf.int32)
global_step_op = global_step_var.assign(global_step_holder) # 変数への設定オペレーション
そして、【B】の箇所に次のコードを追加して「_global_step」をtensorflowの変数に設定します。
sess.run(global_step_op, feed_dict={global_step_holder: _global_step}) # ローカル変数をTensorflow変数に設定
最後に【C】の箇所に次のコードを追加して復元した変数をローカル変数に戻します。
_global_step = sess.run(global_step_var) # リストア時のglobal_step取得

なお存在しないチェックポイントの番号を指定した場合、終了するようになっているが、その時の表示情報を追加する。
【E】の箇所に、次のコードを追加します。(tf.train.get_checkpoint_stateで前回の保存情報を取得します。)
      checkpoint_state = tf.train.get_checkpoint_state("./models/") # 保存時のチェックポイント情報取得
      print("checkpoint_state-->>",checkpoint_state) # 最後に保存された checkpointファイル群情報
      print("type(checkpoint_state):",type(checkpoint_state))
      if checkpoint_state:
          print("最後の保存情報:", checkpoint_state.model_checkpoint_path)

この場合、前回の続きで、存在しないチェックポイントの番号を指定した例は次のようになります。
(my2) F:\python_ai\tensortest>python tsave3.py 99
models/model.ckpt-99.indexの対象の保存ファイルが見つかりません。
checkpoint_state-->> model_checkpoint_path: "./models/model.ckpt-80"
all_model_checkpoint_paths: "./models/model.ckpt-40"
all_model_checkpoint_paths: "./models/model.ckpt-50"
all_model_checkpoint_paths: "./models/model.ckpt-60"
all_model_checkpoint_paths: "./models/model.ckpt-70"
all_model_checkpoint_paths: "./models/model.ckpt-80"

type(checkpoint_state): <class 'tensorflow.python.training.checkpoint_state_pb2.CheckpointState'>
最後の保存情報: ./models/model.ckpt-80

(my2) F:\python_ai\tensortest>

global_step指定を、tf.app.flags.FLAGSと連動させて、使いやすくする検討

上記で global_step の使い方を示しましたが、これに tf.app.flags.FLAGSを連動させて、使い勝手の向上を図ります。

tf.app.flags.FLAGSとは?

TensorFlow の tf.app.flags.FLAGS は、ファイル実行時にパラメタを付与できるようにするものです
tf.app.flags.FLAGSを使うと、デフォルト値やヘルプ画面の説明文などのパラメタを作って記憶できます。
String型用tf.app.flags.DEFINE_strinメソッド、tf.app.flags.DEFINE_integerの例を示します。
(他に tf.app.flags.DEFINE_float、tf.app.flags.DEFINE_boolean等もあります)
パラメタというのは、「オブジェクトのメンバとなるパラメタ」です。
まず、実行の具体例で示します。 それは、次のように「python testFLAGS.py --help」と実行できて、使いか方が確認できて、パラメタが指定できるものです。
(my2) F:\python_ai\tensortest>python testFLAGS.py --help

       USAGE: testFLAGS.py [flags]
flags:

testFLAGS.py:
  --checkpoint_dir: 学習したデータを記憶するディレクトリ
    (default: './savedata')
  --train_numb: トレーニングデータ数
    (default: '100')
    (an integer)

Try --helpfull to get a list of all flags.

(my2) F:\python_ai\tensortest>
パラメタを与えた場合の実行例も、以下で示します。
(my2) F:\python_ai\tensortest>python testFLAGS.py --train_numb 200 --checkpoint_dir ../train
train_numb: 200
checkpoint_dir: ../train
appFLAGS: ['checkpoint_dir', 'h', 'help', 'helpfull', 'helpshort', 'train_numb']

(my2) F:\python_ai\tensortest>
このプログラムコード例を下記に示します。
# -*- coding: utf-8 -*-
#<meta charaset="uft-8"> testFLAGS.py

import tensorflow as tf
tf.app.flags.DEFINE_integer('train_numb', 100, """トレーニングデータ数""")
tf.app.flags.DEFINE_string('checkpoint_dir', './savedata', """学習したデータを記憶するディレクトリ""")
appFLAGS = tf.app.flags.FLAGS # 変数に参照させ扱いやすくしている

def main(argv):
    print("train_numb:", appFLAGS.train_numb)
    print("checkpoint_dir:", appFLAGS.checkpoint_dir)
    print("appFLAGS:" , dir(appFLAGS) )

if __name__ == '__main__':
    tf.app.run()

上記のコードで使っている「 tf.app.flags.DEFINE_string 」の一般的使い方の書式は次の通りです。
tf.app.flags.DEFINE_string('パラメタとなる変数名', 'デフォルト値', """説明文""") 

tf.app.flags.FLAGSを使ってglobal_stepを指定する例

引数がない場合の挙動を次のようにする。
デフォルト保存ディレクトリが無い場合は、新規カウントを行い、在る場合は前回の継続カウントを行う。

引数のglobal_stepがある場合の挙動を次のようにする。
デフォルト保存ディレクトリが無い場合は、新規カウントを行い、在る場合はcheckpointで指定する次からの継続を行う。
このコード例を下記に示します。
# -*- coding: utf-8 -*-
#<meta charaset="uft-8">
import os
import sys
import tensorflow as tf

tf.app.flags.DEFINE_string('checkpoint_dir', "./models/",  "学習したデータを記憶するディレクトリ")
tf.app.flags.DEFINE_integer('checkpoint', -1, "継続用 checkpoint 番号(0は新規実行、-1は前回の最後より継続")

counter=tf.Variable(0,name="counter") # 初期値ゼロの変数を作った。
count_up=counter+1
count_op=tf.assign(counter, count_up) # 「counter = counter+1」に相当する処理 

saver = tf.train.Saver()

def main(argv):
  appFLAGS = tf.app.flags.FLAGS 
  continueFlag=os.path.isdir( appFLAGS.checkpoint_dir ) # 過去に実行して、学習フォルダがある。
  checkpoint_state = tf.train.get_checkpoint_state(appFLAGS.checkpoint_dir) # 保存時のチェックポイント情報取得
  if checkpoint_state:
    print("前回保存した checkpointのファイル群情報", checkpoint_state)
    print("checkpoint.model_checkpoint_path:", checkpoint_state.model_checkpoint_path)

  with tf.Session() as sess:
    if not checkpoint_state or appFLAGS.checkpoint == 0: # 保存時のチェックポイント情報が無い?または引数が0の場合
        sess.run(tf.global_variables_initializer())
        print("新規に最初の実行から始めます。");
        _global_step = 0
    elif appFLAGS.checkpoint == -1:
        checkpoint_path = checkpoint_state.model_checkpoint_path;
        print(checkpoint_path,"からの次からの継続実行です。");
        saver.restore(sess, checkpoint_path) # 変数の復元
        _global_step = int( checkpoint_path[ checkpoint_path.rfind('-')+1:] ) # 文字列末尾の番号を得る。
    else:
        _global_step = appFLAGS.checkpoint # 引数より取得
        checkpoint_path="./models/model.ckpt-" + str(_global_step)
        print(checkpoint_path,"からの次からの継続実行です。");
        saver.restore(sess , checkpoint_path)
    print("--開始--")
    for _global_step in range(_global_step+1, _global_step + 51):
        v = sess.run(count_op)
        if (_global_step) % 10 == 0: # 10回ごとに表示と保存
            print(v, "save_step file: models/model.ckpt-{}.index".format(_global_step) )
            saver.save(sess, "models/model.ckpt" , global_step=_global_step)

if __name__ == '__main__':
    tf.app.run()


上記を「tsave4.py」のファイルで作った場合の実行例を示します。
(my2) F:\python_ai\tensortest>python tsave4.py
新規に最初の実行から始めます。
--開始--
10 save_step file: models/model.ckpt-10.index
20 save_step file: models/model.ckpt-20.index
30 save_step file: models/model.ckpt-30.index
40 save_step file: models/model.ckpt-40.index
50 save_step file: models/model.ckpt-50.index

(my2) F:\python_ai\tensortest>python tsave4.py
前回保存した checkpointのファイル群情報 model_checkpoint_path: "./models/model.ckpt-50"
all_model_checkpoint_paths: "./models/model.ckpt-10"
all_model_checkpoint_paths: "./models/model.ckpt-20"
all_model_checkpoint_paths: "./models/model.ckpt-30"
all_model_checkpoint_paths: "./models/model.ckpt-40"
all_model_checkpoint_paths: "./models/model.ckpt-50"

checkpoint.model_checkpoint_path: ./models/model.ckpt-50
./models/model.ckpt-50 からの次からの継続実行です。
--開始--
60 save_step file: models/model.ckpt-60.index
70 save_step file: models/model.ckpt-70.index
80 save_step file: models/model.ckpt-80.index
90 save_step file: models/model.ckpt-90.index
100 save_step file: models/model.ckpt-100.index

(my2) F:\python_ai\tensortest>python tsave4.py --help

       USAGE: tsave4.py [flags]
flags:

tsave4.py:
  --checkpoint: 継続用 checkpoint 番号(0は新規実行、-1は前回の最後より継続
    (default: '-1')
    (an integer)
  --checkpoint_dir: 学習したデータを記憶するディレクトリ
    (default: './models/')

Try --helpfull to get a list of all flags.

(my2) F:\python_ai\tensortest>python tsave4.py --checkpoint 80
前回保存した checkpointのファイル群情報 model_checkpoint_path: "./models/model.ckpt-100"
all_model_checkpoint_paths: "./models/model.ckpt-60"
all_model_checkpoint_paths: "./models/model.ckpt-70"
all_model_checkpoint_paths: "./models/model.ckpt-80"
all_model_checkpoint_paths: "./models/model.ckpt-90"
all_model_checkpoint_paths: "./models/model.ckpt-100"

checkpoint.model_checkpoint_path: ./models/model.ckpt-100
./models/model.ckpt-80 からの次からの継続実行です。
--開始--
90 save_step file: models/model.ckpt-90.index
100 save_step file: models/model.ckpt-100.index
110 save_step file: models/model.ckpt-110.index
120 save_step file: models/model.ckpt-120.index
130 save_step file: models/model.ckpt-130.index

(my2) F:\python_ai\tensortest>