Description
Submission
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
int ret;
int d;
vector<int> merge(vector<int>& v1, vector<int>& v2) {
int i = 0, j = 0;
vector<int> rets;
while( i < v1.size() && j < v2.size()) {
if(v1[i] <= v2[j]) rets.push_back(v1[i++]);
else rets.push_back(v2[j++]);
}
if(i == v1.size()) {
while(j < v2.size()) rets.push_back(v2[j++]);
} else if(j == v2.size()) {
while(i < v1.size()) rets.push_back(v1[i++]);
}
return rets;
}
vector<int> incOne(vector<int> v) {
for(int i = 0; i < v.size(); ++i) {
v[i]++;
}
return v;
}
vector<int> dfs(TreeNode* cur) {
if(!cur->left && !cur->right) return {1};
if(!cur->left) {
return incOne(dfs(cur->right));
}
if(!cur->right) {
return incOne(dfs(cur->left));
}
vector<int> left = dfs(cur->left);
vector<int> right = dfs(cur->right);
for(auto x : left) {
if(x > d) break;
auto it = upper_bound(right.begin(), right.end(), d - x);
ret += (int)(it-right.begin());
}
return incOne(merge(left, right));
}
public:
int countPairs(TreeNode* root, int distance) {
ret = 0;
d = distance;
dfs(root);
return ret;
}
};