drugkeeper's blog

By drugkeeper, history, 17 months ago, In English

I am requesting for help from python experts (e.g: pajenegod) regarding this issue.

Article: https://usaco.guide/general/choosing-lang

Problem: http://www.usaco.org/index.php?page=viewproblem2&cpid=992

Over here in this article, they mentioned that: "A comparable Python solution only passes the first five test cases:", "After some optimizations, this solution still passes only the first seven test cases:", "We are not sure whether it is possible to modify the above approach to receive full credit (please let us know if you manage to succeed)! Of course, it is possible to pass this problem in Python using DSU (a Gold topic):"

So I went ahead to try to optimise the approach (Binary Search with DFS) but I could only get 9/10 testcases to pass reliably. Very rarely, I will have 10/10 testcases passed, with the 10th testcase having ~3990 ms. I wonder if it is possible to get 10/10 testcases to pass reliably?

I have tried many approaches, including speeding up IO, using list comprehensions whenever possible instead of for loops, using bitwise operators to avoid slow tuple sorting.

I have also profiled my code and found that the valid() function is the one that is the bottleneck.

Here is the code:

from operator import methodcaller

def main():
	lines = open("wormsort.in","rb").readlines()
	n,m = map(int,lines[0].split())
	loc = [*map(int,lines[1].split())]
	edges = [[] for _ in range(n)]
	lo,hi,mask = 0,m,0b11111111111111111

	def valid(loc, mid):
		component = [-1]*n
		numcomps = 0
		for i in range(n):
			if component[i] < 0:
				todo = [i]
				component[i] = numcomps
				while todo:
					for child in [x[0] for x in edges[todo.pop()] if component[x[0]] < 0 and x[1] < mid]:
						component[child] = numcomps
						todo.append(child)
				numcomps += 1
			if component[i] != component[loc[i] - 1]:
				return False
		return True

	# bitwise to avoid tuple sort
	all_edges = [*map(lambda x: int(x[2]) << 34 ^ int(x[0]) << 17 ^ int(x[1]), 
	 				map(methodcaller("split", b" "), lines[2:]))]
	all_edges.sort(reverse=True)

	for i, val in enumerate(all_edges):
		rhs = (val & mask) - 1
		lhs = ((val >> 17) & mask) - 1
		edges[lhs].append((rhs,i))
		edges[rhs].append((lhs,i))

	while lo != hi:
		mid = (lo + hi) // 2
		if valid(loc, mid):
			hi = mid
		else:
			lo = mid+1

	open("wormsort.out","w").write(f"{-1 if lo == 0 else all_edges[lo-1] >> 34}\n")

main()

Any tips or advice on how to speed this up would be greatly appreciated. Thank you!

  • Vote: I like it
  • +3
  • Vote: I do not like it

»
17 months ago, # |
  Vote: I like it 0 Vote: I do not like it

Auto comment: topic has been updated by drugkeeper (previous revision, new revision, compare).

»
17 months ago, # |
  Vote: I like it 0 Vote: I do not like it

Auto comment: topic has been updated by drugkeeper (previous revision, new revision, compare).

»
17 months ago, # |
  Vote: I like it +3 Vote: I do not like it

There are tons of optimizations you can do. For one your algorithm doesn't really need the edges to be sorted. It just needs the edge weights to be sorted. So there is no reason to do the fancy all_edges sort.

About your graph traversal. You say you are doing a DFS, but what you are doing is not a DFS. You are doing some weird mix between a DFS and a BFS. What you are doing still works, but it is not a DFS.

My prefered way to do a graph traversal is to do a BFS like this:

todo = [i]
component[i] = numcomps
for node in todo:
    for child,index in edges[node]:
        if component[child] == -1 and index < mid:
            component[child] = numcomps
            todo.append(child)

