#include <bits/stdc++.h>
using ll = long long;
using std::cin;
using std::cout;
using std::vector;
using std::array;
using std::pair;
using std::ranges::min;
ll correct_sol(ll n, const vector<ll>& a) {
vector<array<ll, 60>> pref(n + 1);
for (ll i = 0; i < n; ++i) {
pref[i + 1] = pref[i];
for (ll b = 0; b < 60; ++b)
pref[i + 1][b] += (a[i] >> b) & 1;
}
ll ans = 0;
for (ll i = n - 1; i >= 0; --i) {
for (ll b = 0; i + (1 << b) - 1 < n; ++b) {
auto check = [&] (ll L, ll R) {
bool good = true;
for (ll b2 = b + 1; b2 < 60; ++b2) {
ll x = (a[i] >> b2) & 1;
ll y = pref[R][b2] - pref[L][b2];
if (x == 0) good &= y == 0;
else good &= y == (R - L);
}
return good;
};
ll L = i - 1 + (1 << b), R = min(n, L + (1 << b));
if (check(L, R)) {
ans += R - L;
continue;
}
ll lo = L - 1, hi = R - 1;
while (lo < hi) {
ll mid = (lo + hi) / 2;
if (check(lo, mid + 1))
lo = mid + 1;
else
hi = mid;
}
ans += hi - L;
break;
}
}
return ans;
}
ll my_sol(ll n, const vector<ll>& a) {
// bits[j].first stores whether the j'th bit of the last num was set
// bits[j].second stores how many contigous previous num had the same bits[j].first
vector<pair<ll, ll>> bits(60);
ll ans = 0;
for (ll i = n - 1; i >= 0; --i){
ll min_right = n;
for (ll j = 0; j < 60; ++j){
bool cur_bit = a[i] & (1ll << j);
if (cur_bit == bits[j].first)
++bits[j].second;
else
bits[j] = {cur_bit, 1};
if (bits[j].second < (1ll << j))
min_right = min(min_right, i + bits[j].second);
}
ans += min_right - i;
}
return ans;
}
static std::mt19937_64 RNG(std::chrono::high_resolution_clock::now().time_since_epoch().count());
int main(){
std::freopen("output.txt", "w", stdout);
ll t;
cin >> t;
while (t--){
ll n = RNG() % 200000 + 1;
vector<ll> a(n);
for (ll& x : a)
x = RNG() % (1ll << 60);
ll ans1 = correct_sol(n, a);
ll ans2 = my_sol(n, a);
if (ans1 != ans2){
cout << n << '\n';
for (ll& x : a)
cout << x << ' ';
cout << "\nCorrect: " << ans1 << "\nMine: " << ans2;
return 0;
}
}
}
Fails on test case: 1 7 1 2 3 4 5 8 5
Expected: 13 Got: 9
PS:
Is this your code? Looks like a poorly copied code or poorly copied idea from Editorial and coding it specially the lines inside if and checking DP
You can refer to my code here. It uses the same idea as yours but instead of manually checking, I used rangePattern sum which is from this start to this end I want the pattern to match for all bits and used DP to make it faster. You can use SOS DP or trie as well
Good news, I fixed the bug (which was a missing "-1") and it got Accepted!
PS: Why does my code look like it was copied?
Congratulations. I also got AC using your code just now XD. Apologies for the PS-1, let's say, I saw this pattern
vector<pair<ll, ll>> bits(60);and few more somewhere(wink). People can have similar code block and names but usually like what are the odds. Still my bad for misjudging.No problem.
By the way, in your code, you used
#define int long long;It invokes undefined behaviour.
Read "Reserved Macro names" on this website (int is a keyword).
"A translation unit that uses any part of the standard library is not allowed to #define or #undef names lexically identical to ... keywords ... Otherwise, the behavior is undefined".