drugkeeper's blog

By drugkeeper, 6 months ago, In English

Introduction

Hello everyone, I have recieved some messages asking how to improve python code to avoid TLE. There have been blogs before but they are rather old, so I will share some tips, hope it helps! These are the tips that I have learnt over the years and some of them may not be well known to non-python mains.

Do let me know in the comments if you have more tips as well.

Essentials (Most questions can be solved using these tips)

Basically, try to make the code as "Pythonic" as possible, the more python gets to call its underlying C functions the faster and better it is.

  1. Use PyPy 3 64 bit, it is faster than python 3.
  2. For fast I/O, I use input = sys.stdin.readline, then for output, if the output has multiple lines, I just create a list to store the outputs, then print("\n".join(lst)) on the list. Because multiple print statements are slow in python.
  3. Put all your code into functions like a main function, as local variables lookup in python is faster than global variables, I think your code will run around ~30% faster just by doing this.
  4. List creation: [0]*n is fastest, list comprehension is also fast. For 2d lists: I do: [[0]*m for _ in range(n)], as [[0]*m]*n will not work due to reference issues. Also, .extend() is faster than .append() if you are adding multiple elements to the list at a time.
  5. Use for i in arr instead of for i in range(n): arr[i] as it is faster. If you need the index as well as the data, use enumerate() as it is faster. Also for x, y in arr: to iterate over multiple elements.
  6. Use built-in methods like map() and sum() as they are faster than implementing the stuff yourself. The more Pythonic the better.
  7. For deep recursion, use an explicit stack instead of a recursive function. Recursion in python is slow and memory intensive.
  8. Try to keep your code as simple as possible especially in the for loops, to reduce the amount of computation needed and also so that PyPy can optimise your code easily.
  9. String concatenation in python is slow in a for loop as new intermediate strings are always being created. To solve this, use a list and .append() (or extend()), at the end call "".join(lst)
  10. Try to not change the data type of variables and collections. Collections should store a single data type. For example PyPy has special dict strategies to speed up dicts with keys of all same types. I believe for lists and other collections there will be some optimisations as well, as python can predict the type better and reduce overhead from mutating types.

🦀-Python

As mentioned in this comment: https://mirror.codeforces.com/blog/entry/106541?#comment-949024

Non-Pythonic ways to improve python performance. You won't need to use some of these unless the time limits are tight.

  1. Flattening 2d lists into 1d: This TLE: https://cses.fi/paste/f6db6ab935e23d2072b133/ This passes: https://cses.fi/paste/c7fa7fcb039aa28f72b17f/
  2. Packing tuple into ints to sort, then unpacking as sorting tuples are slow: This TLE: https://cses.fi/paste/972de727f25b8f1171102a/ This passes: https://cses.fi/paste/25ce406e0f776551711046/
  3. Iterating over multiple elements using zip: This TLE: https://cses.fi/paste/1c4afcf48c0400b071e0f8/ This passes: https://cses.fi/paste/85c4f40b683a6eec71e10d/
  4. In some cases deque() can be faster than list, at least for some questions I've done in CSES graphs.

Other useful tips:

  1. Learn your standard library properly (take a look at the documentations), Python has some very useful functions and classes, for example I sometimes use Decimal class to deal with floats of higher precision (but take note the precision cannot be too high else TLE). Counter and bisect are nice to use as well.
  2. Pyrival is good, has lots of algorithms and data structures, lots of useful stuff there like SortedList() and prime factors.
  3. To avoid integer hash hacking TLE for set() and dict(), you can random.shuffle() the array if order is not important, otherwise you can hash the str(int) or use random int xor to wrap the ints.

Full text and comments »

  • Vote: I like it
  • +198
  • Vote: I do not like it

By drugkeeper, history, 6 months ago, In English

Blue

Full text and comments »

  • Vote: I like it
  • +144
  • Vote: I do not like it

By drugkeeper, history, 9 months ago, In English
  • Vote: I like it
  • +21
  • Vote: I do not like it

By drugkeeper, 12 months ago, In English

I am not sure how to find an efficient solution to this problem that I just thought of:

Problem: Given an unsorted array of length n of integers (each integer is up to 10^9), you need to perform q queries. Each query is of the form (l, r, v1, v2), where you need to count the number of elements in the array from index l to r, having a value of v1 <= x <= v2.

Constraints: 1 <= l, r, n, q <= 10^5, 1 <= v1 <= v2 <= 10^9.

Example:

  1. Array is [4, 2, 3, 1, 5, 6], we have 2 queries: (1, 4, 2, 3) and (1, 6, 4, 5)
  2. For the first query, output 2, because we have 2 elements in the subarray [4, 2, 3, 1] with a value from 2 to 3.
  3. For the next query, output 2, our subarray is [4, 2, 3, 1, 5, 6], we have 2 elements with a value from 4 to 5.

Is there a fast way to solve this problem?

For v1 == v2 (aka, if we are just counting the number of elements from l to r with count of v), we can just make 1 dictionary for each element, every 2 elements, every 4 elements and so on like a segment tree of dictionaries, then when we query l and r we can find the number of elements in that range.

However if v1 != v2 I do not know of an efficent way, please help!

Full text and comments »

  • Vote: I like it
  • +14
  • Vote: I do not like it

