假设我有以下图表:
scala> v.show()
+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC| null|
|BBB| null|
|QQQ| null|
|DDD| null|
|FFF| null|
|EEE| null|
|AAA| null|
|GGG| null|
+---+---------------+
scala> e.show()
+---+---+---+
| iD|src|dst|
+---+---+---+
| 1|CCC|AAA|
| 2|CCC|BBB|
...
+---+---+---+
我想运行一个聚合,它获取从目标顶点发送到源顶点的所有消息(不仅仅是sum,first,last等)。所以我想要运行的命令是这样的:
g.aggregateMessages.sendToSrc(AM.edge("id")).agg(all(AM.msg).as("downstreamEdges")).show()
除了函数all
不存在(不是我知道的)。输出将类似于:
+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC| [1, 2]|
...
+---+---------------+
我可以将上述功能与first
或last
一起使用,而不是(不存在的)all
,但他们只会给我
+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC| 1|
...
+---+---------------+
或
+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC| 2|
...
+---+---------------+
分别。我怎么能保留所有条目? (可能有很多,不仅仅是1和2,而是1,2,23,45等)。感谢。
答案 0 :(得分:0)
我改编this answer以提出以下内容:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.graphframes.lib.AggregateMessages
class KeepAllString extends UserDefinedAggregateFunction {
private val AM = AggregateMessages
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("ids", ArrayType(StringType, containsNull = true), nullable = true) :: Nil
)
// This is the output type of your aggregatation function.
override def dataType: DataType = ArrayType(StringType,true)
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = Seq[String]()
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit =
buffer(0) = buffer.getAs[Seq[String]](0) ++ Seq(input.getAs[String](0))
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
buffer1(0) = buffer1.getAs[Seq[String]](0) ++ buffer2.getAs[Seq[String]](0)
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = buffer.getAs[Seq[String]](0)
}
他们上面我的all
方法就是:val all = new KeepAllString()
。
但是如何使它成为通用的,所以对于BigDecimal,Timestamp等我可以做类似的事情:
val allTimestamp = new KeepAll[Timestamp]()
?
答案 1 :(得分:0)
我通过使用聚合函数collect_set()
agg = gx.aggregateMessages(
f.collect_set(AM.msg).alias("aggMess"),
sendToSrc=AM.edge("id")
sendToDst=None)
另一人(重复)为collect_list()