llir-opt  0.0.1
Low-Level Post-Link Optimiser for OCaml and C
bitset.h
1 // This file if part of the llir-opt project.
2 // Licensing information can be found in the LICENSE file.
3 // (C) 2018 Nandor Licker. All rights reserved.
4 
5 #pragma once
6 
7 #include <cassert>
8 #include <cstdint>
9 #include <climits>
10 #include <limits>
11 #include <map>
12 
13 #include "core/adt/id.h"
14 
15 #include <iostream>
16 #include <bitset>
17 
18 
19 
23 template<typename T, unsigned N = 8>
24 class BitSet final {
25 private:
27  static constexpr uint64_t kBitsInBucket = sizeof(uint64_t) * CHAR_BIT;
29  static constexpr uint64_t kBitsInChunk = kBitsInBucket * N;
30 
32  struct Node {
33  public:
34  Node()
35  {
36  for (unsigned i = 0; i < N; ++i) {
37  arr[i] = 0ull;
38  }
39  }
40 
41  bool Insert(unsigned bit)
42  {
43  const uint64_t bucket = bit / kBitsInBucket;
44  const uint64_t mask = 1ull << (bit - bucket * kBitsInBucket);
45  const uint64_t val = arr[bucket];
46  const bool inserted = !(val & mask);
47  arr[bucket] = val | mask;
48  return inserted;
49  }
50 
51  bool Contains(unsigned bit) const
52  {
53  uint64_t bucket = bit / kBitsInBucket;
54  const uint64_t mask = 1ull << (bit - bucket * kBitsInBucket);
55  return arr[bucket] & mask;
56  }
57 
58  bool Erase(unsigned bit)
59  {
60  const uint64_t bucket = bit / kBitsInBucket;
61 
62  arr[bucket] &= ~(1ull << (bit - bucket * kBitsInBucket));
63 
64  for (unsigned i = 0; i < N; ++i) {
65  if (arr[i] != 0ull) {
66  return false;
67  }
68  }
69  return true;
70  }
71 
72  size_t Size() const
73  {
74  size_t size = 0;
75  for (unsigned i = 0; i < N; ++i) {
76  size += __builtin_popcountll(arr[i]);
77  }
78  return size;
79  }
80 
81  bool operator==(const Node &that) const
82  {
83  for (unsigned i = 0; i < N; ++i) {
84  if (arr[i] != that.arr[i]) {
85  return false;
86  }
87  }
88  return true;
89  }
90 
91  unsigned Union(const Node &that)
92  {
93  unsigned changed = 0;
94  for (unsigned i = 0; i < N; ++i) {
95  uint64_t old = arr[i];
96  arr[i] |= that.arr[i];
97  changed += __builtin_popcountll(old ^ arr[i]);
98  }
99  return changed;
100  }
101 
102  bool Subtract(const Node &that)
103  {
104  bool zero = true;
105  for (unsigned i = 0; i < N; ++i) {
106  arr[i] = arr[i] & ~that.arr[i];
107  zero = zero && (arr[i] == 0);
108  }
109  return zero;
110  }
111 
112  bool And(const Node &that)
113  {
114  bool zero = true;
115  for (unsigned i = 0; i < N; ++i) {
116  arr[i] = arr[i] & that.arr[i];
117  zero = zero && (arr[i] == 0);
118  }
119  return zero;
120  }
121 
122  unsigned Next(unsigned bit) const
123  {
124  const unsigned bucket = bit / kBitsInBucket;
125  const unsigned offset = bit - bucket * kBitsInBucket;
126 
127  uint64_t mask;
128  if (offset + 1 == kBitsInBucket) {
129  mask = 0;
130  } else {
131  mask = arr[bucket] & ~((1ull << (offset + 1)) - 1);
132  }
133 
134  if (mask) {
135  return bucket * kBitsInBucket + __builtin_ctzll(mask);
136  } else {
137  for (unsigned i = bucket + 1; i < N; ++i) {
138  if (arr[i]) {
139  return i * kBitsInBucket + __builtin_ctzll(arr[i]);
140  }
141  }
142  return 0;
143  }
144  }
145 
146  unsigned First() const
147  {
148  for (unsigned i = 0; i < N; ++i) {
149  if (arr[i]) {
150  return i * kBitsInBucket + __builtin_ctzll(arr[i]);
151  }
152  }
153  return 0;
154  }
155 
156  unsigned Prev(unsigned bit) const
157  {
158  const unsigned bucket = bit / kBitsInBucket;
159  const unsigned offset = bit - bucket * kBitsInBucket;
160 
161  uint64_t mask;
162  if (offset == 0) {
163  mask = 0;
164  } else {
165  mask = arr[bucket] & ((1ull << offset) - 1);
166  }
167 
168  if (mask) {
169  return (bucket + 1) * kBitsInBucket - __builtin_clzll(mask) - 1;
170  } else {
171  for (int i = bucket - 1; i >= 0; --i) {
172  if (arr[i]) {
173  return (i + 1) * kBitsInBucket - __builtin_clzll(arr[i]) - 1;
174  }
175  }
176  return kBitsInChunk;
177  }
178  }
179 
180  unsigned Last() const
181  {
182  for (int i = N - 1; i >= 0; --i) {
183  if (arr[i]) {
184  return (i + 1) * kBitsInBucket - __builtin_clzll(arr[i]) - 1;
185  }
186  }
187  return kBitsInChunk;
188  }
189 
190  private:
191  uint64_t arr[N];
192  };
193 
195  using NodeIt = typename std::map<uint32_t, Node>::const_iterator;
196 
197 public:
199  class iterator final {
200  public:
201  iterator(const iterator &that)
202  : set_(that.set_)
203  , it_(that.it_)
204  , current_(that.current_)
205  {
206  }
207 
208  iterator(const BitSet<T> &set, int64_t current)
209  : set_(&set)
210  , it_(set.nodes_.find(current / kBitsInChunk))
211  , current_(current)
212  {
213  }
214 
215  iterator &operator=(const iterator &that)
216  {
217  set_ = that.set_;
218  it_ = that.it_;
219  current_ = that.current_;
220  return *this;
221  }
222 
223  ID<T> operator*() const
224  {
225  return current_;
226  }
227 
228  iterator operator++(int)
229  {
230  iterator it(*this);
231  ++*this;
232  return it;
233  }
234 
235  iterator operator++()
236  {
237  if (current_ == set_->last_) {
238  current_ = set_->last_ + 1;
239  } else {
240  unsigned currPos = current_ & (kBitsInChunk - 1);
241  unsigned nextPos = it_->second.Next(currPos);
242  if (nextPos == 0) {
243  ++it_;
244  current_ = it_->first * kBitsInChunk + it_->second.First();
245  } else {
246  current_ = it_->first * kBitsInChunk + nextPos;
247  }
248  }
249  return *this;
250  }
251 
252  bool operator == (const iterator &that) const
253  {
254  return current_ == that.current_;
255  }
256 
257  bool operator != (const iterator &that) const
258  {
259  return !(*this == that);
260  }
261 
262  private:
264  const BitSet<T> *set_;
266  NodeIt it_;
268  uint64_t current_;
269  };
270 
272  class reverse_iterator final {
273  public:
275  : set_(that.set_)
276  , it_(that.it_)
277  , current_(that.current_)
278  {
279  }
280 
281  reverse_iterator(const BitSet<T> &set, int64_t current)
282  : set_(set)
283  , it_(set.nodes_.find(current / kBitsInChunk))
284  , current_(current)
285  {
286  }
287 
288  ID<T> operator * () const
289  {
290  return current_;
291  }
292 
293  reverse_iterator operator ++ (int)
294  {
295  reverse_iterator it(*this);
296  ++*this;
297  return it;
298  }
299 
300  reverse_iterator operator ++ ()
301  {
302  if (current_ == set_.first_) {
303  current_ = set_.first_ - 1;
304  } else {
305  unsigned currPos = current_ & (kBitsInChunk - 1);
306  unsigned nextPos = it_->second.Prev(currPos);
307  if (nextPos == kBitsInChunk) {
308  --it_;
309  current_ = it_->first * kBitsInChunk + it_->second.Last();
310  } else {
311  current_ = it_->first * kBitsInChunk + nextPos;
312  }
313  }
314  return *this;
315  }
316 
317  bool operator==(const reverse_iterator &that) const
318  {
319  return current_ == that.current_;
320  }
321 
322  bool operator!=(const reverse_iterator &that) const
323  {
324  return !(*this == that);
325  }
326 
327  private:
329  const BitSet<T> &set_;
331  NodeIt it_;
333  int64_t current_;
334  };
335 
337  explicit BitSet()
338  : first_(std::numeric_limits<uint32_t>::max())
339  , last_(std::numeric_limits<uint32_t>::min())
340  {
341  }
342 
344  explicit BitSet(ID<T> id)
345  : BitSet()
346  {
347  Insert(id);
348  }
349 
352  {
353  }
354 
356  iterator begin() const
357  {
358  return Empty() ? end() : iterator(*this, first_);
359  }
360 
362  iterator end() const
363  {
364  return iterator(*this, static_cast<int64_t>(last_) + 1ull);
365  }
366 
368  reverse_iterator rbegin() const
369  {
370  return Empty() ? rend() : reverse_iterator(*this, last_);
371  }
372 
374  reverse_iterator rend() const
375  {
376  return reverse_iterator(*this, static_cast<int64_t>(first_) - 1ull);
377  }
378 
380  bool Empty() const { return nodes_.empty(); }
381 
383  void Clear()
384  {
385  first_ = std::numeric_limits<uint32_t>::max();
386  last_ = std::numeric_limits<uint32_t>::min();
387  nodes_.clear();
388  }
389 
391  bool Insert(const ID<T> &item)
392  {
393  first_ = std::min(first_, static_cast<uint32_t>(item));
394  last_ = std::max(last_, static_cast<uint32_t>(item));
395 
396  auto &node = nodes_[item / kBitsInChunk];
397  return node.Insert(item - (item / kBitsInChunk) * kBitsInChunk);
398  }
399 
401  void Erase(const ID<T> &item)
402  {
403  if (item == first_ && item == last_) {
404  first_ = std::numeric_limits<uint32_t>::max();
405  last_ = std::numeric_limits<uint32_t>::min();
406  } else if (item == first_) {
407  first_ = *++begin();
408  } else if (item == last_) {
409  last_ = *++rbegin();
410  }
411 
412  auto &node = nodes_[item / kBitsInChunk];
413  if (node.Erase(item - (item / kBitsInChunk) * kBitsInChunk)) {
414  nodes_.erase(item / kBitsInChunk);
415  }
416  }
417 
419  bool Contains(const ID<T> &item) const
420  {
421  if (item < first_ || last_ < item) {
422  return false;
423  }
424  auto it = nodes_.find(item / kBitsInChunk);
425  if (it == nodes_.end()) {
426  return false;
427  }
428  return it->second.Contains(item - (item / kBitsInChunk) * kBitsInChunk);
429  }
430 
434  unsigned Union(const BitSet &that)
435  {
436  unsigned changed = 0;
437 
438  for (auto &thatNode : that.nodes_) {
439  changed += nodes_[thatNode.first].Union(thatNode.second);
440  }
441 
442  first_ = std::min(first_, that.first_);
443  last_ = std::max(last_, that.last_);
444 
445  return changed;
446  }
447 
449  void Subtract(const BitSet &that)
450  {
451  auto it = nodes_.begin();
452  auto tt = that.nodes_.begin();
453  while (it != nodes_.end() && tt != that.nodes_.end()) {
454  // Advance iterators until indices match.
455  while (it != nodes_.end() && it->first < tt->first) {
456  ++it;
457  }
458  if (it == nodes_.end()) {
459  break;
460  }
461  while (tt != that.nodes_.end() && tt->first < it->first) {
462  ++tt;
463  }
464  if (tt == that.nodes_.end()) {
465  break;
466  }
467 
468  // Erase the node if all bits are deleted.
469  if (it->first == tt->first) {
470  if (it->second.Subtract(tt->second)) {
471  nodes_.erase(it++);
472  } else {
473  ++it;
474  }
475  }
476  }
477 
478  ResetFirstLast();
479  }
480 
482  void Intersect(const BitSet &that)
483  {
484  auto it = nodes_.begin();
485  auto tt = that.nodes_.begin();
486  while (it != nodes_.end() && tt != that.nodes_.end()) {
487  // Advance iterators until indices match.
488  while (it != nodes_.end() && it->first < tt->first) {
489  nodes_.erase(it++);
490  }
491  if (it == nodes_.end()) {
492  break;
493  }
494  while (tt != that.nodes_.end() && tt->first < it->first) {
495  ++tt;
496  }
497  if (tt == that.nodes_.end()) {
498  break;
499  }
500 
501  // Erase the node if all bits are deleted.
502  if (it->first == tt->first) {
503  if (it->second.And(tt->second)) {
504  nodes_.erase(it++);
505  } else {
506  ++it;
507  }
508  ++tt;
509  }
510  }
511 
512  ResetFirstLast();
513  }
514 
516  size_t Size() const
517  {
518  size_t size = 0;
519  for (auto &[id, node] : nodes_) {
520  size += node.Size();
521  }
522  return size;
523  }
524 
526  bool operator == (const BitSet &that) const
527  {
528  if (first_ != that.first_) {
529  return false;
530  }
531  if (last_ != that.last_) {
532  return false;
533  }
534 
535  if (nodes_.size() != that.nodes_.size()) {
536  return false;
537  }
538 
539  return std::equal(
540  nodes_.begin(), nodes_.end(),
541  that.nodes_.begin(), that.nodes_.end()
542  );
543  }
544 
546  bool operator != (const BitSet &that) const
547  {
548  return !operator==(that);
549  }
550 
552  BitSet operator-(const BitSet &that) const
553  {
554  BitSet copy(*this);
555  copy.Subtract(that);
556  return copy;
557  }
558 
560  BitSet operator|(const BitSet &that) const
561  {
562  BitSet copy(*this);
563  copy.Union(that);
564  return copy;
565  }
566 
568  BitSet &operator|=(const BitSet &that)
569  {
570  Union(that);
571  return *this;
572  }
573 
575  BitSet operator&(const BitSet &that) const
576  {
577  BitSet copy(*this);
578  copy.Intersect(that);
579  return copy;
580  }
581 
582 private:
584  void ResetFirstLast()
585  {
586  if (nodes_.empty()) {
587  first_ = std::numeric_limits<uint32_t>::max();
588  last_ = std::numeric_limits<uint32_t>::min();
589  } else {
590  auto &[fi, fo] = *nodes_.begin();
591  auto &[li, lo] = *nodes_.rbegin();
592  first_ = fi * kBitsInChunk + fo.First();
593  last_ = li * kBitsInChunk + lo.Last();
594  }
595  }
596 
597 private:
599  uint32_t first_;
601  uint32_t last_;
603  std::map<uint32_t, Node> nodes_;
604 };
605 
607 template <typename T>
608 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const BitSet<T> &s)
609 {
610  os << "{";
611  bool first = true;
612  for (auto id : s) {
613  if (first) {
614  first = false;
615  } else {
616  os << ", ";
617  }
618  os << id;
619  }
620  os << "}";
621  return os;
622 }
BitSet::iterator
Iterator over the bitset items.
Definition: bitset.h:199
BitSet::operator==
bool operator==(const BitSet &that) const
Checks if two bitsets are equal.
Definition: bitset.h:526
BitSet::Contains
bool Contains(const ID< T > &item) const
Checks if a bit is set.
Definition: bitset.h:419
BitSet::Erase
void Erase(const ID< T > &item)
Erases a bit.
Definition: bitset.h:401
BitSet::Size
size_t Size() const
Returns the size of the document.
Definition: bitset.h:516
BitSet::begin
iterator begin() const
Start iterator.
Definition: bitset.h:356
BitSet::operator|
BitSet operator|(const BitSet &that) const
Bitwise or.
Definition: bitset.h:560
BitSet::rend
reverse_iterator rend() const
Reverse end iterator.
Definition: bitset.h:374
BitSet::BitSet
BitSet()
Constructs a new bitset.
Definition: bitset.h:337
BitSet::Subtract
void Subtract(const BitSet &that)
Subtracts a bitset from another.
Definition: bitset.h:449
BitSet::end
iterator end() const
End iterator.
Definition: bitset.h:362
ID
Definition: id.h:19
BitSet::operator!=
bool operator!=(const BitSet &that) const
Checks if two bitsets are different.
Definition: bitset.h:546
BitSet::Clear
void Clear()
Clears these set.
Definition: bitset.h:383
BitSet::BitSet
BitSet(ID< T > id)
Constructs a singleton bitset.
Definition: bitset.h:344
BitSet::operator|=
BitSet & operator|=(const BitSet &that)
Bitwise or.
Definition: bitset.h:568
BitSet
Definition: bitset.h:24
BitSet::Empty
bool Empty() const
Checks if the set is empty.
Definition: bitset.h:380
BitSet::rbegin
reverse_iterator rbegin() const
Reverse start iterator.
Definition: bitset.h:368
BitSet::Intersect
void Intersect(const BitSet &that)
Subtracts a bitset from another.
Definition: bitset.h:482
BitSet::Insert
bool Insert(const ID< T > &item)
Inserts an item into the bitset.
Definition: bitset.h:391
BitSet::operator-
BitSet operator-(const BitSet &that) const
Subtraction.
Definition: bitset.h:552
BitSet::reverse_iterator
Reverse iterator over the bitset items.
Definition: bitset.h:272
BitSet::Union
unsigned Union(const BitSet &that)
Definition: bitset.h:434
BitSet::~BitSet
~BitSet()
Deletes the bitset.
Definition: bitset.h:351
Node::Node
Node(Kind kind)
Creates a new node.
Definition: node.cpp:12
BitSet::operator&
BitSet operator&(const BitSet &that) const
Bitwise and.
Definition: bitset.h:575