By drugkeeper, history, 12 months ago, In English

I was doing AtCoder Educational DP Contest, Knapsack 1: https://atcoder.jp/contests/dp/tasks/dp_d

I wanted to loop through capacity instead of looping through each item in the array. All existing solutions that I have found had always looped through each item as the outermost loop. But what if I want to loop through each capacity as the outermost loop? (It is kind of impractical and I tunnel-visioned, but I digress).

Here is the solution I came up with:

My dp is two dimensional, with the first item of dp[i] being the max value, and the second item storing all the remaining items that we have not used, as a list of tuples. For every capacity we get the remaining items for that capacity, and try to use them. Now we just need to update the states for dp[i + wj] if its larger. The remaining items for dp[i + wj] will just be remaining items of dp[i], with that item removed.

Here is the code, which TLEs at the 7th testcase
Here is the optimised code that passes all testcases:

The main change was to seperate the dp array into dp1 and dp2, so that we can avoid unnecessary multidimensional array access. I believe this can still be further optimised but this is good enough to pass all testcases.

Thanks for reading!

Full text and comments »

  • Vote: I like it
  • +1
  • Vote: I do not like it

By drugkeeper, history, 12 months ago, In English

I was doing G1 (https://mirror.codeforces.com/contest/1822/problem/G1) and came up with this solution which is O(n * sqrt(m)) as stated in the editorial. This TLEs in testcase 13 due to constant factors.

TLE Code 1

I went to further optimise my code as shown here, which TLEs in testcase 16:

TLE Code 2

Main changes:

  1. Use faster output
  2. Use Counter instead of array to count, which is faster (this helped the most, and got me past testcase 13). In another solution, I initialised the count array at the very start instead of for each testcase which helped me pass everything.
  3. Got rid of unnecessary variables, if else checks and array / counter access.
  4. Use for loop instead of while loop

However, this still TLEs due to Testcase 16, which is a hash hack for Counter. If I sort the input beforehand, my code passes!

Code that works

Note that Python is too slow to pass all testcases (it fails at testcase 13), but PyPy works.

I have a few questions for the people well versed in python:

  1. Why does sorting the input help get rid of the Counter hash hack testcase TLE?
  2. Is it possible to make Python pass all testcases? (Edit: I found a submission that works in python https://mirror.codeforces.com/contest/1822/submission/203361435)

I hope you can answer my questions, if not, I hope you have learnt something from my blog. Thank you!

Full text and comments »

  • Vote: I like it
  • +8
  • Vote: I do not like it

By drugkeeper, history, 17 months ago, In English

I am requesting for help from python experts (e.g: pajenegod) regarding this issue.

Article: https://usaco.guide/general/choosing-lang

Problem: http://www.usaco.org/index.php?page=viewproblem2&cpid=992

Over here in this article, they mentioned that: "A comparable Python solution only passes the first five test cases:", "After some optimizations, this solution still passes only the first seven test cases:", "We are not sure whether it is possible to modify the above approach to receive full credit (please let us know if you manage to succeed)! Of course, it is possible to pass this problem in Python using DSU (a Gold topic):"

So I went ahead to try to optimise the approach (Binary Search with DFS) but I could only get 9/10 testcases to pass reliably. Very rarely, I will have 10/10 testcases passed, with the 10th testcase having ~3990 ms. I wonder if it is possible to get 10/10 testcases to pass reliably?

I have tried many approaches, including speeding up IO, using list comprehensions whenever possible instead of for loops, using bitwise operators to avoid slow tuple sorting.

I have also profiled my code and found that the valid() function is the one that is the bottleneck.

Here is the code:

from operator import methodcaller

def main():
	lines = open("wormsort.in","rb").readlines()
	n,m = map(int,lines[0].split())
	loc = [*map(int,lines[1].split())]
	edges = [[] for _ in range(n)]
	lo,hi,mask = 0,m,0b11111111111111111

	def valid(loc, mid):
		component = [-1]*n
		numcomps = 0
		for i in range(n):
			if component[i] < 0:
				todo = [i]
				component[i] = numcomps
				while todo:
					for child in [x[0] for x in edges[todo.pop()] if component[x[0]] < 0 and x[1] < mid]:
						component[child] = numcomps
						todo.append(child)
				numcomps += 1
			if component[i] != component[loc[i] - 1]:
				return False
		return True

	# bitwise to avoid tuple sort
	all_edges = [*map(lambda x: int(x[2]) << 34 ^ int(x[0]) << 17 ^ int(x[1]), 
	 				map(methodcaller("split", b" "), lines[2:]))]
	all_edges.sort(reverse=True)

	for i, val in enumerate(all_edges):
		rhs = (val & mask) - 1
		lhs = ((val >> 17) & mask) - 1
		edges[lhs].append((rhs,i))
		edges[rhs].append((lhs,i))

	while lo != hi:
		mid = (lo + hi) // 2
		if valid(loc, mid):
			hi = mid
		else:
			lo = mid+1

	open("wormsort.out","w").write(f"{-1 if lo == 0 else all_edges[lo-1] >> 34}\n")

main()

Any tips or advice on how to speed this up would be greatly appreciated. Thank you!

Full text and comments »

  • Vote: I like it
  • +3
  • Vote: I do not like it