tensorflow loaded part of the variable example explanation

  • 2020-11-25 07:21:30
  • OfStack

tensorflow model is saved as saver = tf.train.Saver () function, saver.save () save the model as follows:


import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 init_op = tf.global_variables_initializer()
 sess.run(init_op)
 saver.save(sess,"checkpoint/model_test",global_step=1)

After we save the model, we can load the model through saver.restore () and initialize the variables:


import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

During neural network training, sometimes we need to load some parameters from the pre-trained model and initialize the current model. For example, when adding CNN has 6 layers, we need to initialize the parameters of the first 5 layers of CNN from the existing model. This can be done with saver.restore ().

Before we have introduced by tf. train. Saver () method of part variables, namely the need to save the list of variables, the same, in the variable initialization, you can need to initialize variables are defined separately. 1 tf train. Saver () function, so that you can separate part of the variable initialization, the following code, for example, saver1 used to initialize a variable v1, saver2 used to initialize a variable v2, v3:


import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
#saver = tf.train.Saver()
saver1 = tf.train.Saver([v1])
saver2 = tf.train.Saver([v2]+[v3])
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver1.restore(sess, "checkpoint/model_test-1")
 saver2.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

Related articles: