Union Find


Description:

Example:

Note

Idea:

Union Find来找到一群里的关系网数,以及每个关系网的element个数。

介绍见: https://neo1218.github.io/unionfind/ (他的code有bug)。

这个code我用来练习inheritance的语法。

Code:

#include <iostream>
#include <vector>
#include <climits>
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <stack>
#include <queue>
#include <utility>
#include <memory>
#include <cmath>

using namespace std;

class UnionFind{
public:
    UnionFind(int n): id_(n), total_size_(n), num_group(n) {
        for(size_t i=0; i<n; ++i){
            id_[i]=i;
        }
    }
    virtual ~UnionFind(){
    }

    size_t num_group;
    virtual void unionCombine(int, int)=0;
    virtual bool connected(int, int)=0;

    virtual ostream & print(ostream & stm = cout){

        for(int i=0; i<total_size_; i++){
            stm<<i<<' ';
        } stm<<'\n';

        for(int i=0; i<total_size_; i++){
            stm<<id_[i]<<' '; 
        } stm<<'\n';

        return stm;
    }

protected:
    vector<int> id_;
    size_t total_size_;
};

// id vector stores the group number of each node
class QuickFind: public UnionFind{
public:
    QuickFind(int n): UnionFind(n){
    }
    void unionCombine(int p, int q) override {
        if(id_[p]==id_[q]) return;

        int idp = id_[p]; // Store the value, otherwise bug

        for(int i=0; i<total_size_; ++i){
            if(id_[i] == idp){
                id_[i]=id_[q];
            }
        }
        num_group--;
    }

    bool connected(int p, int q) override {
        return id_[p] == id_[q];
    }

};

// id vector stores the root location of each node
class QuickUnion: public UnionFind{
public:
    QuickUnion(int n): UnionFind(n){
    }

    void unionCombine(int p, int q) override {
        int rp = root_(p);
        int rq = root_(q);
        if(rp == rq) return;

        id_[rp]=id_[rq];
        num_group--;
    }

    bool connected(int p, int q) override {
        return root_(p) == root_(q);
    }

protected:
    int root_(int p){
        int i=p;
        while( i != id_[i] ){
            i = id_[i];
        }

        return i;
    }
};

// id vector stores the root location of each node
// weight vector stores the size of the subtree at each node
class WeightedUnion: public QuickUnion{
public:
    WeightedUnion(int n):QuickUnion(n), weight_(n, 1){
    }

    void unionCombine(int p, int q) override {
        int rp = root_(p);
        int rq = root_(q);
        if(rp == rq) return;

        // change smaller tree root to bigger tree root
        // change size of bigger tree root
        if(weight_[rp] < weight_[rq]){
            id_[rp]=id_[rq];
            weight_[rq] += weight_[rp];          
        }
        else{
            id_[rq]=id_[rp];
            weight_[rp] += weight_[rq];                   
        }

        num_group--;
    }

    ostream & print(ostream & stm = cout) override {

        for(int i=0; i<total_size_; i++){
            stm<<i<<' ';
        } stm<<'\n';

        for(int i=0; i<total_size_; i++){
            stm<<id_[i]<<' '; 
        } stm<<'\n';
        for(int i=0; i<total_size_; i++){
            stm<<weight_[i]<<' '; 
        } stm<<'\n';
        return stm;
    }

protected:
    vector<int> weight_;
};



int main(){
    vector<pair<int, int>> connect={{0, 1}, {0, 3}, {2, 3}, {3, 4}, {4, 5}};
    int n=8;
    // vector<pair<int, int>> connect={{0, 1}, {0, 3}};
    // int n=4;

    UnionFind * union_set1 = new QuickFind(n);
    UnionFind * union_set2 = new QuickUnion(n);
    UnionFind * union_set3 = new WeightedUnion(n);
    UnionFind * union_set_test = union_set3;

    union_set_test->print()<<'\n';
    for(auto e: connect){
        union_set_test->unionCombine(e.first, e.second);
        // union_set1.print()<<'\n';
    }
    cout<<"Number of groups: "<<union_set_test->num_group<<endl;
    // cout<<union_set1.connected(0, 1)<<endl;
    union_set_test->print()<<'\n';

    delete union_set1;
    delete union_set2;
    delete union_set3;
}

results matching ""

    No results matching ""