I was solving this Problem Link and encountered a bug. I'm getting WA on test 10, but I run stress test locally on my PC and didn't found any counter test cases. Please take a look on my code.
Here are my code.
Test Case Generator
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 100;
int main(int argc, char* argv[]) {
ios::sync_with_stdio(false);
cin.tie(0);
mt19937_64 rnd(atoi(argv[1]));
auto next = [&](ll x) {
return rnd() % x;
};
auto randRange = [&](ll a, ll b) {
return a + rnd() % (b - a + 1);
};
int n = randRange(2, N);
int q = randRange(1, N);
cout << n << " " << q << "\n";
for(int i = 0; i < n; ++i)
cout << randRange(0, 2147483647) << " \n"[i == n - 1];
vector<int> a(1), b(n - 1);
iota(b.begin(), b.end(), 1);
random_shuffle(b.begin(), b.end(), next);
while(!b.empty()) {
int u = a[next((int) a.size())];
int v = b.back();
if(next(2) == 1)
swap(u, v);
cout << u + 1 << " " << v + 1 << "\n";
a.push_back(b.back());
b.pop_back();
}
for(int i = 0; i < q; ++i)
cout << randRange(1, n) << " " << randRange(1, n) << "\n";
return 0;
}
Brute force solution
#include <bits/stdc++.h>
using namespace std;
int main(int argc, char* argv[]) {
ios::sync_with_stdio(false);
cin.tie(0);
int n, q;
cin >> n >> q;
vector<int> a(n);
for(auto& x : a)
cin >> x;
vector<vector<int>> adj(n);
for(int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
--u, --v;
adj[u].push_back(v);
adj[v].push_back(u);
}
vector<int> parent(n);
vector<int> depth(n);
function<void(int, int)> dfs = [&](int u, int p) {
parent[u] = p;
depth[u] = (u == 0 ? 0 : depth[p] + 1);
for(int& v : adj[u]) {
if(v == p)
continue;
dfs(v, u);
}
};
dfs(0, -1);
while(q--) {
int u, v;
cin >> u >> v;
--u, --v;
if(depth[u] < depth[v])
swap(u, v);
set<int> s;
s.insert(a[u]);
s.insert(a[v]);
while(depth[u] > depth[v]) {
u = parent[u];
s.insert(a[u]);
}
while(u != v) {
u = parent[u];
v = parent[v];
s.insert(a[u]);
s.insert(a[v]);
}
cout << s.size() << "\n";
}
return 0;
}
My submission
#include <bits/stdc++.h>
using namespace std;
int main(int argc, char* argv[]) {
ios::sync_with_stdio(false);
cin.tie(0);
int n, q;
cin >> n >> q;
vector<int> a(n);
for(auto& x : a)
cin >> x;
vector<int> b(a);
sort(b.begin(), b.end());
for(auto& x : a)
x = lower_bound(b.begin(), b.end(), x) - b.begin();
vector<vector<int>> adj(n);
for(int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
--u, --v;
adj[u].push_back(v);
adj[v].push_back(u);
}
vector<vector<int>> dp(n, vector<int>(20));
vector<int> depth(n);
vector<int> in(n), out(n);
vector<int> order;
order.reserve(n * 2);
function<void(int, int)> dfs = [&](int u, int p) {
dp[u][0] = (u == 0 ? 0 : p);
depth[u] = (u == 0 ? 0 : depth[p] + 1);
in[u] = (int) order.size();
order.push_back(u);
for(int& v : adj[u])
if(v != p)
dfs(v, u);
out[u] = (int) order.size();
order.push_back(u);
};
dfs(0, -1);
for(int x = 1; x < 20; ++x)
for(int i = 0; i < n; ++i)
dp[i][x] = dp[dp[i][x - 1]][x - 1];
auto lift = [&](int u, int step) -> int {
int r = 0;
while(step) {
if(step & 1)
u = dp[u][r];
step >>= 1;
r <<= 1;
}
return u;
};
auto lca = [&](int x, int y) -> int {
if(depth[x] < depth[y])
swap(x, y);
x = lift(x, depth[x] - depth[y]);
if(x == y)
return x;
for(int i = 19; i >= 0; --i) {
int new_x = dp[x][i];
int new_y = dp[y][i];
if(new_x != new_y) {
x = new_x;
y = new_y;
}
}
assert(dp[x][0] == dp[y][0]);
return dp[x][0];
};
vector<tuple<int, int, int, int>> qry(q);
for(int i = 0; i < q; ++i) {
int u, v;
cin >> u >> v;
--u, --v;
if(depth[u] > depth[v])
swap(u, v);
int z = lca(u, v);
if(u == z)
qry[i] = {in[u], in[v], -1, i};
else
qry[i] = {out[u], in[v], z, i};
}
const int m = sqrt(n * 2) + 1;
sort(qry.begin(), qry.end(), [&](const tuple<int, int, int, int>& x, const tuple<int, int, int, int>& y) {
if(get<0>(x) / m == get<0>(y) / m)
return (get<0>(x) / m & 1 ? get<1>(x) > get<1>(y) : get<1>(x) < get<1>(y));
return get<0>(x) < get<0>(y);
});
vector<int> occurrence(n);
vector<int> cnt((int) b.size());
int counter = 0;
auto add = [&](int u) {
++occurrence[u];
if(occurrence[u] % 2 == 0) {
--cnt[a[u]];
if(cnt[a[u]] == 0)
--counter;
} else {
++cnt[a[u]];
if(cnt[a[u]] == 1)
++counter;
}
};
auto remove = [&](int u) {
--occurrence[u];
if(occurrence[u] % 2 == 0) {
--cnt[a[u]];
if(cnt[a[u]] == 0)
--counter;
} else {
++cnt[a[u]];
if(cnt[a[u]] == 1)
++counter;
}
};
vector<int> ans(q);
int l = 0, r = 0;
add(order[0]);
for(int i = 0; i < q; ++i) {
int L, R, z, id;
tie(L, R, z, id) = qry[i];
while(l < L) {
remove(order[l]);
++l;
}
while(l > L) {
--l;
add(order[l]);
}
while(r < R) {
++r;
add(order[r]);
}
while(r > R) {
remove(order[r]);
--r;
}
if(z != -1)
add(z);
ans[id] = counter;
if(z != -1)
remove(z);
}
for(auto& x : ans)
cout << x << "\n";
return 0;
}