Hello,
I tried problem Segment Sum. The DP solution for this problem was pretty straightforward to me. I decided to implement it in Python (You may find the python code below in this post).
I used memorization technique, and the recursive function returns two numbers, i.e. my DP table maintains two values: current summation and number of ways under the mentioned criteria in our current state. After hours of debugging, I couldn't find any problem with my code. I decided to remove if dp[smaller][start][pos][mask][0] != -1: return dp[smaller][start][pos][mask]
to see if there is any issue with the DP table. Surprisingly, it output correct results when I removed these two lines. It seems there is something wrong with returning tuples (or array of size 2 here) from a recursive function in Python.
To make sure that the method is correct, I reimplemented it in C++ and it got Accepted, as expected: 60403673. Could you please help me fix the issue in Python?
Python code:
import sys
mod = 998244353
MAX_LENGTH = 20
bound = [0] * MAX_LENGTH
def mul(a, b): return (a * b) % mod
def add(a, b):
a += b
if a < 0: a += mod
if a >= mod: a -= mod
return a
def digitize(num):
for i in range(MAX_LENGTH):
bound[i] = num % 10
num //= 10
def rec(smaller, start, pos, mask):
global k
if bit_count[mask] > k:
return [0, 0]
if pos == -1:
return [0, 1]
# if the two following lines are removed, the code reutrns correct results
if dp[smaller][start][pos][mask][0] != -1:
return dp[smaller][start][pos][mask]
res_sum = res_ways = 0
for digit in range(0, 10):
if smaller == 0 and digit > bound[pos]:
continue
new_smaller = smaller | (digit < bound[pos])
new_start = start | (digit > 0) | (pos == 0)
new_mask = (mask | (1 << digit)) if new_start == 1 else 0
cur_sum, cur_ways = rec(new_smaller, new_start, pos - 1, new_mask)
res_sum = add(res_sum, add(mul(mul(digit, ten_pow[pos]), cur_ways), cur_sum))
res_ways = add(res_ways, cur_ways)
dp[smaller][start][pos][mask][0], dp[smaller][start][pos][mask][1] = res_sum, res_ways
return dp[smaller][start][pos][mask]
def solve(upper_bound):
global dp
dp = 2 * [2 * [MAX_LENGTH * [(1 << 10) * [[-1, -1]]]]]
digitize(upper_bound)
ans = rec(0, 0, MAX_LENGTH - 1, 0)
print(ans)
return ans[0]
inp = [int(x) for x in sys.stdin.read().split()]
l, r, k = inp[0], inp[1], inp[2]
bit_count = [0] * (1 << 10)
for i in range(1, 1 << 10): bit_count[i] = bit_count[i & (i - 1)] + 1
ten_pow = [(10 ** i) % mod for i in range(0, MAX_LENGTH)]
print(add(solve(r), -solve(l - 1)))
In your
solve
function, the*
operator only copies the references to the lists. So:This might be the problem here. You should use something like
A = [[0, 0] for i in range(2)]
(orxrange
in python2) instead.There's also the thing that nested lists in Python are slow, so it's best not to use them. For example, a 4D list with shape $$$(256, 256, 256, 2)$$$ — 33 million elements — takes about 5 seconds to create for me (Python 3), while a 1D version with
[-1] * (2*256**3)
takes so little time I'd get MLE before TLE with that. When numpy arrays are an option (unfortunately not here), numpy arrays are the way to go. Otherwise, 1D list all the way. Regular nested lists are only good when sizes vary and that always has a performance cost, just likeArrayList
in Java.Yes! That was the problem. Thank you very much.