Tensorflow loads the pre training model and saves instances of the model

  • 2020-11-25 07:20:56
  • OfStack

In the process of using tensorflow, we need to use the model file after the training. Sometimes, we may need to use a model that someone else has trained and then retrain on it. This is where we need to know how to manipulate the model data. After reading this article, I believe you 1 will definitely have harvest!

1 Tensorflow model file

The file structure we saved in the checkpoint_dir directory is as follows:


|--checkpoint_dir
| |--checkpoint
| |--MyModel.meta
| |--MyModel.data-00000-of-00001
| |--MyModel.index

1.1 meta file

The MyModel.meta file holds the graph structure, and the meta file is in pb (protocol buffer) format, containing variables, op, collections, and so on.

1.2 ckpt file

The ckpt file is a base 2 file that holds all the variables weights, biases, gradients, and so on. Save in the.ckpt file before tensorflow 0.11. After 0.11, save through two files, such as:


MyModel.data-00000-of-00001
MyModel.index

1.3 checkpoint file

You can also see that in the checkpoint_dir directory is the checkpoint file, which is a text file that records the most recent checkpoint files saved and a list of other checkpoint files. At inference, you can modify this file to specify which model to use

2 Save the Tensorflow model

tensorflow provides the ES52en.train.Saver class to save the model. It is worth noting that in tensorflow, the variable exists in the Session environment, that is, the variable value only exists in the Session environment. Therefore, to save the model, you need to pass in session:


saver = tf.train.Saver()
saver.save(sess,"./checkpoint_dir/MyModel")

Look at a simple example:


import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './checkpoint_dir/MyModel')

After execution, create the model file in the checkpoint_dir directory as follows:


checkpoint
MyModel.data-00000-of-00001
MyModel.index
MyModel.meta

In addition, if you want to save the model after 1000 iterations, you only need to set the global_step parameter:

saver.save(sess, './checkpoint_dir/MyModel',global_step=1000)

The name of the saved model file is appended with -1000, as follows:


checkpoint
MyModel-1000.data-00000-of-00001
MyModel-1000.index
MyModel-1000.meta

In the actual training, we might save the model data once every 1000 iterations, but since the graph is constant, it is not necessary to save the graph every time. We can specify not to save the graph in the following way:


saver.save(sess, './checkpoint_dir/MyModel',global_step=step,write_meta_graph=False)

Another practical option is if you want to save the model every 2 hours and only save the latest 5 model files:


tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

Note: tensorflow only saves the last five model files by default. If you want to save more, you can specify this by max_to_keep

If we do not specify any parameters to tf.train.Saver, all variables are saved by default. If you do not want to save all the variables and only save 1 part of the variables, you can do so by specifying variables/collections. When you create an instance of tf. train. Saver, you pass in Saver by constructing list or dictionary with the variables you want to save:


import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './checkpoint_dir/MyModel',global_step=1000)

3. Import the trained model

As we saw in section 1, tensorflow saves graphs and variable data as separate files. Therefore, when importing the model, there are also two steps: constructing the network diagram and loading parameters

3.1 Construct the network diagram

A clumsy approach is to manually type code to implement a graph structure similar to model 1 and model 1. In fact, since we have saved the diagram, there is no need to write the diagram structure code once.


saver=tf.train.import_meta_graph('./checkpoint_dir/MyModel-1000.meta')

The first line of code above loads the graph in

3.2 Loading parameters

More importantly, we need the model parameters trained previously (i.e. weights, biases, etc.). As mentioned in Section 2 of this paper, variable values need to be dependent on Session, so when loading parameters, Session should be constructed first:


MyModel.data-00000-of-00001
MyModel.index
0

At this point, W1 and W2 are loaded into the diagram and can be accessed:


MyModel.data-00000-of-00001
MyModel.index
1

After execution, print as follows:


MyModel.data-00000-of-00001
MyModel.index
2

4. Model using recovery

As we have previously understood how to save and restore models, many times we want to use a model that has already been trained, such as prediction, ES138en-ES139en, and stepwise training. At this point, we may need to obtain some intermediate result values from the trained model by graph_get_tensor_by_ES144en (' es145EN1:0 '). Note that w1:0 is the name of tensor.

Suppose we have a simple network model, the code is as follows:


import tensorflow as tf


w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias") 

# define 1 a op For later recovery 
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# create 1 a Saver Object to hold all variables 
saver = tf.train.Saver()

# By passing in the data, execute op
print(sess.run(w4,feed_dict ={w1:4,w2:8}))
# print  24.0 ==>(w1+w2)*b1

# Now save the model 
saver.save(sess, './checkpoint_dir/MyModel',global_step=1000)

Next we manipulate the saved model using the graph.es155EN_ES156en_by_ES158en () method.


MyModel.data-00000-of-00001
MyModel.index
4

Note: when the model is saved, only the value of the variable is saved. The value in placeholder is not saved

If you're not just using the trained model, but adding op, or adding layers and training the new model, here's a simple example of how to do it:


MyModel.data-00000-of-00001
MyModel.index
5

If you only want to restore part 1 of the graph, add another op for fine-tuning. Just use the graph. get_tensor_by_name() method to obtain the required op and build a graph based on it. Let's look at a simple example. Suppose we need to use the graph in the trained VGG network and modify the last layer to change the output to 2 for the new data of ES180en-ES181en:


MyModel.data-00000-of-00001
MyModel.index
6

Related articles: