假设我具有以下数据框:
+-----------------+---------------------+
| document1 | document2 |
+-----------------+---------------------+
| word1 word2 | word2 word3 |
+-----------------+---------------------+
我需要在此数据框中添加一个称为交集的新列,该列表示document1和document2之间的INTERSECTIOn相似性。
如何处理列中的值。我定义了一个称为交集的函数,在输入中采用两个字符串,但是我无法将其应用于列类型。我想我应该使用UDF函数。我该如何用Java做到这一点。注意我使用的是Spark 2.3.0。
我尝试了以下操作:
SparkSession spark = SparkSession.builder().appName("spark session example").master("local[*]")
.config("spark.sql.warehouse.dir", "/file:C:/tempWarehouse")
.config("spark.sql.caseSensitive", "true")
.getOrCreate();
sqlContext.udf().register("intersection", new UDF2<String, String, Double>() {
@Override
public Double call(String arg, String arg2) throws Exception {
double key = inter(arg, arg2);
return key;
}
}, DataTypes.DoubleType);
v.registerTempTable("v_table");
Dataset<Row> df = spark.sql("select v_table.document, v_table.document1, "
+ "intersection(v_table.document, v_table.document1) as RowKey1,"
+ " from v_table");
df.show();
但是我得到以下异常:
INFO SparkSqlParser: Parsing command: select v_table.document, v_table.document1, intersection(v_table.document, v_table.document1) as RowKey1, from v_table
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve '`v_table.document`' given input columns: []; line 1 pos 7
如果我从查询中删除+ ", intersection(v.doc1, v.doc2) as RowKey1,"
,则选择效果很好。
有什么建议吗?另外,请问如何仅在数据帧上使用而不是像sql一样使用相同的方法?
使用v.printSchema();
的“ v”模式为:
root
|-- document: string (nullable = true)
|-- document1: string (nullable = true)
答案 0 :(得分:2)
我想我会以其他方式工作。
将数据集转换为两个工作数据集:一个用于doc1,一个用于doc 2。 首先将一行分成多个单词,然后爆炸。 然后,您要做的就是保持交叉路口。
类似的东西:
Dataset<Row> ds = spark.sql("select 'word1 word2' as document1, 'word2 word3' as document2");
ds.show();
Dataset<Row> ds1 = ds.select(functions.explode(functions.split(ds.col("document1"), " ")).as("word"));
Dataset<Row> ds2 = ds.select(functions.explode(functions.split(ds.col("document2"), " ")).as("word"));
Dataset<Row> intersection = ds1.join(ds2, ds1.col("word").equalTo(ds2.col("word"))).select(ds1.col("word").as("Common words"));
intersection.show();
输出:
+-----------+-----------+
| document1| document2|
+-----------+-----------+
|word1 word2|word2 word3|
+-----------+-----------+
+------------+
|Common words|
+------------+
| word2|
+------------+
无论如何, 如果您的目标是“仅”将自定义UDF调用到两列中,那么我将这样做:
UDF2<String, String, String> intersection = new UDF2<String, String, String>() {
@Override
public String call(String arg, String arg2) throws Exception {
String key = inter(arg, arg2);
return key;
}
private String inter(String arg1, String arg2) {
// this part of the implementation is just to stay in line with the previous part of this answer
List<String> list1 = Arrays.asList(arg1.split(" "));
return Stream.of(arg2.split(" ")).filter(list1::contains).collect(Collectors.joining(" "));
}
};
UserDefinedFunction intersect = functions.udf(intersection, DataTypes.StringType);
Dataset<Row> ds1 = ds.select(ds.col("document1"), ds.col("document2"), intersect.apply(ds.col("document1"), ds.col("document2")));
ds1.show();
spark.sqlContext().udf().register("intersect", intersection, DataTypes.StringType);
Dataset<Row> df = spark.sql("select document1, document2, "
+ "intersect(document1, document2) as RowKey1"
+ " from v_table");
df.show();
+-----------+-----------+-------+
| document1| document2|RowKey1|
+-----------+-----------+-------+
|word1 word2|word2 word3| word2|
+-----------+-----------+-------+
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DataTypes;
public class StackOverflowUDF {
public static void main(String args[]) {
SparkSession spark = SparkSession.builder().appName("JavaWordCount").master("local").getOrCreate();
Dataset<Row> ds = spark.sql("select 'word1 word2' as document1, 'word2 word3' as document2");
ds.show();
UDF2<String, String, String> intersection = new UDF2<String, String, String>() {
@Override
public String call(String arg, String arg2) throws Exception {
String key = inter(arg, arg2);
return key;
}
private String inter(String arg1, String arg2) {
List<String> list1 = Arrays.asList(arg1.split(" "));
return Stream.of(arg2.split(" ")).filter(list1::contains).collect(Collectors.joining(" "));
}
};
UserDefinedFunction intersect = functions.udf(intersection, DataTypes.StringType);
Dataset<Row> ds1 = ds.select(ds.col("document1"), ds.col("document2"),
intersect.apply(ds.col("document1"), ds.col("document2")));
ds1.show();
ds1.printSchema();
ds.createOrReplaceTempView("v_table");
spark.sqlContext().udf().register("intersect", intersection, DataTypes.StringType);
Dataset<Row> df = spark
.sql("select document1, document2, " + "intersect(document1, document2) as RowKey1" + " from v_table");
df.show();
spark.stop();
}
}
HTH!