题目来源
求所有路径和为sum的条数。
我的第一版本代码如下,用一个vector记录以该点为终点的路径的sum,然后一一进行对比。
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
int res = 0;
int pathSum(TreeNode* root, int sum) {
if (!root)
return 0;
vector<int> nums;
dfs(root, sum, nums);
return res;
}
void dfs(TreeNode *node, const int &sum, vector<int> &nums)
{
nums.push_back(0);
for (int i=0; i<nums.size(); i++) {
nums[i] += node->val;
if (nums[i] == sum)
res++;
}
if (node->left) {
dfs(node->left, sum, nums);
nums.pop_back();
for (int i=0; i<nums.size(); i++)
nums[i] -= node->left->val;
}
if (node->right) {
dfs(node->right, sum, nums);
nums.pop_back();
for (int i=0; i<nums.size(); i++)
nums[i] -= node->right->val;
}
}
};
上面用了一个全局变量,其实可以不用的,想想修改如下:
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
int pathSum(TreeNode* root, int sum) {
if (!root)
return 0;
vector<int> nums;
return dfs(root, sum, nums);
}
int dfs(TreeNode *node, const int &sum, vector<int> &nums)
{
int res = 0;
if (!node)
return 0;
nums.push_back(0);
for (int i=0; i<nums.size(); i++) {
nums[i] += node->val;
if (nums[i] == sum)
res++;
}
res += dfs(node->left, sum, nums) + dfs(node->right, sum, nums);
nums.pop_back();
for (int i=0; i<nums.size(); i++)
nums[i] -= node->val;
return res;
}
};
然后看了下讨论区,发现可以把vector优化下,变成一个map,记录这所有的从根节点开始的前缀,然后你只要判断cursum-sum是否是某个前缀就可以了。复杂度变低了一些。
代码如下:
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
int pathSum(TreeNode* root, int sum) {
unordered_map<int, int> map;
return dfs(root, sum, map, 0);
}
int dfs(TreeNode *node, const int &sum, unordered_map<int, int> &map, int pre)
{
if (!node)
return 0;
int res = 0;
int cursum = node->val;
cursum += pre;
res += (cursum == sum) + (map[cursum-sum] ? map[cursum-sum] : 0);
map[cursum]++;
res += dfs(node->left, sum, map, cursum) + dfs(node->right, sum, map, cursum);
map[cursum]--;
return res;
}
};