如何从scala中的苏打水中导出h2o模型作为MOJO,由EasyPredictModelWrapper

时间:2018-03-27 14:30:29

标签: scala h2o sparkling-water

我的目标是导出一个使用scala(使用苏打水)训练火花的h2o模型,这样我就可以在没有Spark的应用程序中导入它。

因此:

  • 使用scala(文档仅显示r和python的示例)
  • 导出使用苏打水(带火花的h2o)构建的模型
  • 在scala中导入模型(没有spark或h2o集群,只有hex-genmodel包)

因此我使用ModelSerializationSupport导出,MojoModel.load导入

val gbmParams = new GBMParameters()
gbmParams._train = train
gbmParams._response_column = "target"
gbmParams._ntrees = 5
gbmParams._valid = valid
gbmParams._nfolds = 3 
gbmParams._min_rows = 1
gbmParams._distribution = DistributionFamily.multinomial
val gbm = new GBM(gbmParams)
val gbmModel = gbm.trainModel.get
val mojoPath = "./model.zip"
ModelSerializationSupport.exportMOJOModel(gbmModel, new File(mojoPath).toURI, force = true)
val simpleModel = new EasyPredictModelWrapper(MojoModel.load(mojoPath))

失败
error in opening zip file
java.util.zip.ZipException: error in opening zip file
at java.util.zip.ZipFile.open(Native Method)
at java.util.zip.ZipFile.<init>(ZipFile.java:220)
at java.util.zip.ZipFile.<init>(ZipFile.java:150)
at java.util.zip.ZipFile.<init>(ZipFile.java:121)
at hex.genmodel.ZipfileMojoReaderBackend.<init>(ZipfileMojoReaderBackend.java:13)
at hex.genmodel.MojoModel.load(MojoModel.java:33)
...

似乎mojo导出器使用的格式与hex.genmodel中显示的格式相同(显然是一个拉链)

在h2o 2.1.23上运行(2.1.24在构建群集时失败,如https://0xdata.atlassian.net/browse/SW-776上所述)和spark 2.1

- 更新:

使用ModelSerializationSupport类加载它自己的导出也会失败并出现相同的异常:

ModelSerializationSupport.loadMOJOModel(new File(mojoPath).toURI)

H2OModel导出和加载
以H2OModel(因此使用苏打水)加载回来确实有效:

val h2oModelPath = "./model_h2o"
ModelSerializationSupport.exportH2OModel(gbmModel, new File(h2oModelPath).toURI, force = true)
val loadedModel: GBMModel = ModelSerializationSupport.loadH2OModel(new File(h2oModelPath).toURI)

H2OMOJOModel导出和加载
使用H2OMOJOModel将其加载回来(从H2OGBM的实现中复制):

val mojoModel = new H2OMOJOModel(ModelSerializationSupport.getMojoData(gbmModel))
mojoModel.write.overwrite.save(mojoPath)
H2OMOJOModel.load(mojoPath) 

使用MojoModel导入导出H2OGBM
尝试使用常规MojoModel导入失败但是:

val gbm = new H2OGBM(gbmParams)(h2oContext, myspark.sqlContext)
val gbmModel = gbm.trainModel(gbmParams)
val mojoPath = "./models.zip"
gbmModel.write.overwrite.save(mojoPath)
MojoModel.load(mojoPath)

有以下例外:

./models.zip/model.ini (No such file or directory)
java.io.FileNotFoundException: ./models.zip/model.ini (No such file or directory)

1 个答案:

答案 0 :(得分:0)

解决方案实际上是在getMojoModel } Model[_,_,_](接受Array[Byte]ModelSerializationSupport)上解释的

getMojoModel(Model[_,_,_])的实现使用一个字节数组来存储getMojoData(Model[_,_,_]),然后从该字节数组中读取它。

快速测试如下:

val config = new EasyPredictModelWrapper.Config()
config.setModel(ModelSerializationSupport.getMojoModel(gbmModel))
config.setConvertUnknownCategoricalLevelsToNa(true)
val easyPredictModelWrapper = new EasyPredictModelWrapper(config)

因此,现在我们可以自己重现它,但不使用ModelSerializationSupport类(因为它是苏打水的一部分)。

首先将mojo数据存储到文件中:

val path = java.nio.file.Files.createTempFile("model", ".mojo")
path.toFile.deleteOnExit()
path.toString
import java.io.FileOutputStream
val outputStream = new FileOutputStream(path.toFile)
try {
  gbmModel.getMojo.writeTo(outputStream
}
finally if (outputStream != null) outputStream.close()

然后读取字节(在另一个scala应用程序中):

val is = new FileInputStream(path.toFile)
val reader = MojoReaderBackendFactory.createReaderBackend(is, MojoReaderBackendFactory.CachingStrategy.MEMORY)
val mojoModel = ModelMojoReader.readFrom(reader)
val config = new EasyPredictModelWrapper.Config()
config.setModel(mojoModel)
config.setConvertUnknownCategoricalLevelsToNa(true)
val easyPredictModelWrapper = new EasyPredictModelWrapper(config)