如何保存和恢复ResourceBase?

时间:2017-01-07 08:45:39

标签: tensorflow

我添加一个自定义的ops并使用ResourceMgr创建一个ResourceBase的子类来存储一些状态,以及如何将状态写入检查点文件。

class FeatureTransformMap : public ResourceBase {
public:
    FeatureTransformMap(int32_t max_feature_count) : max_feature_count_(max_feature_count), feature_index_(0) {
        cout<<"Max feature count is:"<<max_feature_count_<<endl;
    }
    string DebugString() {return "FeatureTransformMap";}

    int32_t GetFeatureIndex(const string& feature) {
        {
            mutex_lock l(mu_);
            feature_index_ += 1;
        }
        return feature_index_;
    }



private:
    tensorflow::mutex mu_;
    uint32_t feature_index_ GUARDED_BY(mu_);
    const uint32_t max_feature_count_;
};

如上面的代码所示,我如何将feature_index_写入检查点文件。

1 个答案:

答案 0 :(得分:2)

没有保存tensorflow::ResourceBase实例的通用方法,但您可以按如下方式实现自己的检查点支持:

  1. 定义FeatureTransformMap类的方法,该方法将地图的状态序列化为一个或多个tensorflow::Tensor对象,并对其进行反序列化。有关示例,请参阅MutableHashTableOfScalars::ExportValues()MutableHashTableOfScalars::ImportValues()

  2. 实现调用序列化和反序列化方法的新TensorFlow OpKernel类。有关示例,请参阅LookupTableExportOpLookupTableImportOp

  3. 在Python中,为您的资源实现BaseSaverBuilder.SaveableObject的子类,其中包括对新操作的调用。有关示例,请参阅MutableDenseHashTable._Saveable

  4. 在Python中,当您创建资源实例时,请将其添加到tf.GraphKeys.SAVEABLE_OBJECTS的集合中。有关示例,请参阅MutableDenseHashTable here

  5. 的执行方式