我有一个Dataframe,我想用它来预测现有的模型。使用模型的transform方法时出错。
这是我处理trainingdata的方式。
forecast.printSchema()
我的Dataframe架构:
root
|-- PM10: double (nullable = false)
|-- rain_3h: double (nullable = false)
|-- is_rain: double (nullable = false)
|-- wind_deg: double (nullable = false)
|-- wind_speed: double (nullable = false)
|-- humidity: double (nullable = false)
|-- is_newYear: double (nullable = false)
|-- season: double (nullable = false)
|-- is_rushHour: double (nullable = false)
|-- PM10_average: double (nullable = false)
打印第一行
forecast.show(5)
+----+-------+-------+--------+----------+--------+----------+------+-----------+------------+
|PM10|rain_3h|is_rain|wind_deg|wind_speed|humidity|is_newYear|season|is_rushHour|PM10_average|
+----+-------+-------+--------+----------+--------+----------+------+-----------+------------+
| 1.1| 1.0| 0.0| 15.0048| 7.27| 0.0| 0.0| 0.0| 0.0| 1.2|
| 1.1| 1.0| 0.0| 15.0048| 7.27| 0.0| 0.0| 0.0| 0.0| 1.2|
| 1.1| 1.0| 0.0| 15.0048| 7.27| 0.0| 0.0| 0.0| 0.0| 1.2|
| 1.1| 1.0| 0.0| 15.0048| 7.27| 0.0| 0.0| 0.0| 0.0| 1.2|
| 1.1| 1.0| 0.0| 15.0048| 7.27| 0.0| 0.0| 0.0| 0.0| 1.2|
+----+-------+-------+--------+----------+--------+----------+------+-----------+------------+
only showing top 5 rows
准备功能
assembler = VectorAssembler(
inputCols=["rain_3h", "is_rain", "wind_deg", "wind_speed", "humidity", "is_newYear", "season", "is_rushHour", "PM10_average"],
outputCol="features")
output = assembler.transform(forecast)
output.registerTempTable("output")
features = spark.sql("SELECT features, PM10 as label FROM output")
features.printSchema()
+--------------------+-----+
| features|label|
+--------------------+-----+
|(9,[0,2,3,8],[1.0...| 1.1|
|(9,[0,2,3,8],[1.0...| 1.1|
|(9,[0,2,3,8],[1.0...| 1.1|
|(9,[0,2,3,8],[1.0...| 1.1|
|(9,[0,2,3,8],[1.0...| 1.1|
+--------------------+-----+
only showing top 5 rows
将数据传递给模型
model = PipelineModel.load(path)
predict = model.transform(features)
predict.printSchema()
root
|-- features: vector (nullable = true)
|-- label: double (nullable = false)
|-- indexedFeatures: vector (nullable = true)
|-- prediction: double (nullable = true)
predict.show(5)
导致此错误:
17/09/16 19:12:25 WARN Utils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/hdp/current/spark2-client/python/pyspark/sql/dataframe.py", line 287, in show
print(self._jdf.showString(n, truncate))
File "/usr/hdp/current/spark2-client/python/lib/py4j-0.10.3-src.zip/py4j/java_gateway.py", line 1133, in __call__
File "/usr/hdp/current/spark2-client/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/usr/hdp/current/spark2-client/python/lib/py4j-0.10.3-src.zip/py4j/protocol.py", line 319, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o235.showString.
: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$11: (vector) => vector)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply5_1$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
at org.apache.spark.sql.execution.TakeOrderedAndProjectExec$$anonfun$executeCollect$1.apply(limit.scala:132)
at org.apache.spark.sql.execution.TakeOrderedAndProjectExec$$anonfun$executeCollect$1.apply(limit.scala:132)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at
scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
at org.apache.spark.sql.execution.TakeOrderedAndProjectExec.executeCollect(limit.scala:132)
at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2193)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)
at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2546)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2192)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2199)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:1935)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:1934)
at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2576)
at org.apache.spark.sql.Dataset.head(Dataset.scala:1934)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2149)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:239)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:237)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:280)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:214)
at java.lang.Thread.run(Thread.java:745)
Caused by: java.util.NoSuchElementException: key not found: 1.0
at scala.collection.MapLike$class.default(MapLike.scala:228)
at scala.collection.AbstractMap.default(Map.scala:59)
at scala.collection.MapLike$class.apply(MapLike.scala:141)
at scala.collection.AbstractMap.apply(Map.scala:59)
at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$10.apply(VectorIndexer.scala:339)
at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$10.apply(VectorIndexer.scala:317)
at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$11.apply(VectorIndexer.scala:362)
at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$11.apply(VectorIndexer.scala:362)
... 33 more
答案 0 :(得分:3)
这是因为PipelineModel
包含VectorIndexerModel
而features
包含标记为分类的其中一列中未见过的关卡。您可以轻松地重现相同的错误,如下所示:
val train = Seq((1L, Vectors.dense(0.0))).toDF("id", "foo")
val test = Seq((1L, Vectors.dense(1.0))).toDF("id", "foo")
new VectorIndexer().setInputCol("foo").setOutputCol("bar")
.fit(train).transform(test).first
截至今天VectorIndexer
(Spark 2.2)Spark并不支持处理VectorIndexer
(as it does with StringIndexer
)中看不见的关卡,但此功能is planned for the future。
修改强>:
在Spark 2.3中,您可以使用handleInvalid
,例如:
new VectorIndexer()
.setInputCol("foo").setOutputCol("bar")
.setHandleInvalid("keep")
.fit(train).transform(test).first