How Do I Restore the Model Training Parameters After Quantization Operators Are Inserted?

The list of quantization variable operators (quant_add_ops) to be added has been passed to the quantize_model API. The variable values in the list cannot be found in the model training file. Therefore, an error indicating that the variables cannot be found is reported when the model training parameters are restored. In this case, you need to delete the variable values in the quant_add_ops list from the restoration list before restoring the model parameters.

  1. Restoration of shadow variables
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    # 1. Obtain the dictionary variables_dict of each {key:value} variable.
    variables_ema = tf.train.ExponentialMovingAverage(moving_average_decay)
    variables_dict = variables_ema.variables_to_restore()
    
    # 2. Define the variables to be restored. {key:value} corresponds to the dictionary params_need_load.
    params_need_load = dict()
    
    # 3. Find the variables to be restored from variables_dict based on quant_add_ops.
    for key, value in variables_dict.items():
        if value not in quant_add_ops:
            params_need_load[key] = value
    
    # 4. Restore variables.
    loader = tf.train.Saver(params_need_load)
    loader.restore(sess, FLAGS.checkpoint)
    
  2. Restoration of non-shadow variables
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    # 1. Obtain the dictionary variables_dict of each {key:value} variable.
    variables_global = tf.global_variables()
    variables_dict = dict()
    for var in variables_global:
        variables_dict[var.name[:-2]] = var
    
    # 2. Define the variables to be restored. {key:value} corresponds to the dictionary params_need_load.
    params_need_load = dict()
    
    # 3. Find the variables to be restored from variables_dict based on quant_add_ops.
    for key, value in variables_dict.items():
        if value not in quant_add_ops:
           params_need_load[key] = value
    
    # 4. Restore variables.
    loader = tf.train.Saver(params_need_load)
    loader.restore(sess, FLAGS.checkpoint)