Python script to rename variables in a graph - Tensorflow
If you want to build a model that use a pretained Tensorflow model (for example, as feature extractor), you might encounter the following situation. Some names in the graph of your first model might be the same as the one of your second model. If your try to restore the weight of your first model, tensorflow will intent to update the wrong part of your final graph, causing in most of the cases some erros.
To solve this problem, the solution is often to processed to some variable renaming.
with tf.variable_scope('generator'):
could become with tf.variable_scope('generator_Model1'):
tf_rename.py
of this project. don’t forget to update the code according your need. In the previous example, we get:
checkpoint_dir = 'checkpoint_dir'
replace_substr1 = 'generator'
replace_substr2 = 'generator_Model1'
prefix = ''
This script will update each variable accordingly.
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator_Model1')
weight_initiallizer = tf.train.Saver(var_list)
# Define the initialization operation
init_op = tf.global_variables_initializer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
# Load the pretrained model
print('Loading weights from the pre-trained model')
weight_initiallizer.restore(sess, FLAGS.checkpoint)