Блог пользователя -egoist-

Автор -egoist-, история, 3 месяца назад, По-английски

Problem: Xor Sum

Given a constraint $$$N$$$, find the number of unique pairs $$$(u, v)$$$ such that $$$0 \le u, v \le N$$$ and there exist non-negative integers $$$a, b$$$ satisfying $$$a \oplus b = u$$$ and $$$a + b = v$$$.

The Core Idea: The relationship between bitwise XOR and arithmetic addition is governed by the identity $$$a + b = (a \oplus b) + 2(a \text{ AND } b)$$$. Because both $$$u$$$ and $$$v$$$ are derived from the same underlying bits $$$(a_i, b_i)$$$, we can solve this using Digit DP by processing bits from most to least significant. We represent the three unique ways $$$a$$$ and $$$b$$$ can contribute to the pair $$$(u, v)$$$ at any bit $$$i$$$ as "bit-sums" $$$s \in {0, 1, 2}$$$, where $$$s=1$$$ covers both $$$(0,1)$$$ and $$$(1,0)$$$ because they produce the same result for $$$u$$$ and $$$v$$$. By maintaining a "tightness" or "buffer" state $$$j$$$ (where $$$j$$$ tracks the difference between the prefix of $$$N$$$ and our constructed sum $$$v$$$), we can ensure $$$v \le N$$$. Crucially, since $$$a + b \ge a \oplus b$$$, satisfying the constraint for $$$v$$$ automatically satisfies it for $$$u$$$, allowing us to count unique result-pairs $$$(u, v)$$$ through simple state transitions $$$next_j = \min(2j + N_i - s, 2)$$$.

Code:

const int MOD = 1e9 + 7; 

void soln() {   
    int n;  cin >> n;  
    vector<int> dp(3, 0);
    dp[0] = 1; 
    for (int i = 60; i >= 0; i--) {
        vector<int> dp1(3, 0); 
        int ni = (n >> i) & 1;
        for (int diff = 0; diff < 3; diff++) {
            if (dp[diff] == 0) {
                continue; 
            }
            for (int sum = 0; sum < 3; sum++) { 
                int nxt_diff = 2 * diff + ni - sum;    
                if (nxt_diff < 0) continue;
                nxt_diff = min(nxt_diff, 2LL); 
                dp1[nxt_diff] += dp[diff];
                dp1[nxt_diff] %= MOD;  
            }
        } 
        dp = dp1;
    } 
    cout << ((dp[0] + dp[1]) % MOD + dp[2]) % MOD << endl;
}

Note: (v-u) is even, it can be shown that this condition automatically holds true for each bit.

Recursive version:

int dp(long long n) {
        
        if(n == 0) return 1;
        if(n == 1) return 2;
        if(mp.find(n) != mp.end()) return mp[n];

        return mp[n] = (1LL * dp((n - 2) / 2) + dp((n - 1) / 2) + dp(n / 2)) % MOD;
}

Полный текст и комментарии »

  • Проголосовать: нравится
  • -13
  • Проголосовать: не нравится