Блог пользователя micho

Автор micho, 20 месяцев назад, По-английски

I don't know if this technique is very well-known or obvious, but I haven't seen it anywhere else before. So, here it is.

Problem

Consider the following problem. You are given an array $$$A = (A_1, A_2, A_3 \dots A_N)$$$ of $$$N$$$ integers. We will define a merge function $$$F(X) : X \rightarrow Y $$$ where array $$$X$$$ is transformed to array $$$Y$$$:

  • For each index $$$i$$$ $$$(1 \leq i < len(X))$$$ if $$$A_i$$$ and $$$A_{i+1}$$$ are equal, we are going to merge them into one element $$$A_i + 1$$$, output the element and skip the next element (with index $$$(i+1)$$$). Otherwise, we are going to output $$$A_i$$$.

The problem then asks you to find out, what is $$$F^N(A)$$$. In other words, how would the array look after applying the merge operation $$$N$$$ times?

Example

Consider the following array:

  • $$$A = (3, 3, 4, 2, 2, 2, 2, 1)$$$
  • $$$F(A) = (4, 4, 3, 3, 1)$$$ — Here we merge the first two 3s, and the pairs of 2s
  • $$$F^2(A) = (5, 4, 1)$$$ — Here we merge the pairs of 4s and 3s.
  • $$$F^K(A) = F^2(A)$$$, $$$K \geq 2$$$ and thus it's the answer.

Breakdown

The naive solution is to always iterate from the start of the array and merge the elements into a new array. Repeating the process as many times as you will need. But, here is a contra-example to proof that this type of merging is $$$O(N^2)$$$.

Consider the following array:

  • $$$A = (N - 1, N - 2, N - 3 \dots 4, 3, 2, 1, 1)$$$

You will need $$$N - 1$$$ merge operations and for each operation you will iterate across the length of the array in $$$O(N)$$$. Giving us the naive solution of $$$O(N^2)$$$.

We want to show a $$$O(N)$$$ solution.

Technique

The main idea behind the technique is to keep a set of $$$N$$$ pointers where such merges will occur. It's similar to a BFS from multiple sources, but here we have it on a list. The first step is to iterate the array and store the pointers to the locations in the first merge operation. After that, we will go through the pointers and merge the elements in the list. While we do that, we will take a look at the positions right before the merges, so we can create new pointers. Because there will be only one pointer per merge of a pair, there are at most $$$N - 1$$$ merges which means that the amortized complexity is $$$O(N)$$$ at worse.

Here's the idea in code.

vector<Node*> pointers;

// HERE find the pointers
for (...)

while (!pointers.empty()) {
    vector<Node*> new_pointers;
    
    for (auto p: pointers) {
        // do something

        // check if we have to create a new pointer
        if (...) new_pointers.push_back(...)
    }

    pointers.clear();
    copy(new_pointers.begin(), new_pointers.end(), pointers.begin());
}

Here is a solution to the original problem implemented with a custom list. I'm sure you can make this solution cleaner.

#include <bits/stdc++.h>
using namespace std;

struct Node {
    int value = -1;
    Node* prev = nullptr;
    Node* next = nullptr;
    bool updated = false;
    Node* addNode() {
        next = new Node;
        cin >> next->value;
        next->prev = this;
        return next;
    }
    void print() {
        cout << value << " ";
        if (next != nullptr) next->print();
    }
};

int main() {
    int N; cin >> N;
    Node* head = new Node;
    Node* current = head;
    for (int i = 0; i < N; i++) current = current->addNode();
    vector<Node*> p; // find the initial pointers
    for (auto it = head->next; it != nullptr; it = it->next) {
        if (it->next == nullptr) break;
        if (it->value == it->next->value) {
            p.push_back(it);
            if ((it = it->next) == nullptr) break;
        }
    }

    // This is O(N) amortized.
    while (!p.empty()) {
        for (int i = 0; i < p.size(); i++) {
            if (p[i]->updated) continue;
            if (i && p[i] == p[i-1]) continue;
            if (p[i]->next->value != p[i]->value) continue;
            p[i]->next->value++;
            p[i]->next->updated = true;
            // delete node at p[i]
            p[i]->prev->next = p[i]->next;
            if (p[i]->next != nullptr) {
                p[i]->next->prev = p[i]->prev;
                p[i] = p[i]->next;
            }
        }
        vector<Node*> new_p;
        for (int i = 0; i < p.size(); i++) {
            p[i]->updated = false;
            if (p[i]->value == p[i]->prev->value) {
                new_p.push_back(p[i]->prev);
            }
            if (p[i]->next != nullptr && p[i]->value == p[i]->next->value) {
                new_p.push_back(p[i]);
            }
        }
        p.swap(new_p);
    }

    head->next->print();
    return 0;
}

Other Problems

Here are some problems that I solved with this technique.

  • Проголосовать: нравится
  • 0
  • Проголосовать: не нравится