我有一个Hive表,我创建了一个非常大的以下形状的火花数据帧:
------------------------------------------------------ | Category |Subcategory | Purchase_ID | Product_ID | |------------+------------+-------------+------------| | a | a_1 | purchase 1 | product 1 | | a | a_1 | purchase 1 | product 2 | | a | a_1 | purchase 1 | product 3 | | a | a_1 | purchase 4 | product 1 | | a | a_2 | purchase 5 | product 4 | | b | b_1 | purchase 6 | product 5 | | b | b_2 | purchase 7 | product 6 | ------------------------------------------------------
请注意,此矩阵非常大,每个子类别购买数千万,每个类别购买50M +。
我的任务如下:
到目前为止我目前的解决方案:
首先,我使用SQL从Hive收集所有唯一的'子类别'值到驱动程序机器中,然后循环遍历每个子类别,我再次为该特定子类别加载数据并计算余弦相似度。计算每对产品的余弦相似度需要构建NxN矩阵。我想(如果我错了请纠正我),按子类别加载整个数据帧和groupby并为每个子类别计算NxN矩阵可能会导致内存不足错误,所以我按顺序计算如下:
val subcategories = hiveContext.sql(s"SELECT Subcategory FROM $table_name")
val subcategory_ids = subcategories.select("Subcategory").collect()
// for each context, sequentially compute models
for ((arr_subcategory_id, index) <- subcategory_ids.zipWithIndex) {
println("Loading current context")
val subcategory_id = arr_subcategory_id(0)
println("subcategory id: ".concat(subcategory_id.toString))
val context_data = hiveContext.sql(s"SELECT Purchase_ID, Product_ID FROM $table_name WHERE Subcategory = $subcategory_id")
//UDF to concatenate column values into single long string
class ConcatenatedGroupItems() extends UserDefinedAggregateFunction {
// Input Data Type Schema
def inputSchema: StructType = StructType(Array(StructField("item", StringType)))
// Intermediate Schema
def bufferSchema = StructType(Array(StructField("items", StringType)))
// Returned Data Type .
def dataType: DataType = StringType
// Self-explaining
def deterministic = true
// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = "" // initialize to empty string
}
// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row) = {
var tempString:String = buffer.getString(0)
// add space in between the items unless it is the first element
if (tempString.length() != 0){
tempString = tempString + " "
}
buffer(0) = tempString.concat(input.getString(0))
}
// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
var tempString = buffer1.getString(0)
// add space in between the items unless it is the first element
if (tempString.length() != 0){
tempString = tempString + " "
}
buffer1(0) = tempString.concat(buffer2.getString(0))
}
// Called after all the entries are exhausted.
def evaluate(buffer: Row) = {
buffer.getString(0)
}
}
// ========================================================================================
println("Concatenating grouped items")
val itemConcatenator = new ConcatenatedGroupItems()
val sents = context_data.groupBy("Purchase_ID").agg(itemConcatenator(context_data.col("Product_ID ")).as("items"))
// ========================================================================================
println("Tokenizing purchase items")
val tokenizer = new Tokenizer().setInputCol("items").setOutputCol("words")
val tokenized = tokenizer.transform(sents)
// ========================================================================================
// fit a CountVectorizerModel from the corpus
println("Creating sparse incidence matrix")
val cvModel: CountVectorizerModel = new CountVectorizer().setInputCol("words").setOutputCol("features").fit(tokenized)
val incidence = cvModel.transform(tokenized)
// ========================================================================================
// create dataframe of mapping from indices into the item id
println("Creating vocabulary")
val vocabulary_rdd = sc.parallelize(cvModel.vocabulary)
val rows_vocabulary_rdd = vocabulary_rdd.zipWithIndex.map{ case (s,i) => Row(s,i)}
val vocabulary_field1 = StructField("Product_ID", StringType, true)
val vocabulary_field2 = StructField("Product_Index", LongType, true)
val schema_vocabulary = StructType(Seq(vocabulary_field1, vocabulary_field2))
val df_vocabulary = hiveContext.createDataFrame(rows_vocabulary_rdd, schema_vocabulary)
// ========================================================================================
println("Computing similarity matrix")
val myvectors = incidence.select("features").rdd.map(r => r(0).asInstanceOf[Vector])
val mat: RowMatrix = new RowMatrix(myvectors)
val sims = mat.columnSimilarities(0.0)
// ========================================================================================
// Convert records of the Matrix Entry RDD into Rows
println("Extracting paired similarities")
val rowRdd = sims.entries.map{case MatrixEntry(i, j, v) => Row(i, j, v)}
// ========================================================================================
// create dataframe schema
println("Creating similarity dataframe")
val field1 = StructField("Product_Index", LongType, true)
val field2 = StructField("Neighbor_Index", LongType, true)
var field3 = StructField("Similarity_Score", DoubleType, true)
val schema_similarities = StructType(Seq(field1, field2, field3))
// create the dataframe
val df_similarities = hiveContext.createDataFrame(rowRdd, schema_similarities)
// ========================================================================================
println("Register vocabulary and correlations as spark temp tables")
df_vocabulary.registerTempTable("df_vocabulary")
df_similarities.registerTempTable("df_similarities")
// ========================================================================================
println("Extracting Product_ID")
val temp_corrs = hiveContext.sql(
"SELECT T1.Product_ID, T2.Neighbor_ID, T1.Similarity_Score " +
"FROM " +
"(SELECT Product_ID, Neighbor_Index, Similarity_Score " +
"FROM df_similarities LEFT JOIN df_vocabulary " +
"WHERE df_similarities.Product_Index = df_vocabulary.Product_Index) AS T1 " +
"LEFT JOIN " +
"(SELECT Product_ID AS Neighbor_ID, Product_Index as Neighbor_Index FROM df_vocabulary) AS T2 " +
"ON " +
"T1.Neighbor_Index = T2.Neighbor_Index")
// ========================================================================================
val context_corrs = temp_corrs.withColumn("Context_ID", lit(context_id))
// ========================================================================================
context_corrs.registerTempTable("my_temp_table_correlations")
hiveContext.sql(s"INSERT INTO TABLE $table_name_correlations SELECT * FROM my_temp_table_correlations")
// ========================================================================================
// clean up environment
println("Cleaning up temp tables")
hiveContext.dropTempTable("my_temp_table_correlations")
hiveContext.dropTempTable("df_similarities")
hiveContext.dropTempTable("df_vocabulary")
}
}
问题:
解决此类问题的正确逻辑是什么?
我可以先按子类别对数据帧进行分区,然后根据子类别应用groupby并计算每个子类别的相似度,但如果数据集对于该分区来说太大,我可能会在构建NxN矩阵时得到OOM。