llir-opt  0.0.1
Low-Level Post-Link Optimiser for OCaml and C
union_find.h
1 // This file if part of the llir-opt project.
2 // Licensing information can be found in the LICENSE file.
3 // (C) 2018 Nandor Licker. All rights reserved.
4 
5 #pragma once
6 
7 #include <vector>
8 #include <memory>
9 
10 #include "core/adt/id.h"
11 
12 
13 
17 template <typename T>
18 class UnionFind {
19 private:
21  struct Entry {
23  mutable ID<T> Parent;
25  mutable unsigned Rank;
27  std::unique_ptr<T> Element;
28 
29  Entry(ID<T> n, std::unique_ptr<T> &&element)
30  : Parent(n)
31  , Rank(0)
32  , Element(std::move(element))
33  {
34  }
35  };
36 
37 public:
41  class iterator : public std::iterator<std::forward_iterator_tag, T *> {
42  public:
43  bool operator==(const iterator &that) const { return it_ == that.it_; }
44  bool operator!=(const iterator &that) const { return !operator==(that); }
45 
46  iterator &operator++() {
47  ++it_;
48  Skip();
49  return *this;
50  }
51 
52  iterator operator++(int) {
53  auto tmp = *this;
54  ++*this;
55  return tmp;
56  }
57 
58  T *operator*() const { return it_->Element.get(); }
59 
60  private:
61  friend class UnionFind<T>;
62 
63  iterator(
64  UnionFind<T> *that,
65  typename std::vector<Entry>::iterator it)
66  : that_(that)
67  , it_(it)
68  {
69  Skip();
70  }
71 
72  void Skip()
73  {
74  while (it_ != that_->entries_.end() && !it_->Element.get()) {
75  ++it_;
76  }
77  }
78 
79  private:
81  UnionFind<T> *that_;
83  typename std::vector<Entry>::iterator it_;
84  };
85 
86 public:
87  UnionFind() : size_(0) {}
88 
89  template <typename... Args>
90  ID<T> Emplace(Args&&... args)
91  {
92  size_++;
93  unsigned n = entries_.size();
94  ID<T> id(n);
95  entries_.emplace_back(
96  n,
97  std::make_unique<T>(id, std::forward<Args>(args)...)
98  );
99  return id;
100  }
101 
102  ID<T> Union(ID<T> idA, ID<T> idB)
103  {
104  unsigned idxA = Find(idA);
105  unsigned idxB = Find(idB);
106  if (idxA == idxB) {
107  return idxB;
108  }
109 
110  size_--;
111 
112  Entry &entryA = entries_[idxA];
113  Entry &entryB = entries_[idxB];
114  T *a = entryA.Element.get();
115  T *b = entryB.Element.get();
116 
117  if (entryA.Rank < entryB.Rank) {
118  entryA.Parent = idxB;
119  b->Union(*a);
120  entries_[idxA].Element = nullptr;
121  return idxB;
122  } else {
123  entryB.Parent = idxA;
124  a->Union(*b);
125  entries_[idxB].Element = nullptr;
126  if (entryA.Rank == entryB.Rank) {
127  entryA.Rank += 1;
128  }
129  return idxA;
130  }
131  }
132 
133  T *Map(ID<T> id) const
134  {
135  return entries_[Find(id)].Element.get();
136  }
137 
138  T *Get(ID<T> id) const
139  {
140  return entries_[id].Element.get();
141  }
142 
143  ID<T> Find(ID<T> id) const
144  {
145  unsigned root = id;
146  while (entries_[root].Parent != root) {
147  root = entries_[root].Parent;
148  }
149  while (entries_[id].Parent != id) {
150  unsigned parent = entries_[id].Parent;
151  entries_[id].Parent = root;
152  id = parent;
153  }
154  return id;
155  }
156 
158  iterator begin() { return iterator(this, entries_.begin()); }
160  iterator end() { return iterator(this, entries_.end()); }
161 
162  unsigned Size() const { return size_; }
163 
164 private:
166  std::vector<Entry> entries_;
168  unsigned size_;
169 };
UnionFind::begin
iterator begin()
Iterator over root elements - begin.
Definition: union_find.h:158
ID
Definition: id.h:19
UnionFind::end
iterator end()
Iterator over root elements - end.
Definition: union_find.h:160
UnionFind::iterator
Definition: union_find.h:41
UnionFind
Definition: union_find.h:18