Incredibly beautiful DP optimization from N^3 to N log^2 N
Difference between en3 and en4, changed 3 character(s)
The task I want to discuss is [problem:739E]. While the official solution is a greedy algorithm sped up enough to pass the time limit, I recently came upon another solution. The main idea is to speed up the obvious dp approach, where we define dp[i][x][y] as the maximum expected number of caught pokemon in the prefix of first i pokemon, if we throw at most x A-pokeballs and at most y B-pokeballs. The computation of each state is O(1), so the complexity of this solution is O(n^3). ↵
[cut]↵
There is no obvious way to speed up this dp, because the transition of states is already done in O(1), and that's where dp optimization techniques usually cut the complexity. It's also useless to use some other definition of dp, since they will all take O(n^3) time to compute. But what we can do is to use the same trick used to solve the task Alien, from IOI 2016, or [problem:674C] in O(n log k) as [user:Radewoosh,2017-01-11] had described on his blog, and completely kick out a dimension from our dp!↵

Kicking out the 3rd dimension:↵
------------------------------↵
By kicking out the 3rd dimension, we're left with dp[i][x]. This is now defined as the highest expected number of caught pokemon in the prefix of i pokemon if we throw at most x A-pokeballs and any number of B-pokeballs. Obviously this will always use the maximum amount of B-pokeballs. But what's really cool is that we can actually try to simulate this last dimension: we define some C as a "cost" we have to pay every time we want to take a B-pokeball. This is essentially adding the functions f(x) = dp[n][a][x] and g(x) = -Cx. The cool thing is, f(x) is con
cavex, i.e. f(x+1) &mdash; f(x) <= f(x) &mdash; f(x-1). This is intuitive because whenever we get a new B-pokeball, we will always throw it at the best possible place. So if we get more and more of them, our expected number of caught pokemon will increase more and more slowly. And why is it useful that f(x) is convex? Well, h(x) = f(x) + g(x) has a non-trivial maximum, that we can find. And if h(x) is maximal, it means that for this C, it's optimal to throw x B-pokeballs. Now it's pretty obvious that we can do a binary search on this C to find one such that it's optimal to throw exactly b B-pokeballs, as given in the input. Inside our binary search we just do the O(n^2) algorithm, and when we finish, do a reconstruction of our solution to see how many B-pokeballs we've used, and use that information to continue binary searching. This gives us complexity O(n^2 log n), which is good enough to get AC. This trick was shown to us at our winter camp, which ended yesterday.↵

<spoiler summary="Code:">↵
~~~~~↵
#include <bits/stdc++.h>↵
using namespace std;↵

const int maxn = 2020;↵
const double eps = 1e-8;↵
int n, a, b, opt[maxn][maxn];↵
double dp[maxn][maxn], pa[maxn], pb[maxn], pab[maxn];↵

int solve(double mid){↵
    for(int i = 1; i <= n; i++){↵
        for(int j = 0; j <= a; j++){↵
            double &d = dp[i][j];↵
            int &o = opt[i][j];↵

            d = dp[i - 1][j];↵
            o = 0;↵

            if(j && d < dp[i - 1][j - 1] + pa[i]){↵
                d = dp[i - 1][j - 1] + pa[i];↵
                o = 1;↵
            }↵

            if(d < dp[i - 1][j] + pb[i] - mid){↵
                d = dp[i - 1][j] + pb[i] - mid;↵
                o = 2;↵
            }↵

            if(j && d < dp[i - 1][j - 1] + pab[i] - mid){↵
                d = dp[i - 1][j - 1] + pab[i] - mid;↵
                o = 3;↵
            }↵
        }↵
    }↵

    int ret = 0, la = a;↵

    for(int i = n; i >= 1; i--){↵
        if(opt[i][la] > 1)↵
            ret++;↵

        if(opt[i][la] & 1)↵
            la--;↵
    }↵

    return ret;↵
}↵

int main(){↵
    ios_base::sync_with_stdio(false);↵

    cin >> n >> a >> b;↵

    for(int i = 1; i <= n; i++)↵
        cin >> pa[i];↵

    for(int i = 1; i <= n; i++)↵
        cin >> pb[i];↵

    for(int i = 1; i <= n; i++)↵
        pab[i] = pa[i] + pb[i] - pa[i] * pb[i];↵

    double lo = 0, hi = 1, mid;↵

    for(int it = 0; it < 50; it++){↵
        mid = (lo + hi) / 2;↵

        if(solve(mid) > b)↵
            lo = mid;↵
        else↵
            hi = mid;↵
    }↵

    int ans = solve(hi);↵

    cout << fixed << setprecision(10) << dp[n][a] + hi * b << endl;↵

    return 0;↵
}↵
~~~~~↵
</spoiler>↵

