spark aggregateMessages tree data sum value of all node

38 Views Asked by At

i have a tree data, like this:

    (A) --> (B) --> (D)
     \
      \--> (C)

each node have a value. I want to agg total_value, asume V(i) is value of node i, T(i) is total_value of node i.

V(A) = 1
V(B) = 1
V(C) = 1
V(D) = 1

and my desired result is:

T(D) = 1
T(B) = T(D) + V(B) = 2
T(C) = 1
T(A) = T(A) + T(B) + V(A) = 4

But this is my spark result:

+---+----+
| id|cost|
+---+----+
|1-A|   2|
|1-B|   1|

My Code is:

df = spark.createDataFrame(
        [
            (1,"A", None, 1, 1),
            (1,"B", "A", 2, 1),
            (1,"C", "A", 2, 1),
            (1,"D", "B", 3, 1),
            (2,"A", None, 1, 1),
            (2,"B", "A", 2, 1),
        ], 
        ["frame_id", "node_id", "parent_node_id", "depth", "cost"]
    )
node_df = df.selectExpr("concat_ws('-', frame_id, node_id) as id", 
                            "cost as cost", 
                            "depth as depth", 
                            "frame_id as frame_id", 
                            "node_id as node_id", 
                            "parent_node_id as parent_node_id",
                            )
edge_df = df.selectExpr("concat_ws('-', frame_id, node_id) as src", "concat_ws('-', frame_id, parent_node_id) as dst")
    
g = GraphFrame(node_df, edge_df)
g.aggregateMessages(
        sum(AM.msg).alias("cost"), 
        sendToDst=AM.src["cost"],
    
    ).orderBy("id").show()

Where am i wrong?

Update:2023-04-06

this worked

g.aggregateMessages(
        sum(AM.msg).alias("cost_total"),
        sendToDst=AM.src["cost"] + AM.dst["cost"],
    ).orderBy("id").show()

but i hava many other attrs of every node. like value2, value3. how could i update these attrs within one aggregateMessages?

0

There are 0 best solutions below