Union Find
Description:
Example:
Note
Idea:
Union Find来找到一群里的关系网数,以及每个关系网的element个数。
介绍见: https://neo1218.github.io/unionfind/ (他的code有bug)。
这个code我用来练习inheritance的语法。
Code:
#include <iostream>
#include <vector>
#include <climits>
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <stack>
#include <queue>
#include <utility>
#include <memory>
#include <cmath>
using namespace std;
class UnionFind{
public:
UnionFind(int n): id_(n), total_size_(n), num_group(n) {
for(size_t i=0; i<n; ++i){
id_[i]=i;
}
}
virtual ~UnionFind(){
}
size_t num_group;
virtual void unionCombine(int, int)=0;
virtual bool connected(int, int)=0;
virtual ostream & print(ostream & stm = cout){
for(int i=0; i<total_size_; i++){
stm<<i<<' ';
} stm<<'\n';
for(int i=0; i<total_size_; i++){
stm<<id_[i]<<' ';
} stm<<'\n';
return stm;
}
protected:
vector<int> id_;
size_t total_size_;
};
// id vector stores the group number of each node
class QuickFind: public UnionFind{
public:
QuickFind(int n): UnionFind(n){
}
void unionCombine(int p, int q) override {
if(id_[p]==id_[q]) return;
int idp = id_[p]; // Store the value, otherwise bug
for(int i=0; i<total_size_; ++i){
if(id_[i] == idp){
id_[i]=id_[q];
}
}
num_group--;
}
bool connected(int p, int q) override {
return id_[p] == id_[q];
}
};
// id vector stores the root location of each node
class QuickUnion: public UnionFind{
public:
QuickUnion(int n): UnionFind(n){
}
void unionCombine(int p, int q) override {
int rp = root_(p);
int rq = root_(q);
if(rp == rq) return;
id_[rp]=id_[rq];
num_group--;
}
bool connected(int p, int q) override {
return root_(p) == root_(q);
}
protected:
int root_(int p){
int i=p;
while( i != id_[i] ){
i = id_[i];
}
return i;
}
};
// id vector stores the root location of each node
// weight vector stores the size of the subtree at each node
class WeightedUnion: public QuickUnion{
public:
WeightedUnion(int n):QuickUnion(n), weight_(n, 1){
}
void unionCombine(int p, int q) override {
int rp = root_(p);
int rq = root_(q);
if(rp == rq) return;
// change smaller tree root to bigger tree root
// change size of bigger tree root
if(weight_[rp] < weight_[rq]){
id_[rp]=id_[rq];
weight_[rq] += weight_[rp];
}
else{
id_[rq]=id_[rp];
weight_[rp] += weight_[rq];
}
num_group--;
}
ostream & print(ostream & stm = cout) override {
for(int i=0; i<total_size_; i++){
stm<<i<<' ';
} stm<<'\n';
for(int i=0; i<total_size_; i++){
stm<<id_[i]<<' ';
} stm<<'\n';
for(int i=0; i<total_size_; i++){
stm<<weight_[i]<<' ';
} stm<<'\n';
return stm;
}
protected:
vector<int> weight_;
};
int main(){
vector<pair<int, int>> connect={{0, 1}, {0, 3}, {2, 3}, {3, 4}, {4, 5}};
int n=8;
// vector<pair<int, int>> connect={{0, 1}, {0, 3}};
// int n=4;
UnionFind * union_set1 = new QuickFind(n);
UnionFind * union_set2 = new QuickUnion(n);
UnionFind * union_set3 = new WeightedUnion(n);
UnionFind * union_set_test = union_set3;
union_set_test->print()<<'\n';
for(auto e: connect){
union_set_test->unionCombine(e.first, e.second);
// union_set1.print()<<'\n';
}
cout<<"Number of groups: "<<union_set_test->num_group<<endl;
// cout<<union_set1.connected(0, 1)<<endl;
union_set_test->print()<<'\n';
delete union_set1;
delete union_set2;
delete union_set3;
}