Prereq: [this digit dp blog](https://mirror.codeforces.com/blog/entry/53960)↵
↵
Cut the flag dimension↵
-------------------------↵
↵
Usually, whatever states you use in the recursive dp function, you will memoize it. And often you will have some thing like this ↵
↵
~~~~~↵
int memo[pos][...][...][low]↵
~~~~~↵
↵
Where `low` is the flag that checks if the current number is already smaller than the considered number.↵
↵
It is totally possible to subtract this dimension (half the memory needed) by manipulating it in the recursive function:↵
↵
Example problem: [Perfect Number](https://mirror.codeforces.com/contest/919/problem/B)↵
↵
This is what "normal" code would look like:↵
↵
<spoiler summary="normal code">↵
↵
~~~~~↵
ll mem[20][11][2];↵
↵
ll dp(int pos, int sum, bool lo) {↵
if(sum > 10) return 0;↵
if(pos == n) return (sum == 10);↵
↵
ll& res = mem[pos][sum][lo];↵
if(res != -1) return res;↵
↵
res = 0;↵
↵
int mx = lo ? 9 : v[pos];↵
for(int d = 0; d <= mx; d++)↵
res += dp(pos + 1, sum + d, lo || (d < mx));↵
↵
return res;↵
}↵
~~~~~↵
↵
</spoiler>↵
↵
And this is the optimized code:↵
↵
<spoiler summary="optimized code">↵
~~~~~↵
ll mem[20][11];↵
↵
ll dp(int pos, int sum, bool lo) {↵
if(sum > 10) return 0;↵
if(pos == n) return (sum == 10);↵
↵
ll& res = mem[pos][sum];↵
if(lo && res != -1) return res;↵
↵
ll ans = 0;↵
int mx = lo ? 9 : v[pos];↵
for(int d = 0; d <= mx; d++)↵
ans += dp(pos + 1, sum + d, lo || (d < mx));↵
↵
return lo ? res = ans : ans;↵
}↵
~~~~~↵
</spoiler>↵
↵
Basically this trick only store the number if the low flag is on, since low isn't necessary to be memoized because its only meaning is to set the limit for the current digit.↵
↵
_Full submission for the "normal" code: ["normal" code](https://mirror.codeforces.com/contest/919/submission/130424609)_↵
__↵
_Full submission for the optimized code: [optimized code](https://mirror.codeforces.com/contest/919/submission/130424695)_↵
↵
---↵
↵
Different ways to memset↵
---------------------------↵
↵
### "Normal" memset↵
↵
You memset every time you dp. This would takes a huge amount of time if you have to call dp many time or the memory is large. ↵
Example problem: [LIDS](https://toph.co/p/lids)↵
↵
Example slow code:↵
↵
<spoiler summary="slow code">↵
~~~~~↵
#include <bits/stdc++.h>↵
using namespace std;↵
#define int long long↵
↵
// the maximum length of LIDS is 10↵
// so we can check for each length k,↵
// in how many ways can make a number with LIDS = k↵
↵
// then we can print the result we found for the maximum k↵
↵
int a, b;↵
vector<int> v;↵
↵
int mem[11][11][2][2][11];↵
↵
int dp(int pos, int last, bool small, bool nonzero, int need) {↵
if(pos == (int)v.size())↵
return need == 0;↵
↵
if(mem[pos][last + 1][small][nonzero][need] != -1)↵
return mem[pos][last + 1][small][nonzero][need];↵
↵
int res = 0; // res is the result for dp(pos, last + 1, small, nonzero, need)↵
int mx = small ? 9 : v[pos];↵
for(int d = 0; d <= mx; d++) {↵
res += dp(pos + 1, last, (d < mx) || small, nonzero || d, need);↵
if(d > last && need && (nonzero || d))↵
res += dp(pos + 1, d, (d < mx) || small, 1, need - 1);↵
}↵
↵
return mem[pos][last + 1][small][nonzero][need] = res;↵
}↵
↵
void convert(int x) {↵
// convert into array of digit↵
v.clear();↵
while(x) {↵
v.push_back(x % 10);↵
x /= 10;↵
}↵
reverse(v.begin(), v.end());↵
}↵
↵
pair<int, int> solve(int st, int en) {↵
vector<int> lids(10);↵
↵
convert(en);↵
for(int i = 1; i < 10; i++) {↵
memset(mem, -1, sizeof mem);↵
lids[i] += dp(0, -1, 0, 0, i);↵
}↵
↵
convert(st - 1);↵
for(int i = 1; i < 10; i++) {↵
memset(mem, -1, sizeof mem);↵
lids[i] -= dp(0, -1, 0, 0, i);↵
}↵
↵
for(int i = 10 - 1; i >= 1; i--)↵
if(lids[i]) return {i, lids[i]};↵
↵
return {0, 1};↵
}↵
↵
signed main() {↵
ios::sync_with_stdio(0);↵
cin.tie(0);↵
↵
int T;↵
cin >> T;↵
↵
for(int tc = 1; tc <= T; tc++) {↵
cin >> a >> b;↵
pair<int, int> res = solve(a, b);↵
cout << "Case " << tc << ": " << res.first << ' ' << res.second;↵
if(tc != T) cout << "\n";↵
}↵
}↵
~~~~~↵
↵
</spoiler>↵
↵
You can clearly see that memset is executed many times for all digits, this would have complexity $\mathcal{O}(T * digits * memsize)$ where $T$ is the number of testcases, and $digits$ is the number of digits (from 0 to 9 in this case), and $memsize$ is the size of memory.↵
↵
This is extremely slow.↵
↵
### Improvement using "time"↵
↵
Now instead of memset every time you dp, you can keep an additional array `vis[pos][...][...]` which will store the "time" that the value in `mem[pos][...][...]` is set.↵
↵
<spoiler summary="code">↵
~~~~~↵
...↵
if(vis[pos][last + 1][small][nonzero][need] == cur)↵
return mem[pos][last + 1][small][nonzero][need];↵
vis[pos][last + 1][small][nonzero][need] = cur;↵
...↵
~~~~~↵
</spoiler>↵
↵
This way, the complexity is better but this is still too slow for many problems.↵
↵
### memset only once↵
↵
You might wonder, "but how? you are doing dp many times on many different numbers!". Well actually, we are doing dp on the **digits**.↵
↵
You might notice that we're always doing dp from the most significant digit to the least, usually from left to right, the most significant digit will be at position $0$ and the least at position $length - 1$.↵
↵
This way, the memory for each number is different, like number $100$ will have different memory from number $1234$ since they have different $length$ and other states.↵
↵
However, what if we let the most significant digit to be at position $length - 1$ and the least at position $0$?↵
↵
Now, every digits of every number line up, and you only need to memset once only.↵
↵
Example solution of: [Perfect Number](https://mirror.codeforces.com/contest/919/problem/B)↵
↵
<spoiler summary="code">↵
~~~~~↵
int dp(int pos, int sum, bool lo) {↵
if(sum > 10) return 0;↵
if(pos == -1) return (sum == 10);↵
↵
int& res = mem[pos][sum];↵
if(lo && res != -1) return res;↵
↵
int ans = 0;↵
int mx = lo ? 9 : v[pos];↵
for(int d = 0; d <= mx; d++)↵
ans += dp(pos - 1, sum + d, lo || (d < mx));↵
↵
return lo ? res = ans : ans;↵
}↵
↵
int solve(int x) {↵
v.clear();↵
while(x) {↵
v.push_back(x % 10);↵
x /= 10;↵
}↵
// notice that I don't reverse the number anymore↵
↵
// start from length - 1↵
return dp((int)v.size() - 1, 0, 0);↵
}↵
↵
int main() {↵
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);↵
↵
int k;↵
cin >> k;↵
↵
int l = 1;↵
int r = 2e7;↵
↵
memset(mem, -1, sizeof mem);↵
while(l < r - 1) {↵
...↵
~~~~~↵
</spoiler>↵
↵
**This is an extremely important optimization for digit dp**↵
↵
## Other optimizations problem-wise↵
↵
### Check sum of digits divisibility↵
↵
##### _For a single number_↵
↵
If you want to check if the sum of digits of a number is divisible by $D$. Instead of storing the whole sum(could lead to MLE), you can store only the remainder of the sum when divided by $D$.↵
↵
##### _For many numbers_↵
↵
Example problem: [WORKCHEF](https://www.codechef.com/problems/WORKCHEF) (highly recommended, you will need to use a lot of optimizations to AC)↵
↵
For many numbers, instead of having a state for the remainder for each number, eg: `dp[...][rem2][rem3][...]` you can store the remainder of their LCM, eg: checking sum of digits divisible by 1, 2, 3, ... , 9 -> check divisibility by $LCM(1, 2, ..., 9) = 2520$.↵
↵
##### _For numbers with special properties_↵
↵
If you want to check divisibility by 5, the last digit need to be 0 or 5.↵
For 10, the last digit obviously must be 0.↵
...↵
There are also many properties for different numbers.↵
↵
### Another way of digit dp↵
↵
From [this stackoverflow question](//http://stackoverflow.com/questions/22394257/how-to-count-integers-between-large-a-and-b-with-a-certain-property/22394258#22394258)↵
↵
This can be very handy when handling problems relating to the structure of the numbers, eg: [Palindromic Numbers](https://vjudge.net/problem/LightOJ-1205)↵
↵
Example code:↵
↵
<spoiler summary="code">↵
~~~~~↵
// i - position, l - leftmostlower, h - leftmosthigher, ze - numbers of leading zeros↵
ll dp(int i, int l, int h, int ze) {↵
// imagine it as n - i - 1, and plus the offset of leading zeros↵
// i is already offset by leading zeros↵
int j = n - i - 1 + ze;↵
↵
if(i > j) return l <= h;↵
if(vis[i][l][h][ze] == cur) return mem[i][l][h][ze];↵
vis[i][l][h][ze] = cur;↵
↵
ll res = 0;↵
for(int d = 0; d <= 9; d++) {↵
int nl = l;↵
int nh = h;↵
↵
if(d < v[i] && i < nl) nl = i;↵
if(d < v[j] && j < nl) nl = j;↵
if(d > v[i] && i < nh) nh = i;↵
if(d > v[j] && j < nh) nh = j;↵
↵
res += dp(i + 1, nl, nh, ze + (i == ze && d == 0));↵
}↵
↵
return mem[i][l][h][ze] = res;↵
}↵
~~~~~↵
</spoiler>↵
↵
---↵
↵
Feel free to share any tricks or anything that people should know when doing digit dp!↵
If there is any mistakes or suggestions, please let me know.
↵
Cut the flag dimension↵
-------------------------↵
↵
Usually, whatever states you use in the recursive dp function, you will memoize it. And often you will have some thing like this ↵
↵
~~~~~↵
int memo[pos][...][...][low]↵
~~~~~↵
↵
Where `low` is the flag that checks if the current number is already smaller than the considered number.↵
↵
It is totally possible to subtract this dimension (half the memory needed) by manipulating it in the recursive function:↵
↵
Example problem: [Perfect Number](https://mirror.codeforces.com/contest/919/problem/B)↵
↵
This is what "normal" code would look like:↵
↵
<spoiler summary="normal code">↵
↵
~~~~~↵
ll mem[20][11][2];↵
↵
ll dp(int pos, int sum, bool lo) {↵
if(sum > 10) return 0;↵
if(pos == n) return (sum == 10);↵
↵
ll& res = mem[pos][sum][lo];↵
if(res != -1) return res;↵
↵
res = 0;↵
↵
int mx = lo ? 9 : v[pos];↵
for(int d = 0; d <= mx; d++)↵
res += dp(pos + 1, sum + d, lo || (d < mx));↵
↵
return res;↵
}↵
~~~~~↵
↵
</spoiler>↵
↵
And this is the optimized code:↵
↵
<spoiler summary="optimized code">↵
~~~~~↵
ll mem[20][11];↵
↵
ll dp(int pos, int sum, bool lo) {↵
if(sum > 10) return 0;↵
if(pos == n) return (sum == 10);↵
↵
ll& res = mem[pos][sum];↵
if(lo && res != -1) return res;↵
↵
ll ans = 0;↵
int mx = lo ? 9 : v[pos];↵
for(int d = 0; d <= mx; d++)↵
ans += dp(pos + 1, sum + d, lo || (d < mx));↵
↵
return lo ? res = ans : ans;↵
}↵
~~~~~↵
</spoiler>↵
↵
Basically this trick only store the number if the low flag is on, since low isn't necessary to be memoized because its only meaning is to set the limit for the current digit.↵
↵
_Full submission for the "normal" code: ["normal" code](https://mirror.codeforces.com/contest/919/submission/130424609)_↵
__↵
_Full submission for the optimized code: [optimized code](https://mirror.codeforces.com/contest/919/submission/130424695)_↵
↵
---↵
↵
Different ways to memset↵
---------------------------↵
↵
### "Normal" memset↵
↵
You memset every time you dp. This would takes a huge amount of time if you have to call dp many time or the memory is large. ↵
Example problem: [LIDS](https://toph.co/p/lids)↵
↵
Example slow code:↵
↵
<spoiler summary="slow code">↵
~~~~~↵
#include <bits/stdc++.h>↵
using namespace std;↵
#define int long long↵
↵
// the maximum length of LIDS is 10↵
// so we can check for each length k,↵
// in how many ways can make a number with LIDS = k↵
↵
// then we can print the result we found for the maximum k↵
↵
int a, b;↵
vector<int> v;↵
↵
int mem[11][11][2][2][11];↵
↵
int dp(int pos, int last, bool small, bool nonzero, int need) {↵
if(pos == (int)v.size())↵
return need == 0;↵
↵
if(mem[pos][last + 1][small][nonzero][need] != -1)↵
return mem[pos][last + 1][small][nonzero][need];↵
↵
int res = 0; // res is the result for dp(pos, last + 1, small, nonzero, need)↵
int mx = small ? 9 : v[pos];↵
for(int d = 0; d <= mx; d++) {↵
res += dp(pos + 1, last, (d < mx) || small, nonzero || d, need);↵
if(d > last && need && (nonzero || d))↵
res += dp(pos + 1, d, (d < mx) || small, 1, need - 1);↵
}↵
↵
return mem[pos][last + 1][small][nonzero][need] = res;↵
}↵
↵
void convert(int x) {↵
// convert into array of digit↵
v.clear();↵
while(x) {↵
v.push_back(x % 10);↵
x /= 10;↵
}↵
reverse(v.begin(), v.end());↵
}↵
↵
pair<int, int> solve(int st, int en) {↵
vector<int> lids(10);↵
↵
convert(en);↵
for(int i = 1; i < 10; i++) {↵
memset(mem, -1, sizeof mem);↵
lids[i] += dp(0, -1, 0, 0, i);↵
}↵
↵
convert(st - 1);↵
for(int i = 1; i < 10; i++) {↵
memset(mem, -1, sizeof mem);↵
lids[i] -= dp(0, -1, 0, 0, i);↵
}↵
↵
for(int i = 10 - 1; i >= 1; i--)↵
if(lids[i]) return {i, lids[i]};↵
↵
return {0, 1};↵
}↵
↵
signed main() {↵
ios::sync_with_stdio(0);↵
cin.tie(0);↵
↵
int T;↵
cin >> T;↵
↵
for(int tc = 1; tc <= T; tc++) {↵
cin >> a >> b;↵
pair<int, int> res = solve(a, b);↵
cout << "Case " << tc << ": " << res.first << ' ' << res.second;↵
if(tc != T) cout << "\n";↵
}↵
}↵
~~~~~↵
↵
</spoiler>↵
↵
You can clearly see that memset is executed many times for all digits, this would have complexity $\mathcal{O}(T * digits * memsize)$ where $T$ is the number of testcases, and $digits$ is the number of digits (from 0 to 9 in this case), and $memsize$ is the size of memory.↵
↵
This is extremely slow.↵
↵
### Improvement using "time"↵
↵
Now instead of memset every time you dp, you can keep an additional array `vis[pos][...][...]` which will store the "time" that the value in `mem[pos][...][...]` is set.↵
↵
<spoiler summary="code">↵
~~~~~↵
...↵
if(vis[pos][last + 1][small][nonzero][need] == cur)↵
return mem[pos][last + 1][small][nonzero][need];↵
vis[pos][last + 1][small][nonzero][need] = cur;↵
...↵
~~~~~↵
</spoiler>↵
↵
This way, the complexity is better but this is still too slow for many problems.↵
↵
### memset only once↵
↵
You might wonder, "but how? you are doing dp many times on many different numbers!". Well actually, we are doing dp on the **digits**.↵
↵
You might notice that we're always doing dp from the most significant digit to the least, usually from left to right, the most significant digit will be at position $0$ and the least at position $length - 1$.↵
↵
This way, the memory for each number is different, like number $100$ will have different memory from number $1234$ since they have different $length$ and other states.↵
↵
However, what if we let the most significant digit to be at position $length - 1$ and the least at position $0$?↵
↵
Now, every digits of every number line up, and you only need to memset once only.↵
↵
Example solution of: [Perfect Number](https://mirror.codeforces.com/contest/919/problem/B)↵
↵
<spoiler summary="code">↵
~~~~~↵
int dp(int pos, int sum, bool lo) {↵
if(sum > 10) return 0;↵
if(pos == -1) return (sum == 10);↵
↵
int& res = mem[pos][sum];↵
if(lo && res != -1) return res;↵
↵
int ans = 0;↵
int mx = lo ? 9 : v[pos];↵
for(int d = 0; d <= mx; d++)↵
ans += dp(pos - 1, sum + d, lo || (d < mx));↵
↵
return lo ? res = ans : ans;↵
}↵
↵
int solve(int x) {↵
v.clear();↵
while(x) {↵
v.push_back(x % 10);↵
x /= 10;↵
}↵
// notice that I don't reverse the number anymore↵
↵
// start from length - 1↵
return dp((int)v.size() - 1, 0, 0);↵
}↵
↵
int main() {↵
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);↵
↵
int k;↵
cin >> k;↵
↵
int l = 1;↵
int r = 2e7;↵
↵
memset(mem, -1, sizeof mem);↵
while(l < r - 1) {↵
...↵
~~~~~↵
</spoiler>↵
↵
**This is an extremely important optimization for digit dp**↵
↵
## Other optimizations problem-wise↵
↵
### Check sum of digits divisibility↵
↵
##### _For a single number_↵
↵
If you want to check if the sum of digits of a number is divisible by $D$. Instead of storing the whole sum(could lead to MLE), you can store only the remainder of the sum when divided by $D$.↵
↵
##### _For many numbers_↵
↵
Example problem: [WORKCHEF](https://www.codechef.com/problems/WORKCHEF) (highly recommended, you will need to use a lot of optimizations to AC)↵
↵
For many numbers, instead of having a state for the remainder for each number, eg: `dp[...][rem2][rem3][...]` you can store the remainder of their LCM, eg: checking sum of digits divisible by 1, 2, 3, ... , 9 -> check divisibility by $LCM(1, 2, ..., 9) = 2520$.↵
↵
##### _For numbers with special properties_↵
↵
If you want to check divisibility by 5, the last digit need to be 0 or 5.↵
For 10, the last digit obviously must be 0.↵
...↵
There are also many properties for different numbers.↵
↵
### Another way of digit dp↵
↵
From [this stackoverflow question](//http://stackoverflow.com/questions/22394257/how-to-count-integers-between-large-a-and-b-with-a-certain-property/22394258#22394258)↵
↵
This can be very handy when handling problems relating to the structure of the numbers, eg: [Palindromic Numbers](https://vjudge.net/problem/LightOJ-1205)↵
↵
Example code:↵
↵
<spoiler summary="code">↵
~~~~~↵
// i - position, l - leftmostlower, h - leftmosthigher, ze - numbers of leading zeros↵
ll dp(int i, int l, int h, int ze) {↵
// imagine it as n - i - 1, and plus the offset of leading zeros↵
// i is already offset by leading zeros↵
int j = n - i - 1 + ze;↵
↵
if(i > j) return l <= h;↵
if(vis[i][l][h][ze] == cur) return mem[i][l][h][ze];↵
vis[i][l][h][ze] = cur;↵
↵
ll res = 0;↵
for(int d = 0; d <= 9; d++) {↵
int nl = l;↵
int nh = h;↵
↵
if(d < v[i] && i < nl) nl = i;↵
if(d < v[j] && j < nl) nl = j;↵
if(d > v[i] && i < nh) nh = i;↵
if(d > v[j] && j < nh) nh = j;↵
↵
res += dp(i + 1, nl, nh, ze + (i == ze && d == 0));↵
}↵
↵
return mem[i][l][h][ze] = res;↵
}↵
~~~~~↵
</spoiler>↵
↵
---↵
↵
Feel free to share any tricks or anything that people should know when doing digit dp!↵
If there is any mistakes or suggestions, please let me know.