我正在尝试创建一个用户定义的聚合函数,我可以从python调用它。我试着按照this问题的答案。 我基本上实现了以下内容(取自here):
package com.blu.bla;
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;
public class MySum extends UserDefinedAggregateFunction {
private StructType _inputDataType;
private StructType _bufferSchema;
private DataType _returnDataType;
public MySum() {
List<StructField> inputFields = new ArrayList<StructField>();
inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
_inputDataType = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<StructField>();
bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
_bufferSchema = DataTypes.createStructType(bufferFields);
_returnDataType = DataTypes.DoubleType;
}
@Override public StructType inputSchema() {
return _inputDataType;
}
@Override public StructType bufferSchema() {
return _bufferSchema;
}
@Override public DataType dataType() {
return _returnDataType;
}
@Override public boolean deterministic() {
return true;
}
@Override public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, null);
}
@Override public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
if (buffer.isNullAt(0)) {
buffer.update(0, input.getDouble(0));
} else {
Double newValue = input.getDouble(0) + buffer.getDouble(0);
buffer.update(0, newValue);
}
}
}
@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
buffer1.update(0, buffer2.getDouble(0));
} else {
Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
buffer1.update(0, newValue);
}
}
}
@Override public Object evaluate(Row buffer) {
if (buffer.isNullAt(0)) {
return null;
} else {
return buffer.getDouble(0);
}
}
}
然后我用所有依赖项编译它并使用--jars myjar.jar运行pyspark
在pyspark我做了:
df = sqlCtx.createDataFrame([(1.0, "a"), (2.0, "b"), (3.0, "C")], ["A", "B"])
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql import Row
def myCol(col):
_f = sc._jvm.com.blu.bla.MySum.apply
return Column(_f(_to_seq(sc,[col], _to_java_column)))
b = df.agg(myCol("A"))
我收到以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-24-f45b2a367e67> in <module>()
----> 1 b = df.agg(myCol("A"))
<ipython-input-22-afcb8884e1db> in myCol(col)
4 def myCol(col):
5 _f = sc._jvm.com.blu.bla.MySum.apply
----> 6 return Column(_f(_to_seq(sc,[col], _to_java_column)))
TypeError: 'JavaPackage' object is not callable
我也尝试将--driver-class-path添加到pyspark调用,但结果相同。
还尝试通过java import访问java类:
from py4j.java_gateway import java_import
jvm = sc._gateway.jvm
java_import(jvm, "com.bla.blu.MySum")
def myCol2(col):
_f = jvm.bla.blu.MySum.apply
return Column(_f(_to_seq(sc,[col], _to_java_column)))
还尝试简单地创建类(如建议的here):
a = jvm.com.bla.blu.MySum()
所有人都收到相同的错误消息。
我似乎无法弄清问题是什么。
答案 0 :(得分:4)
因此,似乎主要的问题是,如果给出相对路径,添加jar( - jars,驱动程序类路径,SPARK_CLASSPATH)的所有选项都无法正常工作。这可能是因为ipython中的工作目录存在问题,而不是我运行pyspark的地方。
一旦我将其更改为绝对路径,它就可以正常工作(Haven尚未在集群上对其进行测试,但至少它适用于本地安装)。
此外,我不确定答案here中是否也存在错误,因为该答案使用了scala实现,但是在我需要的java实现中
def myCol(col):
_f = sc._jvm.com.blu.bla.MySum().apply
return Column(_f(_to_seq(sc,[col], _to_java_column)))
这可能不是很有效,因为它每次都会创建_f,而我应该在函数外部定义_f(再次,这需要在集群上进行测试),但至少现在它提供了正确的功能性答案