我添加一个自定义的ops并使用ResourceMgr创建一个ResourceBase的子类来存储一些状态,以及如何将状态写入检查点文件。
class FeatureTransformMap : public ResourceBase {
public:
FeatureTransformMap(int32_t max_feature_count) : max_feature_count_(max_feature_count), feature_index_(0) {
cout<<"Max feature count is:"<<max_feature_count_<<endl;
}
string DebugString() {return "FeatureTransformMap";}
int32_t GetFeatureIndex(const string& feature) {
{
mutex_lock l(mu_);
feature_index_ += 1;
}
return feature_index_;
}
private:
tensorflow::mutex mu_;
uint32_t feature_index_ GUARDED_BY(mu_);
const uint32_t max_feature_count_;
};
如上面的代码所示,我如何将feature_index_写入检查点文件。
答案 0 :(得分:2)
没有保存tensorflow::ResourceBase
实例的通用方法,但您可以按如下方式实现自己的检查点支持:
定义FeatureTransformMap
类的方法,该方法将地图的状态序列化为一个或多个tensorflow::Tensor
对象,并对其进行反序列化。有关示例,请参阅MutableHashTableOfScalars::ExportValues()
和MutableHashTableOfScalars::ImportValues()
。
实现调用序列化和反序列化方法的新TensorFlow OpKernel
类。有关示例,请参阅LookupTableExportOp
和LookupTableImportOp
。
在Python中,为您的资源实现BaseSaverBuilder.SaveableObject
的子类,其中包括对新操作的调用。有关示例,请参阅MutableDenseHashTable._Saveable
。
在Python中,当您创建资源实例时,请将其添加到tf.GraphKeys.SAVEABLE_OBJECTS
的集合中。有关示例,请参阅MutableDenseHashTable
here。