Kicking out another dimension?↵
------------------------------↵
But is this all? Can we do better? Why can't we kick out the 2nd dimension in the same way we kicked out the first one? It turns out that in this task, we actually can! We just define D as the cost that we deduct each time we use an A-pokeball, and then using binary search find the C for which we use exactly enough B-pokeballs, and reconstruct the solution to see if we've used too many or too little A-pokeballs. The function is again convex, so the same trick works! Using this I was able to get AC in O(n log^2 n), which is pretty amazing for a Div1 E task with N <= 2000. My friends [user:vilim_l,2017-01-11], [user:jklepec,2017-01-11], [user:lukatiger,2017-01-11] and me are still amazed that this can be done!↵


<spoiler summary="Code:">↵
~~~~~↵
#include <bits/stdc++.h>↵
using namespace std;↵

typedef pair<int, int> pii;↵

const int maxn = 2020;↵
const double eps = 1e-8;↵
int n, a, b, opt[maxn];↵
double dp[maxn], pa[maxn], pb[maxn], pab[maxn];↵

pii solve(double &D, double &C){↵
    for(int i = 1; i <= n; i++){↵
        double &d = dp[i];↵
        int &o = opt[i];↵

        d = dp[i - 1];↵
        o = 0;↵

        if(d < dp[i - 1] + pa[i] - D){↵
            d = dp[i - 1] + pa[i] - D;↵
            o = 1;↵
        }↵

        if(d < dp[i - 1] + pb[i] - C){↵
            d = dp[i - 1] + pb[i] - C;↵
            o = 2;↵
        }↵

        if(d < dp[i - 1] + pab[i] - C - D){↵
            d = dp[i - 1] + pab[i] - C - D;↵
            o = 3;↵
        }↵
    }↵

    pii ret = pii(0, 0);↵

    for(int i = 1; i <= n; i++){↵
        if(opt[i] > 1)↵
            ret.second++;↵

        if(opt[i] & 1)↵
            ret.first++;↵
    }↵

    return ret;↵
}↵

int main(){↵
    ios_base::sync_with_stdio(false);↵

    cin >> n >> a >> b;↵

    for(int i = 1; i <= n; i++)↵
        cin >> pa[i];↵

    for(int i = 1; i <= n; i++)↵
        cin >> pb[i];↵

    for(int i = 1; i <= n; i++)↵
        pab[i] = pa[i] + pb[i] - pa[i] * pb[i];↵

    double lo = 0, hi = 1, mid, lo2, hi2, mid2;↵

    for(int it2 = 0; it2 < 50; it2++){↵
        mid = (lo + hi) / 2;↵

        lo2 = 0, hi2 = 1, mid2;↵

        for(int it = 0; it < 50; it++){↵
            mid2 = (lo2 + hi2) / 2;↵

            if(solve(mid, mid2).second > b)↵
                lo2 = mid2;↵
            else↵
                hi2 = mid2;↵
        }↵

        if(solve(mid, hi2).first > a)↵
            lo = mid;↵
        else↵
            hi = mid;↵
    }↵

    solve(hi, hi2);↵

    cout << fixed << setprecision(10) << dp[n] + hi2 * b + hi * a << endl;↵

    return 0;↵
}↵
~~~~~↵
</spoiler>↵

History

 
 
 
 
Revisions
 
 
  Rev. Lang. By When Δ Comment
en5 English linkret 2017-01-11 06:14:55 3 Tiny change: ' again convex, so the s' -> ' again concave, so the s'
en4 English linkret 2017-01-11 06:07:22 3 Tiny change: '(x) is convex, i.e. f(x' -> '(x) is concave, i.e. f(x'
en3 English linkret 2017-01-11 06:06:42 1 Tiny change: 'sh; f(x) < f(x) &mda' -> 'sh; f(x) <= f(x) &mda'
en2 English linkret 2017-01-11 04:04:06 7
en1 English linkret 2017-01-11 02:48:38 7017 Initial revision (published)