将Spark Dataframe中的float列转换为VectorUDT

时间:2016-11-03 18:34:09

标签: python apache-spark pyspark

我试图使用pyspark.ml.evaluation二进制分类指标,如下所示

evaluator = BinaryClassificationEvaluator(rawPredictionCol="prediction")
print evaluator.evaluate(predictions)

我的预测数据框如下所示:

predictions.select('rating','prediction')
predictions.show()
+------+------------+
|rating|  prediction|
+------+------------+
|     1|  0.14829934|
|     1|-0.017862909|
|     1|   0.4951505|
|     1|0.0074382657|
|     1|-0.002562912|
|     1|   0.0208337|
|     1| 0.049362548|
|     1|  0.09693333|
|     1|  0.17998546|
|     1| 0.019649783|
|     1| 0.031353004|
|     1|  0.03657037|
|     1|  0.23280995|
|     1| 0.033190556|
|     1|  0.35569906|
|     1| 0.030974165|
|     1|   0.1422375|
|     1|  0.19786166|
|     1|  0.07740938|
|     1|  0.33970386|
+------+------------+
only showing top 20 rows

每列的数据类型如下:

predictions.printSchema()
root
 |-- rating: integer (nullable = true)
 |-- prediction: float (nullable = true)

现在我收到一个错误,上面的Ml代码说预测列是Float并且预期是一个VectorUDT。

/Users/i854319/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
    811         answer = self.gateway_client.send_command(command)
    812         return_value = get_return_value(
--> 813             answer, self.gateway_client, self.target_id, self.name)
    814 
    815         for temp_arg in temp_args:

/Users/i854319/spark/python/pyspark/sql/utils.pyc in deco(*a, **kw)
     51                 raise AnalysisException(s.split(': ', 1)[1], stackTrace)
     52             if s.startswith('java.lang.IllegalArgumentException: '):
---> 53                 raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace)
     54             raise
     55     return deco

IllegalArgumentException: u'requirement failed: Column prediction must be of type org.apache.spark.mllib.linalg.VectorUDT@f71b0bce but was actually FloatType.'

所以我想把预测列从float转换为VectorUDT,如下所示:

将架构应用于数据帧以将float列类型转换为VectorUDT

from pyspark.sql.types import IntegerType, StructType,StructField

schema = StructType([
    StructField("rating", IntegerType, True),
    StructField("prediction", VectorUDT(), True)
])


predictions_dtype=sqlContext.createDataFrame(prediction,schema)

但现在我收到了这个错误。

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-30-8fce6c4bbeb4> in <module>()
      4 
      5 schema = StructType([
----> 6     StructField("rating", IntegerType, True),
      7     StructField("prediction", VectorUDT(), True)
      8 ])

/Users/i854319/spark/python/pyspark/sql/types.pyc in __init__(self, name, dataType, nullable, metadata)
    401         False
    402         """
--> 403         assert isinstance(dataType, DataType), "dataType should be DataType"
    404         if not isinstance(name, str):
    405             name = name.encode('utf-8')

AssertionError: dataType should be DataType

在火花库中运行ml算法需要花费很多时间,因为有很多奇怪的错误。甚至我用RDD数据尝试了Mllib。这就是给出ValueError:Null指针异常。

请指教。

1 个答案:

答案 0 :(得分:2)

尝试:

class TestClass
{
    private IEventRecorder _eventRecorder;


    private bool _someFlag;
    private object _sharedObject = new object();
    private readonly object _syncObject = new object();

#if DEBUG
    public void SetEventRecorder(IEventRecorder eventRecorder) => _eventRecorder = eventRecorder;
#endif

    public object Read()
    {
        //lock (_syncObject)
        {
#if DEBUG
            _eventRecorder?.Record(nameof(Read));
#endif
            _someFlag = false;
            return _sharedObject;
        }
    }

    public void Write(object obj)
    {
        //lock (_syncObject)
        {
#if DEBUG
            _eventRecorder?.Record(nameof(Write));
#endif
            _someFlag = true;
            _sharedObject = obj;
        }
    }

    public interface IEventRecorder
    {
        void Record(string eventName);
    }
}

public class TestClassTests
{
    private class EventRecorder : TestClass.IEventRecorder
    {
        private string _events = string.Empty;

        public void Record(string eventName) => _events += eventName;

        public string Events => _events;

        public void Reset() => _events = string.Empty;
    }

    [Fact]
    public void RaceConditionTest()
    {
        var correctObject = new object();
        var eventRecorder = new EventRecorder();
        var test = new TestClass();
        test.SetEventRecorder(eventRecorder);

        for (int i = 0; i < 1000; i++)
        {
            test.Write(correctObject);
            var assertTask = Task.Run(() =>
            {
                var actualObj = test.Read();
                if (eventRecorder.Events.StartsWith("WriteRead"))
                    Assert.True(object.ReferenceEquals(correctObject, actualObj), $"Failed on {i} iteration");
            });
            var failTask = Task.Run(() => test.Write(new object()));

            Task.WaitAll(assertTask, failTask);
            eventRecorder.Reset();
        }
    }
}

来源:Tuning parameters for implicit pyspark.ml ALS matrix factorization model through pyspark.ml CrossValidator