Python Optimization Tricks in Competitive Programming
For better reading experience in Chinese, please refer to: LeetCode.
Preface
This article aims to introduce tricks to optimize Python programs and reduce runtime in competitive programming, maintaining complexity correct and unchanged. The tricks described below are evaluated in three dimensions: the complexity of modification, the significance of optimization, and the usefulness.
It is worth noting that due to Python's inherent nature, no matter how well it is optimized, it cannot pass problems with large data range and high complexity[^1]. Additionally, the time spent on coding Python and optimizing it afterward is not necessarily faster than directly using C++. Therefore, although Python has some fantastic language features, if you want to engage in professional competitive programming, it might be better to leave the obsession in Python and learn C++ instead.
For the following reasons, the content of this article is mostly empirical and may have some imprecise points: the inner principles are not fully clear; comparative experiments are required for verification; there may be differences in performance between CPython/PyPy and multiple versions of Python. Correcting any mistake would be appreciated.
I/O
【Simple, Significant, Useful】 I/O is one of the time-consuming bottlenecks. When there are many input lines, using standard input sys.stdin is far better than using input.
import sys
input = lambda: sys.stdin.readline().rstrip() # Remove the newline character at the end of the line
II = lambda: int(input())
LII = lambda: list(map(int, input().split()))
【Simple, Significant, Useful】 If you want to avoid frequently reading from input, you can also read all input into memory only once.
import sys
it = map(int, sys.stdin.read().split())
II = lambda: next(it)
# If the input contains strings, it can be modified to
# it = iter(sys.stdin.read().split())
# SI = lambda: next(it)
# II = lambda: int(SI())
【Simple, Insignificant, Useful】 Similarly, to avoid frequent output, you can temporarily store all results and then output them alltogether. Personally, I prefer using print(*output) instead of print(' '.join(map(str, output))). sys.stdout may be faster but more difficult.
output = []
for _ in range(n):
ans = solve()
output.append(ans)
print(*output, sep='\n')
【Complex】 The implementation of IOWrapper is too cumbersome, so it has not been tested.
Data Types
int
【Simple, Insignificant, Useful】 Since Python's int has no length limit, and division operations are more time-consuming than addition and multiplication, in problems that require modulo operations, if the intermediate result is not too large (around __int128 data range), you can calculate directly and take the modulo once at the end.
# precalculation
comb = lambda m, n: fac[m] * ifac[n] * ifac[m-n] % MOD
# Before
ans = 0
for i in range(n):
ans = (ans + comb(n, i) * pow(2, i, MOD) % MOD) % MOD
# After
ans = 0
for i in range(n):
ans += comb(n, i) * pow(2, i, MOD)
ans %= MOD
【Simple, Insignificant, Useful】 float('inf') is a floating-point number, and comparison operations are slower, so try to use a large integer as infinity.
# Before
from math import inf
# Or
inf = float('inf')
# After
inf = 1 << 60
dis = [inf] * n
str
【Simple, Significant, Useful】 The complexity of string += concatenation is $$$O(n)$$$ because strings are immutable objects, and space needs to be reallocated for the concatenated string.
# Before
ans = ''
for s in strs:
ans += s
# After
ans = ''.join(strs)
【Simple, Insignificant, Useless】 Using bytearray(s, encoding) can create a mutable string similar to C++.
# Before
t = list(s)
t[0] = 'a'
s = ''.join(t)
# After
t = bytearray(s, encoding='ascii')
t[0] = ord('a')
s = t.decode('ascii')
list
【Simple, Insignificant, Useful】 When traversing a list, using enumerate to get both the index and the value is more efficient than getting the index first and then the value.
# Before
for i in range(len(nums)):
x = nums[i]
...
# After
for i, x in enumerate(nums):
...
【Simple, Insignificant, Useless】 Allocate space for list in advance to avoid frequent list.append which modifies the length and reallocates space.
# Before
nums = []
for i in range(n):
nums.append(i)
# After
nums = [0] * n
for i in range(n):
nums[i] = i
【Simple, Significant, Useful】 For multi-dimensional list, place list with larger size inside. On the one hand, it is more friendly in terms of cache continuity; on the other hand, it reduces the overhead of creating list.
n, k = 10**5, 20
# Before
dp = [[0] * k for _ in range(n)]
# After
dp = [[0] * n for _ in range(k)]
【Simple, Significant, Useful】 Sometimes, the indices of a multi-dimensional list can be compressed to one dimension.
# Before
dp = [[0] * n for _ in range(m)]
# After
dp = [0] * (m*n)
compress = lambda i, j: i*n+j
decompress = lambda k: divmod(k, n)
【Complex, Significant, Useful】 Use a chained forward star instead of an adjacency list to build a graph. This can reduce the overhead of the adjacency list on n lists.
# Before
g = [[] for _ in range(n)]
def add_edge(u: int, v: int, w: int):
g[u].append((v, w))
# After
head = [-1] * n
to = [-1] * m
weight = [0] * m
nxt = [-1] * m
ptr = 0
def add_edge(u: int, v: int, w: int):
nonlocal ptr
to[ptr] = v
weight[ptr] = w
nxt[ptr] = head[u]
head[u] = ptr
ptr += 1
【Simple, Significant, Useful】 Use array.array instead of list. When the list is of fixed length, all elements are integers, and frequent access and modification are needed, the space and time optimization is obvious.
# Before
nums = [0] * n
# After
from array import array
nums = array('i', [0] * n)
【Simple, Significant, Useless】 Following the preceding trick, bytearray can be used instead of a boolean array, with obvious optimization.
# Before
vis = [False] * n
# After
from array import array
vis = bytearray(bytes(n))
【Simple, significant, Useless】 Moreover, you can also use C-type arrays from ctypes instead of list.
# Before
rank = [0] * n
pa = list(range(n))
# After
from ctypes import c_int32
rank = (c_int32 * n)()
pa = (c_int32 * n)(*range(n))
tuple
【Simple, Insignificant, Useless】 Instead of combining multiple fields into a tuple and then putting it into a list, use multiple lists to store multiple fields separately. For example:
# Before
items = [(w1, v1), (w2, v2), ...]
# After
weights = [w1, w2, ...]
values = [v1, v2, ...]
dict
【Simple, Significant, Useful】 Accessing a hash table by key is far slower than accessing an array by index. Therefore, if the data range allows, use list instead of dict as much as possible. The same works for set.
# Before
g = defaultdict(list)
# After
g = [[] for _ in range(n)]
【Simple, Insignificant, Useful】 When traversing a dict, using dict.items to get both keys and values is more efficient than getting the key first and then accessing the value.
# Before
for k in mp:
v = mp[k]
...
# After
for k, v in mp.items():
...
【Simple, Insignificant, Useless】 When clearing a dictionary, dict.clear is not as efficient as directly creating a new dictionary. The same works for set. Trust the garbage collection mechanism.
# Before
mp.clear()
# After
mp = {}
【Simple, Significant, Useful】 For simple counting, use defaultdict(int) instead of Counter.
# Before
cnt = Counter()
# After
cnt = defaultdict(int)
【Simple, Insignificant, Useful】 For defaultdict, when trying to access a non-existent key, use the defaultdict.get method instead of directly using defaultdict.__getitem__ to avoid inserting unnecessary keys.
mp = defaultdict(int)
# Before
x = mp[k]
# After
x = mp.get(k, 0)
【Simple, Significant, Useful】 In CodeForces, there are often problems that counter hash map using the chaining method, and construct hash collision cases leading to TLE. If the numbers are irrelevant to their positions, you can consider simply shuffling them; if the numbers are relevant to their positions, you can insert the original numbers after adding/subtracting or XORing them with a random number.
# position irrelevant
from random import shuffle
shuffle(nums)
cnt = defaultdict(int)
for x in nums:
cnt[x] += 1
# position relevant
from random import getrandbits
RD = getrandbits(31)
pos = defaultdict(list)
for i, x in enumerate(nums):
pos[x ^ RD].append(i)
【Simple, Significant, Useful】 For cases that are only related to the same numbers but not to their values, discretization can be done first. How to discretize faster? My most commonly used writing method is hashset+sorting+hashmap. Apart from using a set, the deduplication method can also imitate C++'s std::unique, and the mapping method can also use binary search besides hash table. Surprisingly, these two new methods are not as fast as the above writing method.
# Using a set
sarr = sorted(set(nums))
# Using a hash table
mp = {x: i for i, x in enumerate(sarr)}
nums = [mp[x] for x in nums]
# Using sorting and two pointers
def unique(nums):
sarr = sorted(nums)
ptr = 0
for x in sarr:
if x != sarr[ptr]:
ptr += 1
sarr[ptr] = x
del sarr[ptr+1:]
return sarr
sarr = unique(nums)
# Using binary search
from bisect import bisect_left
nums = [bisect_left(sarr, x) for x in nums]
deque
【Significant, Useless】 The efficiency of deque is not as high as that of an array-simulated queue.
# Before
q = deque()
while q:
u = q.popleft()
for v in g[u]:
q.append(v)
# After
q = [0] * n
head, tail = 0, 1
while head < tail:
u = q[head]
head += 1
for v in g[u]:
q[tail] = v
tail += 1
【Significant, Useless】 deque can access elements by index, but its implementation is a chunked linked list, and the complexity is $$$O(n/B)$$$, where $$$B=64$$$.
# Before
q = deque()
...
val = q[pos]
# After
q = [0] * n
head, tail = 0, 1
...
val = q[pos + head]
Functions
【Simple, Insignificant, Useful】 Due to the implementation at the C level, built-in functions and library functions are usually more efficient than handwritten ones, but there could be exceptions.
# Before
pres = [0] * (n+1)
for i, x in enumerate(nums):
pres[i+1] = pres[i] + x
# After
from itertools import accumulate
pres = list(accumulate(nums, initial=0))
【Simple, Significant, Useful】 Manually write min and max for comparing two numbers to avoid additional overhead such as type checking.
# Before
x, y = min(x, y), max(x, y)
# After
fmin = lambda x, y: x if x < y else y
fmax = lambda x, y: x if x > y else y
x, y = fmin(x, y), fmax(x, y)
【Simple, Significant, Useful】 Manually write pow for fast exponentiation to avoid additional overhead such as type checking. In addition, pow(base, exp, mod) implements the inverse element based on the extended Euclidean algorithm when exp is negative, and its efficiency is also slightly lower than the handwritten fast exponentiation.
MOD = 10**9+7
# Before
inv = pow(x, -1, MOD)
# After
def qpow(x, k):
res = 1
while k:
if k & 1:
res = res * x % MOD
x = x * x % MOD
k >>= 1
return res
inv = qpow(x, MOD-2)
【Simple, insignificant, useful】 For functions that accept an iterable as parameter, the passed iterable can be a generator may not be necessary to create an object like a list.
# Before
s = sum([x**2 for x in range(n)])
# After
s = sum(x**2 for x in range(n))
【Simple, Significant, Useless】 Do not use sum(lsts, []) to concatenate multiple lists, similarly do not use sum(strs, '') to concatenate multiple strings, as this will repeatedly create new objects.
# Before
longlist = sum(lsts, [])
longstr = sum(strs, '')
# After
longlist = []
for lst in lsts:
longlist.extend(lst)
longstr = ''.join(strs)
【Simple, Insignificant, Useless】 Since local variables locals are accessed faster than global variables globals, you can consider encapsulating the main program in a function body.
def main():
...
main()
【Simple, Insignificant, Useful】 In some scenarios, yield can be used to turn a function into a generator, generating results one by one instead of temporarily storing them and then outputting all results uniformly.
# Before
def all_subsets(mask):
subs = []
cur = mask
while cur:
subs.append(cur)
cur = (cur - 1) & mask
return subs
# After
def all_subsets(mask):
cur = mask
while cur:
yield cur
cur = (cur - 1) & mask
【Complex, Insignificant, Useful】 In memoization search, do not use functools.cache. The principle of @cache is roughly: pack the function parameters into a tuple, hash them, and use a hash table and a doubly linked list to save the results based on LRU algorithm. The entire process is too time-consuming.
# Before
@cache
def dp(i: int, pre: int, islim: bool, isnum: bool) -> int:
...
# After
memo = [[-1] * 10 for _ in range(n)]
def dp(i: int, pre: int, islim: bool, isnum: bool) -> int:
if not islim and not isnum and memo[i][pre] != -1:
return memo[i][pre]
...
【Complex, Significant, Useful】 For recursive functions, due to the large overhead of function frames, functions with large recursive depth usually cannot be used directly. Adjusting the recursion stack depth through sys.setrecursionlimit is not entirely feasible. A more general approach is to use an infinite recursion decorator[^2]. What's more, for recursive functions that are not too complex, they can also be rewritten with iteration. For example, using a stack to simulate the recursion process and perform a pre-order traversal of a tree stored in an adjacency list:
# Before
def dfs(u: int, pa: int):
for v in tree[u]:
if v != pa:
...
dfs(v, u)
# After
order = []
parents = [-1] * len(tree)
stk = [root]
while stk:
u = stk.pop()
order.append(u)
for v in g[u]:
if parents[u] != v:
parents[v] = u
stk.append(v)
...
And the post-order traversal can be obtained by reversing the pre-order traversal result.
【Complex, Significant, Useful】 Use an iterative segment tree[^4]. It is suitable for most segment tree problems and has significantly better performance than the recursive segment tree.
# Single-point modification for maximum subarray sum
def op(lnode: Tuple[int], rnode: Tuple[int]):
llmx, lrmx, lsum, lres = lnode
rlmx, rrmx, rsum, rres = rnode
return max(llmx, lsum + rlmx), max(rrmx, rsum + lrmx), lsum + rsum, max(lres, rres, lrmx + rlmx)
e = (0, 0, 0, -inf)
seg = SegmentTree([(x, x, x, max(x, 0)) for x in nums], e, op)
Classes
【Complex, Significant, Useful】 In a word, class instances are too slow. Tree data structures always have a way to replace node classes with arrays.
# Before
class TrieNode:
def __init__(self):
self.children = [None] * 26
self.isend = False
self.cnt = 0
class Trie:
def __init__(self):
self.root = TrieNode()
...
# After
class StaticTrie:
def __init__(self, lengths):
lengths += 1
self.children = [[-1] * lengths for _ in range(26)]
self.isend = [False] * lengths
self.cnt = [0] * lengths
self.ptr = 1
...
【Simple, Significant, Useful】 Use __slots__ to record class member variable names, which can access class member variables faster[^3].
class DSU:
__slots__ = 'parent', 'size'
def __init__(self, n: int):
self.parent = list(range(n))
self.size = [1] * n
【Simple, Insignificant, Useful】 Frequently accessed member variables in a class are cached using intermediate variables.
class DSU:
def find(self, u: int):
parent = self.parent
while u != parent[u]:
u = parent[u]
return u
References
[^1]: Python performance tips. https://mirror.codeforces.com/blog/entry/21851
[^2]: PyRival. https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py
[^3]: Python Docs. https://docs.python.org/zh-cn/3.13/reference/datamodel.html#object.__slots__
[^4]: AtCoder Library Python. https://github.com/not522/ac-library-python







