I was recently up-solving AtCoder Beginner Contest 339 and came across a Persistent Segment Tree Problem G — Smaller Sum.
While implementing my own Persistent Segment Tree. I tried to make it as generic as possible.
Code:
#include <bits/stdc++.h>
using namespace std;
/**
* Source: https://github.com/kth-competitive-programming/kactl/blob/main/content/various/BumpAllocator.h
* Description: When you need to dynamically allocate many objects and don't
* care about freeing them. "new X" otherwise has an overhead of something like
* 0.05us + 16 bytes per allocation.
*/
const size_t SZ = (450 << 20); // 450 mb
static char buf[SZ];
class Alloc {
private:
size_t ptr;
public:
Alloc() : ptr(sizeof(buf)) {}
void *alloc(size_t s) {
assert(s < ptr);
return (void *)&buf[ptr -= s];
}
void reset() { ptr = sizeof(buf); }
};
template <typename Info> class Node {
public:
Info info;
Node *left, *right;
};
template <typename Info> class PSegTree {
public:
typedef Node<Info> node_t;
node_t *root;
vector<node_t *> time;
int n;
Alloc ar;
PSegTree(size_t sz) : PSegTree(vector<Info>(sz, Info())) {}
PSegTree(const vector<Info> &info) {
root = nullptr;
ar = Alloc();
n = (int)info.size();
function<node_t *(int, int, int)> build = [&](int p, int l,
int r) -> node_t * {
if (l == r) {
node_t *node = (node_t *)ar.alloc(sizeof(node_t));
node->info = info[l];
node->left = node->right = nullptr;
return node;
}
int m = l + (r - l) / 2;
return pull(build(2 * p + 1, l, m), build(2 * p + 2, m + 1, r));
};
root = build(0, 0, n - 1);
time.push_back(root);
}
node_t *pull(node_t *left, node_t *right) {
node_t *node = (node_t *)ar.alloc(sizeof(node_t));
node->info = (left->info + right->info);
node->left = left, node->right = right;
return node;
}
void modify(int p, const Info &v) {
function<node_t *(node_t *, int, int)> _modify = [&](node_t *c, int l,
int r) -> node_t * {
if (l == r) {
node_t *node = (node_t *)ar.alloc(sizeof(node_t));
node->info = v;
node->left = node->right = nullptr;
return node;
}
int m = l + (r - l) / 2;
node_t *left = c->left, *right = c->right;
if (p <= m) {
left = _modify(left, l, m);
} else {
right = _modify(right, m + 1, r);
}
return pull(left, right);
};
root = _modify(root, 0, n - 1);
time.push_back(root);
}
Info rangeQuery(int t, int x, int y) {
function<Info(node_t *, int, int)> query = [&](node_t *c, int l,
int r) -> Info {
if (y < l or r < x or c == nullptr) {
return Info();
}
if (x <= l and r <= y) {
return c->info;
}
int m = l + (r - l) / 2;
return query(c->left, l, m) + query(c->right, m + 1, r);
};
return query(time[t], 0, n - 1);
}
};
class Sum {
public:
int64_t x = 0;
Sum() : x(0) {}
Sum(int64_t _x) : x(_x) {}
};
Sum operator+(const Sum &lf, const Sum &rt) {
return Sum(lf.x + rt.x);
}
void solve() {
int N;
cin >> N;
vector<int> A(N);
for (int i = 0; i < N; i++)
cin >> A[i];
map<int, int, greater<int>> id;
vector<int> a = A;
sort(a.begin(), a.end());
for (auto &e : a) {
if (id.find(e) == id.end()) {
int sz = (int)id.size();
id[e] = sz;
}
}
PSegTree<Sum> seg(id.size());
for (int i = 0; i < N; i++) {
int idx = id[A[i]];
int64_t val = seg.rangeQuery(i, idx, idx).x;
seg.modify(idx, Sum(val + A[i]));
}
int Q;
cin >> Q;
int64_t b = 0;
for (int _ = 0; _ < Q; _++) {
int64_t l, r, x;
cin >> l >> r >> x;
l = (l ^ b), r = (r ^ b), x = (x ^ b);
b = 0;
if (x != 0) {
auto up = id.lower_bound(x);
if (up != id.end()) {
int idx = up->second;
int64_t rt = seg.rangeQuery(r, 0, idx).x;
int64_t lf = seg.rangeQuery(l - 1, 0, idx).x;
b = rt - lf;
}
}
cout << b << '\n';
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
solve();
return 0;
}
Here are few things that I want your help and opinion on:
- Is their a more generic and efficient way to implement Persistent Segment Tree? and How to improve the above code?
- How can you identify if a problem uses Persistent Segment Tree?
- Is their a better way to implement custom allocator for Competitive Programming(or for C++/C projects)?
- How to implement Lazy Persistent Segment Tree?
- How to add documentation to Personal Library codes?
Upd: I re-wrote the Persistent Segment Tree with documentation and without custom allocator. Thanks to lrvideckis and NimaAryan for help.
Code:
#include <bits/stdc++.h>
using namespace std;
template <typename Info> class PSegTree {
public:
int root;
vector<Info> info;
vector<int> time, left, right;
int n, index, size;
/**
* Create a new Persistent Segment Tree
* @brief constructor
* @param sz defines the size of range [0, sz - 1]
* @time O(n * log(n))
* @space O(8 * n)
*/
PSegTree(size_t sz) : PSegTree(vector<Info>(sz, Info())) {}
/**
* Create a new Persistent Segment Tree
* @brief constructor
* @param a vector defines the size of range [0, len(a) - 1]
* @time O(n * log(n))
* @space O(8 * n)
*/
PSegTree(const vector<Info> &a) {
root = -1;
index = 0;
n = (int)a.size();
size = 8 * n;
info.assign(size, Info());
left.assign(size, -1);
right.assign(size, -1);
function<int(int, int)> build = [&](int l, int r) -> int {
if (l == r) {
return add_leaf(a[l]);
}
int m = l + (r - l) / 2;
return pull(build(l, m), build(m + 1, r));
};
root = build(0, n - 1);
time.push_back(root);
}
/**
* @brief adds a leaf
* @param v to added as leaf
* @return index of created leaf
* @time O(1)
* @space O(1)
*/
int add_leaf(const Info &v) {
if (index < size) {
info[index] = v;
left[index] = -1;
right[index] = -1;
index++;
return index - 1;
}
assert(info.size() == left.size() and left.size() == right.size());
assert(index == size);
int idx = (int)info.size();
info.push_back(v);
left.push_back(-1);
right.push_back(-1);
size++;
index++;
return idx;
}
/**
* @brief adds parent to children
* @param left_idx and right_idx define the children to parent
* @return index of created parent
* @time O(1)
* @space O(1)
*/
int pull(int left_idx, int right_idx) {
if (index < size) {
info[index] = info[left_idx] + info[right_idx];
left[index] = left_idx;
right[index] = right_idx;
index++;
return index - 1;
}
assert(info.size() == left.size() and left.size() == right.size());
assert(index == size);
int idx = (int)info.size();
info.push_back(info[left_idx] + info[right_idx]);
left.push_back(left_idx);
right.push_back(right_idx);
size++;
index++;
return idx;
}
/**
* @brief modify the value a[index] = v on latest version
* @param p index of value to modify
* @param v new value
* @time O(log(n))
* @space O(log(n))
*/
void modify(int p, const Info &v) {
function<int(int, int, int)> _modify = [&](int c, int l, int r) -> int {
if (l == r) {
return add_leaf(v);
}
int m = l + (r - l) / 2;
int left_ptr = left[c], right_ptr = right[c];
if (p <= m) {
left_ptr = _modify(left_ptr, l, m);
} else {
right_ptr = _modify(right_ptr, m + 1, r);
}
return pull(left_ptr, right_ptr);
};
root = _modify(root, 0, n - 1);
time.push_back(root);
}
/**
* @brief modify the value a[index] = v on version t
* @param t defines the version
* @param p index of value to modify
* @param v new value
* @time O(log(n))
* @space O(log(n))
*/
void modifyTime(int t, int p, const Info &v) {
assert(t < (int)time.size());
function<int(int, int, int)> _modify = [&](int c, int l, int r) -> int {
if (l == r) {
return add_leaf(v);
}
int m = l + (r - l) / 2;
int left_ptr = left[c], right_ptr = right[c];
if (p <= m) {
left_ptr = _modify(left_ptr, l, m);
} else {
right_ptr = _modify(right_ptr, m + 1, r);
}
return pull(left_ptr, right_ptr);
};
root = _modify(time[t], 0, n - 1);
time.push_back(root);
}
/**
* @brief find the range query for [x, y] on version t
* @param t defines the version
* @param x, y defines the range [x, y]
* @return a[x] + a[x + 1] + ... + a[y - 1] + a[y] on version t
* @time O(log(n))
* @space O(log(n))
*/
Info rangeQuery(int t, int x, int y) {
function<Info(int, int, int)> query = [&](int c, int l, int r) -> Info {
if (y < l or r < x or c == -1) {
return Info();
}
if (x <= l and r <= y) {
return info[c];
}
int m = l + (r - l) / 2;
return query(left[c], l, m) + query(right[c], m + 1, r);
};
return query(time[t], 0, n - 1);
}
};
class Sum {
public:
int64_t x = 0;
Sum() : x(0) {}
Sum(int64_t _x) : x(_x) {}
};
Sum operator+(const Sum &lf, const Sum &rt) { return Sum(lf.x + rt.x); }
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
vector<int> A(N);
for (int i = 0; i < N; i++)
cin >> A[i];
map<int, int, greater<int>> id;
vector<int> a = A;
sort(a.begin(), a.end());
for (auto &e : a) {
if (id.find(e) == id.end()) {
int sz = (int)id.size();
id[e] = sz;
}
}
PSegTree<Sum> seg(id.size());
for (int i = 0; i < N; i++) {
int idx = id[A[i]];
int64_t val = seg.rangeQuery(i, idx, idx).x;
seg.modifyTime(i, idx, Sum(val + A[i]));
}
int Q;
cin >> Q;
int64_t b = 0;
for (int _ = 0; _ < Q; _++) {
int64_t l, r, x;
cin >> l >> r >> x;
l = (l ^ b), r = (r ^ b), x = (x ^ b);
b = 0;
if (x != 0) {
auto up = id.lower_bound(x);
if (up != id.end()) {
int idx = up->second;
int64_t rt = seg.rangeQuery(r, 0, idx).x;
int64_t lf = seg.rangeQuery(l - 1, 0, idx).x;
b = rt - lf;
}
}
cout << b << '\n';
}
return 0;
}
Implementing an allocator class, I always think about memory alignment. Do you?
Not really, I usually use char array for creating a buffers even for TCP/IP stuff.
For competitive programming purposes, using an
inline static std::deque<node_t>
declared inside the class, and a staticnode_t* new_node()
function should be enough (my lazy persistent segment tree template uses it). It has very little overhead compared to the allocator you used, and it's generic and resizable. We usestd::deque
instead ofstd::vector
because reference invalidation is not an issue instd::deque
.I didn't get the
deque
part, bit of code would help. Can you provide me your implementation of lazy persistent segment tree?Upd: Someone helped me understand how to use
deque
instead of custom allocator. Seem way better then custom allocator.Your new code seems to use vector instead of deque and indices. My point was that you could do something like this:
This doesn't work with
std::vector
since onemplace_back
the pointers to the elements in the vector might not remain valid due to potential reallocation of the vector.And if you want to give it multiple-constructor-like-syntax, you can do something like this:
https://github.com/amenotiomoi/template?tab=readme-ov-file#segment-treesingle-point-change-interval-query-persistence
Here's my implementation of the persistent segtree, I've separated the node's merge rules into a new structure, and I'm only using O(1) space in each structure, which means you can use separate persistent segtrees as if they were normal variables, though this causes me to not release memory very well (I don't have a good way to determine the life cycle). It's CP though, memory largely doesn't matter lol
It might just me but I need my code to reclaim memory before the program termination 👽 . Btw really cool way to document code library on github 👍 but I want to add documentation in code itself.
Here is my Persistent Segment Tree without pointers :)
it's more efficient and easier to code.
I have added an updated version of code. Thanks for help.