我有一个在Python API中训练过的xgboost模型,名为my_fpd20.model
,现在我想在Scala中使用它执行预测操作,但是当我进行测试时,使用相同功能会得到不同的预测结果。(xgb版本为0.82)
代码如下:
import ml.dmlc.xgboost4j.scala.XGBoost
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.LabeledPoint
object MSFScoreCompute {
def main(args: Array[String]): Unit = {
val model = XGBoost.loadModel("/home/mixxbox/my_fpd20.model")
val t1 = (1007988895058497544L, 0.928856, ",,5436.0,559.0,2169.0,15267.0,360.0,3619.0,3508.0,1412.0,,,,118.0,4.0,,,,,,,511.0,648.0,312.0,57.0,4573.0,3116.0,3530.0,124.0,1774.0,521.0,,,,,625.0,124.0,,246.0,,,,,,,,,,,,,,,,,,,,,,608.0,8020.0,7246.0,6427.0,24328.0,22359.0,14074.0,66638.0,3636.0,,608.0,793.0,9183.0,,,,,2451.0,375.0,1127.0,4621.0,,,2.0,,,126.0,,1610.0,,,,,10469.0,,166.0,,2.0,,,1610.0,4869.0,7150.0,,,,,,9.0,,,4179.0,97.0,147.0,938.0,172.0,228.0,543.0,1226.0,115.0,90.0,163.0,199.0,4073.0,2860.0,598.0,469.0,,,172.0,67.0,18.0,,2725.0,6.0,296.0,435.0,,,,273.0,,,147.0,25.0,1397.0,103.0,,,,,52.0,,,,,,,50.0,159.0,134.0,420.0,143.0,64.0,36.0,228.0,,23.0,837.0,,,1650.0,3781.0,1019.0,40.0,116.0,186.0,13826.0,2783.0,,26.0,,,1394.0,,1056.0,135.0,,632.0,,,,87059.0,3821.0,6121.0,183069.0,284618.0,273332.0,292360.0,1068.0,1547.0,1206.0,32231.0,70824.0,70372.0,84594.0,59782.0,61287.0,64698.0,74435.0,471844.0,2602.0,225.0,14518.0,22737.0,24931.0,16938.0,9972.0,17280.0,14224.0,5.0,,,1379.0,1415.0,,2685.0,2977.0,71.0,16853.0,,,,,141.0,71803.0,88547.0,264464.0,196441.0,280080.0,187640.0,272245.0,259110.0,180342.0,248236.0,297921.0,1755.0,88677.0,90063.0,88413.0,58363.0,96684.0,94109.0,21.0,,13925.0,2829.0,16684.0,16453.0,22001.0,23828.0,13205.0,3314.0,14718.0,38169.0,,,7661.0,13241.0,17582.0,222.0,10880.0,,,5104.407676174497,12.4,12.4,7845.433933358866,4060.872788134977,6.76,,,,,,,,,,,1.0,1.0,8.0,189.0,173.0,146.0,70.0,14.0,25.0,38.0,17.0,10858.0,,,432.0,3698.0,3381.0,942.0,109.0,870.0,8.0,8.0,,,,,,,,2.0,4.0,4.0,3.0,1.0,1.0,5.0,8.0")
val t2 = (1002897541026550024L, 0.969927, ",,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0")
val t3 = (1005094818872823816L, 0.96601504, ",,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,5.0,,6.0,,,,,,,1.0,1.0,1.0,2.0,1.0,1.0,2.0,2.0")
val t4 = (1005094818872823816L, 0.96601504, ",,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0")
val t5 = (1005966203849544456L, 0.96515334, "13.0,577.0,371.0,121.0,132.0,3052.0,210.0,,882.0,1308.0,,,,683.0,2079.0,,,,,,,,,37.0,25.0,4035.0,225.0,530.0,15.0,,17.0,,,,,15.0,,,,,,,,,,,,,,555.0,164.0,1243.0,,,,524.0,,,,,,,,1411.0,,,7094.0,44568.0,,,,,,,,,,,,,109.0,,50.0,187.0,130.0,122.0,2950.0,1253.0,3326.0,316.0,232.0,,1407.0,15413.0,6595.0,881.0,25.0,57.0,8.0,34.0,72.0,1327.0,4110.0,4138.0,2346.0,,57.0,,414.0,92.0,,755.0,8451.0,15674.0,18366.0,1120.0,2486.0,17526.0,21293.0,49350.0,38514.0,28455.0,77143.0,50427.0,27724.0,53642.0,59878.0,1500.0,4926.0,809.0,21877.0,5977.0,20509.0,13501.0,1557.0,100799.0,25869.0,23.0,1500.0,4926.0,12115.0,3.0,10918.0,1744.0,2851.0,23947.0,16003.0,1.0,,,427.0,815.0,,,,633.0,188.0,43.0,,,78.0,15.0,52.0,63.0,,50.0,,,3679.0,285.0,912.0,4852.0,763.0,2193.0,3164.0,4977.0,2425.0,20147.0,304.0,872.0,,,52.0,60.0,104.0,225.0,641.0,,,,,,34043.0,11996.0,7274.0,32689.0,61913.0,49065.0,85510.0,6076.0,622.0,5298.0,12458.0,8957.0,16414.0,29491.0,31258.0,23791.0,24628.0,34797.0,162136.0,1881.0,516.0,4343.0,2442.0,5408.0,2173.0,7204.0,883.0,9821.0,58.0,,,885.0,995.0,431.0,581.0,1566.0,736.0,10097.0,,,,,,104805.0,138960.0,151672.0,94589.0,129486.0,149334.0,213810.0,180026.0,130354.0,152028.0,197758.0,422.0,64783.0,75352.0,57723.0,33850.0,83382.0,57379.0,9822.0,1141.0,12815.0,857.0,886.0,15905.0,9014.0,11621.0,14919.0,1661.0,9428.0,3358.0,50.0,3.0,1112.0,2501.0,4630.0,300.0,1413.0,,,4126.941409266409,4.87,4.87,4740.780104712042,3864.2336822246352,4.62,,,,,,,,,,,2.0,5.0,5.0,23.0,15.0,25.0,6.0,,,10.0,3.0,285.0,13.0,9.0,232.0,1504.0,820.0,297.0,296.0,461.0,9.0,8.0,,,1.0,,,2.0,,2.0,4.0,5.0,5.0,1.0,1.0,6.0,10.0")
val t6 = (1003746555179569416L, 0.97295755, ",90.0,,293.0,148.0,2043.0,56.0,70.0,113.0,482.0,,,,160.0,2460.0,354.0,,,2509.0,29.0,71.0,816.0,255.0,885.0,1072.0,5563.0,4042.0,2431.0,268.0,318.0,653.0,591.0,77.0,151.0,710.0,2675.0,2.0,435.0,157.0,,,,,,,,,,,,,,,,,,,,,,,,1304.0,2147.0,4777.0,5847.0,38116.0,168435.0,,,,513.0,,,,,,,,,1457.0,,1191.0,68.0,122.0,338.0,1893.0,216.0,267.0,972.0,1014.0,119.0,4.0,7607.0,3772.0,,1191.0,,68.0,46.0,189.0,663.0,4619.0,2320.0,668.0,,,,,,85.0,511.0,28238.0,419.0,8897.0,12713.0,35535.0,13884.0,19754.0,16271.0,33518.0,12715.0,34050.0,34536.0,36717.0,27203.0,45865.0,7030.0,10168.0,1305.0,410.0,4968.0,335.0,28300.0,3720.0,13834.0,26539.0,29.0,1967.0,,1503.0,53.0,219.0,1952.0,,9862.0,4314.0,,,81.0,,194.0,,,,,,,,,260.0,416.0,416.0,1086.0,425.0,,,,3082.0,212.0,427.0,2387.0,1776.0,3846.0,9103.0,8933.0,6526.0,13894.0,1150.0,160.0,,,,,,294.0,,15.0,174.0,,,,34210.0,4986.0,1217.0,26738.0,63735.0,62730.0,187933.0,239.0,4590.0,157.0,6212.0,12547.0,23008.0,19881.0,19423.0,17632.0,28744.0,61775.0,172619.0,172.0,,2211.0,3989.0,2641.0,3405.0,7307.0,17613.0,3274.0,1594.0,,,1191.0,1027.0,3030.0,1215.0,653.0,2677.0,6224.0,,,,321.0,326.0,42824.0,61799.0,119313.0,82310.0,150184.0,83475.0,148886.0,131373.0,96711.0,155795.0,189394.0,3944.0,47375.0,50923.0,42097.0,26121.0,71303.0,49468.0,,,9026.0,3180.0,2409.0,4598.0,10544.0,14670.0,6743.0,4249.0,2354.0,14226.0,,,320.0,645.0,1075.0,100.0,38.0,,,4311.241637010675,11.18,11.18,5006.421052631579,2395.94753248642,5.95,,,,,,,5.0,4.0,4.0,4.0,3.0,,382.0,890.0,768.0,147.0,358.0,13.0,101.0,50.0,108.0,15215.0,4.0,3.0,317.0,2486.0,1954.0,400.0,554.0,79.0,6.0,,,,19.0,,,108.0,,3.0,5.0,4.0,4.0,3.0,3.0,6.0,5.0")
Array(t1, t2, t3, t4, t5, t6)
.map(row => {
val features = row._3.split(",").map(row => if ("".equals(row)) Double.NaN else row.toDouble).map(row => if (row.equals(0) || row.equals(-1)) Double.NaN else row)
val labelPoint = LabeledPoint(0, (0 to 326).toArray, features.map(row => row.toFloat))
val result = model.predict(new DMatrix(Iterator(labelPoint)), false, 1000)
Array(row._1.toString, row._2, 1.0 - result(0)(0), row._2 - (1 - result(0)(0))).mkString(",")
}).foreach(row => println("result:++++" + row))
}
}
结果如下,第一列为键,第二列为python预测结果,第三列为Scala预测结果,第四列为diff值。
result:++++1007988895058497544,0.928856,0.9201448783278465,0.008711144023895279
result:++++1002897541026550024,0.969927,0.8605302572250366,0.10939674277496336
result:++++1005094818872823816,0.96601504,0.8911455124616623,0.07486951263717656
result:++++1005094818872823816,0.96601504,0.8605302572250366,0.10548478277496343
result:++++1005966203849544456,0.96515334,0.9438987076282501,0.021254662174072236
result:++++1003746555179569416,0.97295755,0.9416626989841461,0.03129488081817622
我的问题是,我该怎么做才能获得一致的结果? (我想这可能是由于Scala和Python之间的NaN值不同所致,但是如何解决此问题?)
答案 0 :(得分:0)
我已经通过分配DMatrix的“缺失”值解决了这个问题,此“缺失”意味着所分配的值表示缺失值,其构造函数如下:
@throws(classOf[XGBoostError])
def this(data: Array[Float], nrow: Int, ncol: Int, missing: Float) {
this(new JDMatrix(data, nrow, ncol, missing))
}
例如
val ma = new DMatrix(my_features_array, 1, 327, Float.NaN)
val result = model.predict(ma, false)
作为代码,缺失值是NaN,表示如果要素数组中存在NaN,则这些NaN将被视为缺失。 现在的结果如下:
result:++++1007988895058497544,0.928856,0.9288560450077057,-1.5205383285810115E-8
result:++++1010594380924326152,0.96601504,0.9660150445997715,-8.744811541561148E-10
result:++++1002897541026550024,0.969927,0.9699269998818636,-1.292037965505699E-8
result:++++1005094818872823816,0.96601504,0.9660150445997715,-8.744811541561148E-10
result:++++1005966203849544456,0.96515334,0.9651533514261246,3.4750365918156945E-9
result:++++1003746555179569416,0.97295755,0.9729575514793396,-1.4793396507783996E-9