LeetCode #834 Sum of Distances in Tree 树中距离之和

834 Sum of Distances in Tree 树中距离之和

Description:
There is an undirected connected tree with n nodes labeled from 0 to n - 1 and n - 1 edges.

You are given the integer n and the array edges where edges[i] = [ai, bi] indicates that there is an edge between nodes ai and bi in the tree.

Return an array answer of length n where answer[i] is the sum of the distances between the ith node in the tree and all other nodes.

Example:

Example 1:

lc-sumdist1

Input: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output: [8,12,6,10,10,10]
Explanation: The tree is shown above.
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8.
Hence, answer[0] = 8, and so on.

Example 2:

lc-sumdist2

Input: n = 1, edges = []
Output: [0]

Example 3:

lc-sumdist3

Input: n = 2, edges = [[1,0]]
Output: [1,1]

Constraints:

1 <= n <= 3 * 10^4
edges.length == n - 1
edges[i].length == 2
0 <= ai, bi < n
ai != bi
The given input represents a valid tree.

题目描述:
给定一个无向、连通的树。树中有 N 个标记为 0...N-1 的节点以及 N-1 条边 。

第 i 条边连接节点 edges[i][0] 和 edges[i][1] 。

返回一个表示节点 i 与其他所有节点距离之和的列表 ans。

示例 :

示例 1:

输入: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
输出: [8,12,6,10,10,10]
解释:
如下为给定的树的示意图:

  0
 / \
1   2
   /|\
  3 4 5

我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此类推。

说明:
1 <= N <= 10000

思路:

树形 DP
用一个数组 nodes 记录子结点数, 每个结点初始化为 1
dp[i] 表示到 i 各结点的距离之和, 每个结点初始化为 0
先用后序遍历, 得到所有的结点的子结点数
在后序遍历的同时, dp[i] 加上子结点的距离 dp[i] += dp[child] + nodes[child], 这样就完成了父结点到所有子结点的距离求和
然后还需要加上父结点到其他结点的距离和, 注意到这个时候已经求出了根结点的真正的距离和
因为已经求出来了子结点数, 所以某个结点的非子结点数是 n - nodes[i] - 1,
如果将父结点走到子结点改为从根结点出发 dp[child] = dp[root] - nodes[i] + (n - nodes[i])
dp[root] - nodes[i] 表示, 如果从根结点出发到当前结点需要少走 nodes[i] 的距离
n - nodes[i] 表示其他结点走到根结点再走到当前结点需要多走 1 步
时间复杂度为 O(n), 空间复杂度为 O(n)

代码:
C++:

class Solution
{
public:
    vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) 
    {
        vector<vector<int>> graph(n);
        vector<int> dp(n, 0), nodes(n, 1);
        for (const auto& edge : edges) 
        {
            graph[edge.front()].emplace_back(edge.back());
            graph[edge.back()].emplace_back(edge.front());
        }
        post(0, -1, dp, nodes, graph);
        pre(0, -1, dp, nodes, graph, n);
        return dp;
    }
private:
    void post(int root, int parent, vector<int>& dp, vector<int>& nodes, vector<vector<int>>& graph) 
    {
        for (const auto& child : graph[root]) 
        {
            if (child == parent) continue;
            post(child, root, dp, nodes, graph);
            nodes[root] += nodes[child];
            dp[root] += dp[child] + nodes[child];
        }
    }
    
    void pre(int root, int parent, vector<int>& dp, vector<int>& nodes, vector<vector<int>>& graph, int n) {
        for (const auto& child : graph[root]) 
        {
            if (child == parent) continue;
            dp[child] = dp[root] + n - (nodes[child] << 1);
            pre(child, root, dp, nodes, graph, n);
        }
    }
};

Java:

class Solution {
    public int[] sumOfDistancesInTree(int n, int[][] edges) {
        List<List<Integer>> graph = new ArrayList<>();
        for (int i = 0; i < n; i++) graph.add(new ArrayList<>());
        int dp[] = new int[n], nodes[] = new int[n];
        Arrays.fill(nodes, 1);
        for (int[] edge : edges) {
            graph.get(edge[0]).add(edge[1]);
            graph.get(edge[1]).add(edge[0]);
        }
        post(0, -1, dp, nodes, graph);
        pre(0, -1, dp, nodes, graph, n);
        return dp;
    }
    
    private void post(int root, int parent, int[] dp, int[] nodes, List<List<Integer>> graph) {
        for (int child : graph.get(root)) {
            if (child == parent) continue;
            post(child, root, dp, nodes, graph);
            nodes[root] += nodes[child];
            dp[root] += dp[child] + nodes[child];
        }
    }
    
    private void pre(int root, int parent, int[] dp, int[] nodes, List<List<Integer>> graph, int n) {
        for (int child : graph.get(root)) {
            if (child == parent) continue;
            dp[child] = dp[root] + n - (nodes[child] << 1);
            pre(child, root, dp, nodes, graph, n);
        }
    }
}

Python:

class Solution:
    def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
        graph, dp, nodes = [list() for _ in range(n)], [0] * n, [1] * n
        for u, v in edges:
            graph[u].append(v)
            graph[v].append(u)
            
        def post(root: int, parent: int) -> None:
            for child in graph[root]:
                if child == parent:
                    continue
                post(child, root)
                nodes[root] += nodes[child]
                dp[root] += dp[child] + nodes[child]
                
        def pre(root: int, parent: int) -> None:
            for child in graph[root]:
                if child == parent:
                    continue
                dp[child] = dp[root] + n - (nodes[child] << 1)
                pre(child, root)
                
        post(0, -1)
        pre(0, -1)
        return dp
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容