您好我正在尝试使用Spark kmeans模型来预测群集编号。但是当我注册它并在SQL中使用它时它给了我一个
java.lang.reflect.InvocationTargetException
def findCluster(s:String):Int={
model.predict(feautarize(s))
}
我正在使用以下
%sql select findCluster((text)) from tweets
如果我直接使用它,同样的工作
findCluster("hello am vishnu")
输出1
答案 0 :(得分:2)
使用您提供的代码无法重现问题。假设model
org.apache.spark.mllib.clustering.KMeansModel
这里是逐步解决方案
首先让我们导入所需的库并设置RNG种子:
import scala.util.Random
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors
Random.setSeed(0L)
生成随机火车组:
// Generate random training set
val trainData = sc.parallelize((1 to 1000).map { _ =>
val off = if(Random.nextFloat > 0.5) 0.5 else -0.5
Vectors.dense(Random.nextFloat + off, Random.nextFloat + off)
})
运行KMeans
// Train KMeans with 2 clusters
val numClusters = 2
val numIterations = 20
val clusters = KMeans.train(trainData, numClusters, numIterations)
创建UDF
// Create broadcast variable with model and prediction function
val model = sc.broadcast(clusters)
def findCluster(v: org.apache.spark.mllib.linalg.Vector):Int={
model.value.predict(v)
}
// Register UDF
sqlContext.udf.register("findCluster", findCluster _)
准备测试集
// Create test set
case class Coord(v: org.apache.spark.mllib.linalg.Vector)
val testData = sqlContext.createDataFrame(sc.parallelize((1 to 100).map { _ =>
val off = if(Random.nextFloat > 0.5) 0.5 else -0.5
Coord(Vectors.dense(Random.nextFloat + off, Random.nextFloat + off))
}))
// Register test set df
testData.registerTempTable("testData")
// Check if it works
sqlContext.sql("SELECT findCluster(v) FROM testData").take(1)
结果:
res3: Array[org.apache.spark.sql.Row] = Array([1])