如何在数据库中保留Spark MLlib模型?

时间:2016-12-02 17:29:35

标签: apache-spark-mllib

我有一个MultilayerPerceptronClassificationModel设置和训练(与this教程中的方式相同)现在我想坚持下去,以便在下次需要对某些内容进行分类时重用神经网络数据。该模型具有loadsave方法,可在文件中保留和恢复。但有没有办法在数据库中保存(以及以后加载)模型? (就我而言,它是CassandraDB)。

1 个答案:

答案 0 :(得分:1)

好的,我自己找到了答案。不确定这是最好的解决方案,但它对我来说很好。

MultilayerPerceptronClassificationModel(据我所见,MLlib包的每个模型都实现了Serializable接口。因此可以将其序列化/反序列化为ByteArray

让我们制作一个表格,用于在Cassandra DB中存储模型:

CREATE TABLE models (
  uid TEXT,
  name TEXT,
  model BLOB,

  PRIMARY KEY (uid)
);

现在我们可以将模型写入DB:

def saveModel(model: MultilayerPerceptronClassificationModel) = {
  val baos = new ByteArrayOutputStream()
  val oos = new ObjectOutputStream(baos)

  oos.writeObject(model)
  oos.flush()
  oos.close()

  sc.parallelize(Seq((model.uid, "my-neural-network-model", baos.toByteArray)))
    .saveToCassandra("mykeyspace", "models", SomeColumns("uid", "name", "model"))
}

并阅读模型:

def loadModel(): MultilayerPerceptronClassificationModel = {
  sc.cassandraTable("mykeyspace", "models").map { r =>
    val bis = new ByteArrayInputStream(r.getBytes("model").array())
    val ois = new ObjectInputStream(bis)

    ois.readObject.asInstanceOf[MultilayerPerceptronClassificationModel]
  }.first()
}