将Tensorflow Checkpoint文件更新为1.0

时间:2017-02-22 16:47:36

标签: python tensorflow

我有一个在Tensorflow r0.12中训练的模型,它使用SaverV2创建了检查点文件。我的模型是使用来自rnn_cell的{​​{1}}和rnn_cell.GRUCell的RNN。自从更改为1.0后,根据this answer

,此程序包已移至tensorflow.python.ops core_rnn_cell_impl

我从here运行了tensorflow.contrib.rnn.python.ops文件,将我的文件更新为新版本。但是,自更新以来,我的旧检查点文件不起作用。似乎新tf_update.py实现所需的某些变量不存在或名称不同。

示例错误(有132个这样的错误):

GRUCell

保存/加载完美无缺,直到更新。如何将旧的检查点文件更新为r1.0?

如果重要,我使用的是python2.7,当使用CUDA的仅CPU张量流或张量流时,会发生同样的错误。

1 个答案:

答案 0 :(得分:5)

没有简单的方法可以做到这一点......一种方法是使用get_variable_to_shape_map()

ID      Debit       Credit      A       B       C   
1       1000.00     900.00      0       0       1000.00     
2       450.00      425.00      0       450.00  0   
3       500.00      490.00      500.00  0       0   
4       600.00      599.00      600.00  0       0   
5       748.00      700.00      0       748.00  0   


Now if we sum the credit it will be = 3114,  
What I have to do here is whatever total credit I have it has to start from top (A+B+C) - 3114  
So It will make C = 0 and my new credit will be 3114-1000=2114,   
Then in my id=2 it will do the same thing (A+B+C) - 2114  
so now B will be 0 and my new credit will be 2114-450=1664

将为您提供已保存检查点中形状的变量名称列表。然后......创建一个从旧名称映射到新名称的字典,即

ID      Debit       Credit      A       B       C   
1       1000.00     900.00      0       0       0.00        
2       450.00      425.00      0       0.00    0   
3       500.00      490.00      0.00    0       0   
4       600.00      599.00      0.00    0       0   
5       748.00      700.00      0       184.00  0   

然后即时启动保护程序并恢复那些变量

  ckpt_reader = tf.train.NewCheckpointReader(filepath)
  ckpt_vars = ckpt_reader.get_variable_to_shape_map()
祝你好运,希望这会有所帮助。