tensorflow saver saves and restores an instance description of the specified tensor

  • 2020-11-18 06:21:39
  • OfStack

In practice, we often encounter situations like this:

1. Pre-training parameters with simple model

2. Train the complex model after importing the pre-training parameters into the complex model

This raises a question:

How to load pre-training parameters.

Here is my summary.

For illustration purposes, make an assumption: a simple model has only one volume base and a complex model has two.

The implementation code of convolutional layer is as follows:


import tensorflow as tf
# PS The burden of this article is saver But just to make it easier to read the parameters 
#  parameter 
# name : There is so much code to create the volume base that it must be functioned and used to prevent variable conflicts tf.name_scope
# input_data : Input data 
# width, high : The width and height of the convolution window 
# deep_before, deep_after : Number of neurons before and after convolution 
# stride : Move step size of convolutional small window 
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm I wrote it myself batchnorm function 
  conv =tf.maximum(0.1*bias, bias)
  return conv

A simple pre-training model is described in the following sentence


conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

The complex model is two volume bases, as follows:


conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

At this point, it is simply in the pre-training model:


saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

Not because:

1. Print all tensor if you use the following in the pre-training model


all_v =tf.global_variables()
for i in all_v: print i

The name of tensor is not weights and biases, but 'simple-conv1/weights and 'ES39en-ES40en1 /biases, as follows:


<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

Similarly, in the complex model is ES45en-conv1 /weights and ES48en-conv1 /biases, which do not match.

2. In the pre-training model, there is only one convolutional layer, while in the complex model, there are two. By default, tensorflow will find all "trainable" tensor from the model file (' model.ckpt '), and no error will be found.

Solutions:

1. Define global variables in the pre-training model


parm_dict={}

Add the following two lines above "return conv"


parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

Then use the following sentence to define saver:


saver= tf.train.Saver(parm_dict)

This saves the model file to the complex model.

2. Define global variables in complex models


parameters= []

Add the following line above "return conv"


conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)
0

And then determine that if it's the second convolutional layer it doesn't update parameters.

Then use the following sentence to define saver:


conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)
1

This tells saver that you only need to look for weights and biases from the model file and roll all complex-ES94en1 /Variable~ complex-ES97en1 /Variable_3 (red above).

Finally, use the following code to load it


conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)
2

Related articles: