为什么UDF无法识别数据框的列?

时间:2019-03-20 14:54:19

标签: java apache-spark dataframe

假设我具有以下数据框:

+-----------------+---------------------+
|       document1  |   document2        |
+-----------------+---------------------+
|    word1 word2  |   word2 word3       |
+-----------------+---------------------+

我需要在此数据框中添加一个称为交集的新列,该列表示document1和document2之间的INTERSECTIOn相似性。

如何处理列中的值。我定义了一个称为交集的函数,在输入中采用两个字符串,但是我无法将其应用于列类型。我想我应该使用UDF函数。我该如何用Java做到这一点。注意我使用的是Spark 2.3.0。

我尝试了以下操作:

SparkSession spark = SparkSession.builder().appName("spark session example").master("local[*]")
                .config("spark.sql.warehouse.dir", "/file:C:/tempWarehouse")
                .config("spark.sql.caseSensitive", "true")
                .getOrCreate();

sqlContext.udf().register("intersection", new UDF2<String, String, Double>() {
            @Override
            public Double call(String arg, String arg2) throws Exception {
            double key = inter(arg, arg2);
            return key;
            }
            }, DataTypes.DoubleType);
  v.registerTempTable("v_table");

Dataset<Row> df = spark.sql("select v_table.document, v_table.document1, "
                + "intersection(v_table.document, v_table.document1) as RowKey1,"
                + " from v_table");
        df.show();

但是我得到以下异常:

    INFO SparkSqlParser: Parsing command: select v_table.document, v_table.document1, intersection(v_table.document, v_table.document1) as RowKey1, from v_table
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve '`v_table.document`' given input columns: []; line 1 pos 7

如果我从查询中删除+ ", intersection(v.doc1, v.doc2) as RowKey1,",则选择效果很好。 有什么建议吗?另外,请问如何仅在数据帧上使用而不是像sql一样使用相同的方法?

使用v.printSchema();的“ v”模式为:

root
 |-- document: string (nullable = true)
 |-- document1: string (nullable = true)

1 个答案:

答案 0 :(得分:2)

我想我会以其他方式工作。

将数据集转换为两个工作数据集:一个用于doc1,一个用于doc 2。 首先将一行分成多个单词,然后爆炸。 然后,您要做的就是保持交叉路口。

类似的东西:

Dataset<Row> ds = spark.sql("select 'word1 word2' as document1, 'word2 word3' as document2");
ds.show();

Dataset<Row> ds1 = ds.select(functions.explode(functions.split(ds.col("document1"), " ")).as("word"));
Dataset<Row> ds2 = ds.select(functions.explode(functions.split(ds.col("document2"), " ")).as("word"));      

Dataset<Row> intersection = ds1.join(ds2, ds1.col("word").equalTo(ds2.col("word"))).select(ds1.col("word").as("Common words"));
intersection.show();

输出:

+-----------+-----------+
|  document1|  document2|
+-----------+-----------+
|word1 word2|word2 word3|
+-----------+-----------+
+------------+
|Common words|
+------------+
|       word2|
+------------+

无论如何, 如果您的目标是“仅”将自定义UDF调用到两列中,那么我将这样做:

1。创建您的UDF

UDF2<String, String, String> intersection = new UDF2<String, String, String>() {
    @Override
    public String call(String arg, String arg2) throws Exception {
        String key = inter(arg, arg2);
        return key;
    }

    private String inter(String arg1, String arg2) {
        // this part of the implementation is just to stay in line with the previous part of this answer
        List<String> list1 = Arrays.asList(arg1.split(" "));
        return Stream.of(arg2.split(" ")).filter(list1::contains).collect(Collectors.joining(" "));
    }
};

2。注册并使用它!

纯Java

UserDefinedFunction intersect = functions.udf(intersection, DataTypes.StringType);      

Dataset<Row> ds1 = ds.select(ds.col("document1"), ds.col("document2"), intersect.apply(ds.col("document1"), ds.col("document2")));
ds1.show();

sql

spark.sqlContext().udf().register("intersect", intersection, DataTypes.StringType);
Dataset<Row> df = spark.sql("select document1, document2, "
                + "intersect(document1, document2) as RowKey1"
                + " from v_table");
df.show();

输出

+-----------+-----------+-------+
|  document1|  document2|RowKey1|
+-----------+-----------+-------+
|word1 word2|word2 word3|  word2|
+-----------+-----------+-------+

完整代码

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DataTypes;

public class StackOverflowUDF {
    public static void main(String args[]) {
        SparkSession spark = SparkSession.builder().appName("JavaWordCount").master("local").getOrCreate();

        Dataset<Row> ds = spark.sql("select 'word1 word2' as document1, 'word2 word3' as document2");
        ds.show();

        UDF2<String, String, String> intersection = new UDF2<String, String, String>() {
            @Override
            public String call(String arg, String arg2) throws Exception {
                String key = inter(arg, arg2);
                return key;
            }

            private String inter(String arg1, String arg2) {
                List<String> list1 = Arrays.asList(arg1.split(" "));
                return Stream.of(arg2.split(" ")).filter(list1::contains).collect(Collectors.joining(" "));
            }
        };

        UserDefinedFunction intersect = functions.udf(intersection, DataTypes.StringType);

        Dataset<Row> ds1 = ds.select(ds.col("document1"), ds.col("document2"),
                intersect.apply(ds.col("document1"), ds.col("document2")));
        ds1.show();
        ds1.printSchema();

        ds.createOrReplaceTempView("v_table");

        spark.sqlContext().udf().register("intersect", intersection, DataTypes.StringType);
        Dataset<Row> df = spark
                .sql("select document1, document2, " + "intersect(document1, document2) as RowKey1" + " from v_table");
        df.show();
        spark.stop();

    }
}

HTH!