Tensorflow - 显示和手动修改学习模型的权重并导出以进一步重新学习

时间:2017-10-11 11:46:08

标签: python machine-learning tensorflow

我尝试用Tensorflow做的事情如下:

  1. 考虑我已经学习了神经网络文件:checkpoint,* .meta,* .data和* .index。
  2. 我想提取要显示或处理的学习值(权重,偏差等)到文件/其他工具以供进一步分析。
  3. 我想修改一些学习的值(例如,用0取代一些已经很小的权重,以简化计算)。
  4. 应将修改后的值加载回模型。
  5. 因此,我想获得相同的检查点,* .meta,* .data和* .index文件,但是有一些修改后的值(来自第4步)。
  6. 注意:用于生成初始模型的脚本未知。我在步骤1中的所有内容都是4个列出的文件。

    到目前为止我设法做的是提取图形定义并显示学习值(使用inspect_checkpoint.py)。我发现无法更改模型上的值并将其导出回* .data,* .meta,* .index和checkpoint的集合。通过API后,我没有看到这些操作的明显工具。它甚至可能吗? 致以最诚挚的问候和感谢您的支持!

1 个答案:

答案 0 :(得分:0)

在C ++中,您可以使用CheckpointReaderBundleWriter从/向检查点文件读取/写入张量:

BundleWriter writer(tensorflow::Env::Default(), "out.ckpt");                                                                                                                                        

TF_Status status;                                                                                                                                                                                        
tensorflow::checkpoint::CheckpointReader reader("in.ckpt", &status);

const auto& var_to_shape_map = reader.GetVariableToShapeMap();                                                                                                                                                                                                                                                                                 
for (const auto& elem : var_to_shape_map) {                                                                                                                                                              
  std::unique_ptr<Tensor> weights;                                                                                                                                                                       
  const string& key = elem.first;                                                                                                                                                                        
  reader.GetTensor(key, &weights, &status);  
  auto weights_flat = weights->flat<float>();
  for (int i = 0; i < weights->NumElements(); ++i) {
    // replace with 0 some weights that are already of small value
    if (weights_flat(i) < SMALL_VALUE_THRESHOLD) {
      weights_flat(i) = 0.f;
    }
  }
  writer.Add(key, *weights.get());                                                                                                                                                        
}
writer.Finish();

运行上述代码后,您将获得out.ckpt.dataout.ckpt.index。 您可以使用原始的* meta文件,因为我们仅修改了学习的权重值,并且元信息保持不变。