Help needed! Python code optimisation for USACO Silver: Wormhole Sort

Revision en3, by drugkeeper, 2022-11-29 11:00:43

I am requesting for help from python experts (e.g: pajenegod) regarding this issue.

Article: https://usaco.guide/general/choosing-lang

Problem: http://www.usaco.org/index.php?page=viewproblem2&cpid=992

Over here in this article, they mentioned that: "A comparable Python solution only passes the first five test cases:", "After some optimizations, this solution still passes only the first seven test cases:", "We are not sure whether it is possible to modify the above approach to receive full credit (please let us know if you manage to succeed)! Of course, it is possible to pass this problem in Python using DSU (a Gold topic):"

So I went ahead to try to optimise the approach (Binary Search with DFS) but I could only get 9/10 testcases to pass reliably. Very rarely, I will have 10/10 testcases passed, with the 10th testcase having ~3990 ms. I wonder if it is possible to get 10/10 testcases to pass reliably?

I have tried many approaches, including speeding up IO, using list comprehensions whenever possible instead of for loops, using bitwise operators to avoid slow tuple sorting.

I have also profiled my code and found that the valid() function is the one that is the bottleneck.

Here is the code:

from operator import methodcaller

def main():
	lines = open("wormsort.in","rb").readlines()
	n,m = map(int,lines[0].split())
	loc = [*map(int,lines[1].split())]
	edges = [[] for _ in range(n)]
	lo,hi,mask = 0,m,0b11111111111111111

	def valid(loc, mid):
		component = [-1]*n
		numcomps = 0
		for i in range(n):
			if component[i] < 0:
				todo = [i]
				component[i] = numcomps
				while todo:
					for child in [x[0] for x in edges[todo.pop()] if component[x[0]] < 0 and x[1] < mid]:
						component[child] = numcomps
						todo.append(child)
				numcomps += 1
			if component[i] != component[loc[i] - 1]:
				return False
		return True

	# bitwise to avoid tuple sort
	all_edges = [*map(lambda x: int(x[2]) << 34 ^ int(x[0]) << 17 ^ int(x[1]), 
	 				map(methodcaller("split", b" "), lines[2:]))]
	all_edges.sort(reverse=True)

	for i, val in enumerate(all_edges):
		rhs = (val & mask) - 1
		lhs = ((val >> 17) & mask) - 1
		edges[lhs].append((rhs,i))
		edges[rhs].append((lhs,i))

	while lo != hi:
		mid = (lo + hi) // 2
		if valid(loc, mid):
			hi = mid
		else:
			lo = mid+1

	open("wormsort.out","w").write(f"{-1 if lo == 0 else all_edges[lo-1] >> 34}\n")

main()

Any tips or advice on how to speed this up would be greatly appreciated. Thank you!

Tags python, usaco, optimization, pypy

History

 
 
 
 
Revisions
 
 
  Rev. Lang. By When Δ Comment
en3 English drugkeeper 2022-11-29 11:00:43 131 Tiny change: 'ng-lang)\nProblem:' -> 'ng-lang)\n\nProblem:'
en2 English drugkeeper 2022-11-29 10:58:33 0 (published)
en1 English drugkeeper 2022-11-29 10:57:56 2614 Initial revision (saved to drafts)