I think that changing your graph traversal to this could speed it up significantly.

  • »
    »
    17 months ago, # ^ |
    Rev. 2   Vote: I like it 0 Vote: I do not like it

    Hi, thank you so much for the advice! I actually tried tuples before and thought that it did not really affect the performance.

    I have changed my code to do tuple sort and added your graph traversal:

    	all_edges = [*map(lambda x: (int(x[2]), int(x[0]), int(x[1])), 
    	  				map(methodcaller("split", b" "), lines[2:]))]
    	all_edges.sort(key=itemgetter(0), reverse=True)
    
    	for i, val in enumerate(all_edges):
    		_, lhs, rhs = val
    		edges[lhs-1].append((rhs-1,i))
    		edges[rhs-1].append((lhs-1,i))
    
    	def valid(loc, mid):
    		component = [-1]*n
    		numcomps = 0
    		for i in range(n):
    			if component[i] == -1:
    				todo = [i]
    				component[i] = numcomps
    				for node in todo:
    					for child,index in edges[node]:
    						if component[child] == -1 and index < mid:
    							component[child] = numcomps
    							todo.append(child)
    				numcomps += 1
    			if component[i] != component[loc[i] - 1]:
    				return False
    		return True
    

    It is faster now, and could pass the last testcase more frequently, but still it fails more than 50% of the time. May I ask for more optimisation tips? Thank you!

    Also, correct me if I'm wrong, your code for graph traversal is faster because:

    1. uses a for loop instead of while loop at the top, which is faster`
    2. does not need to call todo.pop(), saving time from a function call`
    3. == -1 is faster than < 0 check?`

    I actually tried many ways to optimise this graph traversal (making a filtered list, calling todo.extend() instead of append, using list comprehension to assign the components, etc etc etc), but yours seem to be the fastest.

    • »
      »
      »
      17 months ago, # ^ |
      Rev. 3   Vote: I like it 0 Vote: I do not like it

      Here is the cProfile timing data (called on a randomly generated testcase), along with my code:

      Profiling Data:

               3000034 function calls in 3.803 seconds
      
         Ordered by: standard name
      
         ncalls  tottime  percall  cumtime  percall filename:lineno(function)
              1    0.000    0.000    0.000    0.000 codecs.py:186(__init__)
              1    0.000    0.000    0.000    0.000 cp1252.py:18(encode)
              1    0.756    0.756    3.803    3.803 wormsort.py:1(<module>)
             19    0.049    0.003    0.049    0.003 wormsort.py:11(valid)
        1000000    0.422    0.000    0.422    0.000 wormsort.py:30(<lambda>)
              1    1.579    1.579    3.047    3.047 wormsort.py:4(main)
              1    0.373    0.373    0.373    0.373 wormsort.py:8(<listcomp>)
              1    0.000    0.000    0.000    0.000 {built-in method _codecs.charmap_encode}
              1    0.000    0.000    3.803    3.803 {built-in method builtins.exec}
              2    0.005    0.003    0.005    0.003 {built-in method io.open}
        2000000    0.173    0.000    0.173    0.000 {method 'append' of 'list' objects}
              1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
              1    0.061    0.061    0.061    0.061 {method 'readlines' of '_io._IOBase' objects}
              1    0.348    0.348    0.348    0.348 {method 'sort' of 'list' objects}
              2    0.037    0.018    0.037    0.018 {method 'split' of 'bytes' objects}
              1    0.000    0.000    0.000    0.000 {method 'write' of '_io.TextIOWrapper' objects}
      

      Code:

      from operator import itemgetter, methodcaller
      
      
      def main():
          lines = open("wormsort.in", "rb").readlines()
          n, m = map(int, lines[0].split())
          loc = [*map(int, lines[1].split())]
          edges = [[] for _ in range(n)]
          lo, hi = 0, m
      
          def valid(loc, mid):
              component = [-1] * n
              numcomps = 0
              for i in range(n):
                  if component[i] == -1:
                      todo = [i]
                      component[i] = numcomps
                      for node in todo:
                          for child, index in edges[node]:
                              if component[child] == -1 and index < mid:
                                  component[child] = numcomps
                                  todo.append(child)
                      numcomps += 1
                  if component[i] != component[loc[i] - 1]:
                      return False
              return True
      
          all_edges = [
              *map(
                  lambda x: (int(x[2]), int(x[0]), int(x[1])),
                  map(methodcaller("split", b" "), lines[2:]),
              )
          ]
          all_edges.sort(key=itemgetter(0), reverse=True)
      
          for i, val in enumerate(all_edges):
              _, lhs, rhs = val
              edges[lhs - 1].append((rhs - 1, i))
              edges[rhs - 1].append((lhs - 1, i))
      
          while lo != hi:
              mid = (lo + hi) // 2
              if valid(loc, mid):
                  hi = mid
              else:
                  lo = mid + 1
      
          open("wormsort.out", "w").write(f"{-1 if lo == 0 else all_edges[lo-1][0]}\n")
      
      
      main()
      
    • »
      »
      »
      17 months ago, # ^ |
        Vote: I like it 0 Vote: I do not like it

      Less and simpler code runs faster in Python. If you can just use a standard for loop, then go for it. As for 3., I don't think ==-1 vs <0 matters at all.

      Btw about avoiding sorting the edges. The following is what I was thinking about when I said your algorithm doesn't really need the edges to be sorted. It just needs the edge weights to be sorted


      edges = [[] for _ in range(n)] weights = [10**9 + 1] for _ in range(m): a,b,w = [int(x) for x in input().split()] edges[a - 1].append((b - 1, w)) edges[b - 1].append((a - 1, w)) weights.append(w) weights.sort() lo = 0 high = m while lo != hi: mid = (lo + hi) // 2 if valid(loc, weights[mid]): hi = mid else: lo = mid + 1 print(-1 if lo == 0 else weights[lo - 1])
      • »
        »
        »
        »
        17 months ago, # ^ |
        Rev. 2   Vote: I like it 0 Vote: I do not like it

        Thank you so much for the advice! I changed it up a bit and now it passes all the testcases reliably now. (Approx 3.7s/4s)

        Here is the code:

        def main():
        	f = open("wormsort.in","rb")
        	n,m = map(int, f.readline().split())
        	loc = [*map(int, f.readline().split())]
        	edges = [[] for _ in range(n)]
        	weights = []
        	lo,hi = 0, m - 1
        
        	def valid(loc, minW):
        		component = [-1]*n
        		numcomps = 0
        		for i in range(n):
        			if component[i] != component[loc[i] - 1]:
        				return False
        			elif component[i] == -1:
        				todo = [i]
        				component[i] = numcomps
        				for node in todo:
        					for child, weight in edges[node]:
        						if component[child] == -1 and weight >= minW:
        							component[child] = numcomps
        							todo.append(child)
        				numcomps += 1
        		return True
        
        	for line in f:
        		a,b,w = map(int, line.split())
        		edges[a - 1].append((b - 1, w))
        		edges[b - 1].append((a - 1, w))
        		weights.append(w)
        	weights.sort()
        
        	while lo != hi:
        		mid = (lo + hi) // 2
        		if valid(loc, weights[mid]):
        			lo = mid + 1
        		else:
        			hi = mid
        
        	open("wormsort.out","w").write(f"{-1 if lo == 0 else weights[lo-1]}\n")
        
        main()