修改ckpt_compare.py文件(内容如下面代码块所示)中的checkpoint_path1为异常的模型文件路径,checkpoint_path2为正常的模型文件路径后,在tensorflow 1.15环境中执行如下python ckpt_compare.py脚本后,会按照余弦相似度从低到高输出变量名和余弦相似度。
from tensorflow.python import pywrap_tensorflow import numpy as np checkpoint_path1 = "path1/model-200" checkpoint_path2 = "path2/model-200" reader1 = pywrap_tensorflow.NewCheckpointReader(checkpoint_path1) reader2 = pywrap_tensorflow.NewCheckpointReader(checkpoint_path2) var_to_shape_map = reader1.get_variable_to_shape_map() key_cos = {} for key in var_to_shape_map: tensor1 = reader1.get_tensor(key).reshape(-1) tensor2 = reader2.get_tensor(key).reshape(-1) key_cos[key] = np.dot(tensor1,tensor2)/(np.linalg.norm(tensor1)*np.linalg.norm(tensor2)) key_cos = list(key_cos.items()) key_cos.sort(key = lambda x : x[1]) for key, cos in key_cos: print(key, cos)