# Commutative Segment Trees

## Preface

This blog is about an alternative way to implement segment trees. Knowing segment tree already should not be necessary to understand this blog, but it is mostly targeted towards people who already know segment trees but don't know how to for example create a merge sort tree that supports updates.

Pros:

- You don't need a fast merge function (which allows you to for example extend to 2D)
- Very easy to implement
- Fast

Cons:

- Merging has to be commutative
- Can't easily be turned lazy (meaning only update or query can be on a range, not both)

Disclaimer: This way of implementing segment tree may be what you think of as "a standard segment tree". But the reason I made this blog is that I talked to people who were very familiar with segment trees but did not know how to make a 2D segment tree (which is a trivial modification of this way to implement segment trees).

If you just want to look at the code, it's at the bottom of the blog.

## Range Sum, Point Add

Consider this simple example problem.

You have an array $$$a$$$ of length $$$n$$$ and have to process the following queries:

- $$$p \enspace x$$$: $$$a_p := a_p + x$$$
- $$$l \enspace r$$$: print $$$\sum_{i=l}^{r} a_i$$$

### Solution

We use a binary tree, and look at $$$\mathcal{O}(\log n)$$$ nodes when querying a range or updating a value.

**How does this work?**

Let's define $$$rootpath(p)$$$ as the set of nodes we touch when walking from the leaf $$$p$$$ to the root (the nodes we look at when doing an update in this problem),

and $$$rangecover(l, r)$$$ as the set of nodes we look at when we find the sum of values at positions in $$$[l, r]$$$.

A more formal definition of $$$rangecover(l, r)$$$ is the smallest set of nodes such that

for all $$$p$$$ in $$$[l, r]$$$, there is exactly $$$1$$$ node in $$$rangecover(l, r)$$$ that is also in $$$rootpath(p)$$$.

for all $$$p$$$ not in $$$[l, r]$$$, there are no nodes in $$$rangecover(l, r)$$$ that are also in $$$rootpath(p)$$$.

Now we can do updates like this:

```
def update(p, x):
for node in rootpath(p):
# instead of only updating the bottom layer and
# then merging up along the path,
# we just add x directly to every node on the path.
value[node] += x
def query(l, r):
result = 0
for node in rangecover(l, r):
result += value[node]
return result
```

This works intuitively from the illustrations in the spoiler tag above, but here is an argument for that it works using the definition of $$$rangecover$$$.

When we ask for the sum of a range, that is the same as asking for the sum of updates to positions within the range (ignoring the initial values, we can just assume they were 0 and update them).

Let's consider a query to $$$[l, r]$$$ and some position $$$p$$$ in $$$[l, r]$$$. Since an update to $$$p$$$ is applied to every node in $$$rootpath(p)$$$, there is also exactly one node in $$$rangecover(l, r)$$$ which the update is applied to.

## Range Add, Point Sum

This problem is the same as in the previous section, except query and update are swapped:

You have an array $$$a$$$ of length $$$n$$$ and have to process the following queries:

- $$$l \enspace r \enspace x$$$: For each $$$i$$$ in $$$[l, r]$$$, $$$a_i := a_i + x$$$
- $$$p$$$: print $$$a_p$$$

### Solution

Here we can use the same set definitions as before, but use a different definition for what value the nodes in the tree should hold.

In the previous problem we said:

For all nodes, the value in the node is the sum of $$$a_p$$$ for positions $$$p$$$ such that the node is in $$$rootpath(p)$$$.

But what if we instead say:

For all positions $$$p$$$, the sum of values in nodes in $$$rootpath(p)$$$ is $$$a_p$$$.

If this holds, queries are easy. Just sum up the values on the path to the root. But what about updates? Since $$$rangecover(l, r)$$$ contains exactly $$$1$$$ node from $$$rootpath(x)$$$ for each $$$x$$$ in the range, we can apply the update to each of those nodes, and maintain that the sum of $$$rootpath(x)$$$ is $$$a_x$$$.

In code it will look like this:

```
def update(l, r, x):
for node in rangecover(l, r):
value[node] += x
def query(p):
result = 0
for node in rootpath(p):
result += value[node]
return result
```

## Commutativity

In the two problems above, the only properties of the updates we used are commutativity and associativity.

So this means we have problems with for example assignment on a range, because assigning something to $$$x$$$ then to $$$y$$$ is not the same as assigning it to $$$y$$$ then $$$x$$$.

But in the point update / range query case, as long as merging would be commutative, we can often represent non-commutative updates as commutative updates by first examining the state of the array.

For range sum and point assign, we can do this:

## Dynamic Merge Sort Tree (Online)

The previous two problems were solvable with a standard merging segment tree as well. Now comes a problem that isn't (tell me in the comments if I'm wrong).

You have an array $$$a$$$ of length $$$n$$$ and have to process the following queries:

- $$$p \enspace x$$$: $$$a_p := x$$$
- $$$l \enspace r \enspace x$$$: print the number of positions $$$p$$$ in $$$[l, r]$$$ such thath $$$a_p < x$$$.

### Solution

If there were no updates we could use a merge sort tree on the array, and in each node binary search to find the number of positions $$$p$$$ covered by that node such that $$$a_p < x$$$.

A merge sort tree is just a regular merging segment tree where the merge operation is to weave together two sorted lists into a new sorted list.

Merging everything once takes only $$$\mathcal{O}(n \log n)$$$ time, because merging a single node takes $$$\mathcal{O}(w)$$$ time where w is how many leafs are below this node, and the sum of $$$w$$$ in a layer is $$$n$$$, and the number of layers is $$$\mathcal{O}(\log n)$$$.

But changing anything in this merge sort tree takes $$$\mathcal{O}(n)$$$ time, since even though only $$$\mathcal{O}(\log n)$$$ nodes need to be visited, some of them have big $$$w$$$.

Commutative segment tree solves this, since we don't need to merge anything. Instead of merging, we can keep a binary search tree in each node, and when we update something we just replace the element in every relevant BST.

In code it will look like this:

```
def update(p, x):
for node in rootpath(p):
bst[node].erase(a[p])
bst[node].insert(x)
a[p] = x
def query(l, r, x):
res = 0
for node in rangecover(l, r):
res += bst[node].count_smaller(x)
return res
```

## Harder Practice Problems (With Explanations)

### Generalization of CF 406 Div.1 B. Legacy

In the original problem you are asked to find single source shortest paths to every node in a graph with some special edges. They state 3 different types of edges, all of which can be generalized to this:

$$$l_1 \enspace r_1 \enspace l_2 \enspace r_2 \enspace w$$$ means you can go from any node in $$$[l_1, r_1]$$$ to any node in $$$[l_2, r_2]$$$ for the cost of $$$w$$$.

The graph has up to $$$10^5$$$ nodes, and up to $$$10^5$$$ of these special edges.

**Target time complexity (better than the one in the editorial)**

**Solution**

### IATI 21 — News

You are given a rooted tree with each node having the value $$$0$$$ or $$$1$$$. Let $$$s(x, k)$$$ be the set of nodes in $$$x$$$'s subtree that are within a distance of $$$k$$$ from node $$$x$$$. Process $$$q$$$ of the following queries:

- $$$1 \enspace x \enspace k$$$: set the value of every node in $$$s(x, k)$$$ to $$$1$$$.
- $$$2 \enspace x \enspace k$$$: print the sum of values of nodes in $$$s(x, k)$$$.

$$$n, q \le 2 \cdot 10^5$$$

**Solution**

## Implementation

**Python**

**C++**