1 | /*
|
2 | * Souffle - A Datalog Compiler
|
3 | * Copyright (c) 2017 The Souffle Developers. All rights reserved
|
4 | * Licensed under the Universal Permissive License v 1.0 as shown at:
|
5 | * - https://opensource.org/licenses/UPL
|
6 | * - <souffle root>/licenses/SOUFFLE-UPL.txt
|
7 | */
|
8 |
|
9 | /************************************************************************
|
10 | *
|
11 | * @file UnionFind.h
|
12 | *
|
13 | * Defines a union-find data-structure
|
14 | *
|
15 | ***********************************************************************/
|
16 |
|
17 | #pragma once
|
18 |
|
19 | #include "souffle/datastructure/LambdaBTree.h"
|
20 | #include "souffle/datastructure/PiggyList.h"
|
21 | #include "souffle/utility/MiscUtil.h"
|
22 | #include <atomic>
|
23 | #include <cstddef>
|
24 | #include <cstdint>
|
25 | #include <functional>
|
26 | #include <utility>
|
27 |
|
28 | namespace souffle {
|
29 |
|
30 | // branch predictor hacks
|
31 | #define unlikely(x) __builtin_expect((x), 0)
|
32 | #define likely(x) __builtin_expect((x), 1)
|
33 |
|
34 | using rank_t = uint8_t;
|
35 | /* technically uint56_t, but, doesn't exist. Just be careful about storing > 2^56 elements. */
|
36 | using parent_t = uint64_t;
|
37 |
|
38 | // number of bits that the rank is
|
39 | constexpr uint8_t split_size = 8u;
|
40 |
|
41 | // block_t stores parent in the upper half, rank in the lower half
|
42 | using block_t = uint64_t;
|
43 | // block_t & rank_mask extracts the rank
|
44 | constexpr block_t rank_mask = (1ul << split_size) - 1;
|
45 |
|
46 | /**
|
47 | * Structure that emulates a Disjoint Set, i.e. a data structure that supports efficient union-find operations
|
48 | */
|
49 | class DisjointSet {
|
50 | template <typename TupleType>
|
51 | friend class EquivalenceRelation;
|
52 |
|
53 | PiggyList<std::atomic<block_t>> a_blocks;
|
54 |
|
55 | public:
|
56 | DisjointSet() = default;
|
57 |
|
58 | // copy ctor
|
59 | DisjointSet(DisjointSet& other) = delete;
|
60 | // move ctor
|
61 | DisjointSet(DisjointSet&& other) = delete;
|
62 |
|
63 | // copy assign ctor
|
64 | DisjointSet& operator=(DisjointSet& ds) = delete;
|
65 | // move assign ctor
|
66 | DisjointSet& operator=(DisjointSet&& ds) = delete;
|
67 |
|
68 | /**
|
69 | * Return the number of elements in this disjoint set (not the number of pairs)
|
70 | */
|
71 | inline std::size_t size() {
|
72 | auto sz = a_blocks.size();
|
73 | return sz;
|
74 | };
|
75 |
|
76 | /**
|
77 | * Yield reference to the node by its node index
|
78 | * @param node node to be searched
|
79 | * @return the parent block of the specified node
|
80 | */
|
81 | inline std::atomic<block_t>& get(parent_t node) const {
|
82 | auto& ret = a_blocks.get(node);
|
83 | return ret;
|
84 | };
|
85 |
|
86 | /**
|
87 | * Equivalent to the find() function in union/find
|
88 | * Find the highest ancestor of the provided node - flattening as we go
|
89 | * @param x the node to find the parent of, whilst flattening its set-tree
|
90 | * @return The parent of x
|
91 | */
|
92 | parent_t findNode(parent_t x) {
|
93 | // while x's parent is not itself
|
94 | while (x != b2p(get(x))) {
|
95 | block_t xState = get(x);
|
96 | // yield x's parent's parent
|
97 | parent_t newParent = b2p(get(b2p(xState)));
|
98 | // construct block out of the original rank and the new parent
|
99 | block_t newState = pr2b(newParent, b2r(xState));
|
100 |
|
101 | this->get(x).compare_exchange_strong(xState, newState);
|
102 |
|
103 | x = newParent;
|
104 | }
|
105 | return x;
|
106 | }
|
107 |
|
108 | private:
|
109 | /**
|
110 | * Update the root of the tree of which x is, to have y as the base instead
|
111 | * @param x : old root
|
112 | * @param oldrank : old root rank
|
113 | * @param y : new root
|
114 | * @param newrank : new root rank
|
115 | * @return Whether the update succeeded (fails if another root update/union has been perfomed in the
|
116 | * interim)
|
117 | */
|
118 | bool updateRoot(const parent_t x, const rank_t oldrank, const parent_t y, const rank_t newrank) {
|
119 | block_t oldState = get(x);
|
120 | parent_t nextN = b2p(oldState);
|
121 | rank_t rankN = b2r(oldState);
|
122 |
|
123 | if (nextN != x || rankN != oldrank) return false;
|
124 | // set the parent and rank of the new record
|
125 | block_t newVal = pr2b(y, newrank);
|
126 |
|
127 | return this->get(x).compare_exchange_strong(oldState, newVal);
|
128 | }
|
129 |
|
130 | public:
|
131 | /**
|
132 | * Clears the DisjointSet of all nodes
|
133 | * Invalidates all iterators
|
134 | */
|
135 | void clear() {
|
136 | a_blocks.clear();
|
137 | }
|
138 |
|
139 | /**
|
140 | * Check whether the two indices are in the same set
|
141 | * @param x node to be checked
|
142 | * @param y node to be checked
|
143 | * @return where the two indices are in the same set
|
144 | */
|
145 | bool sameSet(parent_t x, parent_t y) {
|
146 | while (true) {
|
147 | x = findNode(x);
|
148 | y = findNode(y);
|
149 | if (x == y) return true;
|
150 | // if x's parent is itself, they are not the same set
|
151 | if (b2p(get(x)) == x) return false;
|
152 | }
|
153 | }
|
154 |
|
155 | /**
|
156 | * Union the two specified index nodes
|
157 | * @param x node to be unioned
|
158 | * @param y node to be unioned
|
159 | */
|
160 | void unionNodes(parent_t x, parent_t y) {
|
161 | while (true) {
|
162 | x = findNode(x);
|
163 | y = findNode(y);
|
164 |
|
165 | // no need to union if both already in same set
|
166 | if (x == y) return;
|
167 |
|
168 | rank_t xrank = b2r(get(x));
|
169 | rank_t yrank = b2r(get(y));
|
170 |
|
171 | // if x comes before y (better rank or earlier & equal node)
|
172 | if (xrank > yrank || ((xrank == yrank) && x > y)) {
|
173 | std::swap(x, y);
|
174 | std::swap(xrank, yrank);
|
175 | }
|
176 | // join the trees together
|
177 | // perhaps we can optimise the use of compare_exchange_strong here, as we're in a pessimistic loop
|
178 | if (!updateRoot(x, xrank, y, yrank)) {
|
179 | continue;
|
180 | }
|
181 | // make sure that the ranks are orderable
|
182 | if (xrank == yrank) {
|
183 | updateRoot(y, yrank, y, yrank + 1);
|
184 | }
|
185 | break;
|
186 | }
|
187 | }
|
188 |
|
189 | /**
|
190 | * Create a node with its parent as itself, rank 0
|
191 | * @return the newly created block
|
192 | */
|
193 | inline block_t makeNode() {
|
194 | // make node and find out where we've added it
|
195 | std::size_t nodeDetails = a_blocks.createNode();
|
196 |
|
197 | a_blocks.get(nodeDetails).store(pr2b(nodeDetails, 0));
|
198 |
|
199 | return a_blocks.get(nodeDetails).load();
|
200 | };
|
201 |
|
202 | /**
|
203 | * Extract parent from block
|
204 | * @param inblock the block to be masked
|
205 | * @return The parent_t contained in the upper half of block_t
|
206 | */
|
207 | static inline parent_t b2p(const block_t inblock) {
|
208 | return (parent_t)(inblock >> split_size);
|
209 | };
|
210 |
|
211 | /**
|
212 | * Extract rank from block
|
213 | * @param inblock the block to be masked
|
214 | * @return the rank_t contained in the lower half of block_t
|
215 | */
|
216 | static inline rank_t b2r(const block_t inblock) {
|
217 | return (rank_t)(inblock & rank_mask);
|
218 | };
|
219 |
|
220 | /**
|
221 | * Yield a block given parent and rank
|
222 | * @param parent the top half bits
|
223 | * @param rank the lower half bits
|
224 | * @return the resultant block after merge
|
225 | */
|
226 | static inline block_t pr2b(const parent_t parent, const rank_t rank) {
|
227 | return (((block_t)parent) << split_size) | rank;
|
228 | };
|
229 | };
|
230 |
|
231 | template <typename StorePair>
|
232 | struct EqrelMapComparator {
|
233 | int operator()(const StorePair& a, const StorePair& b) {
|
234 | if (a.first < b.first) {
|
235 | return -1;
|
236 | } else if (b.first < a.first) {
|
237 | return 1;
|
238 | } else {
|
239 | return 0;
|
240 | }
|
241 | }
|
242 |
|
243 | bool less(const StorePair& a, const StorePair& b) {
|
244 | return operator()(a, b) < 0;
|
245 | }
|
246 |
|
247 | bool equal(const StorePair& a, const StorePair& b) {
|
248 | return operator()(a, b) == 0;
|
249 | }
|
250 | };
|
251 |
|
252 | template <typename SparseDomain>
|
253 | class SparseDisjointSet {
|
254 | DisjointSet ds;
|
255 |
|
256 | template <typename TupleType>
|
257 | friend class EquivalenceRelation;
|
258 |
|
259 | using PairStore = std::pair<SparseDomain, parent_t>;
|
260 | using SparseMap =
|
261 | LambdaBTreeSet<PairStore, std::function<parent_t(PairStore&)>, EqrelMapComparator<PairStore>>;
|
262 | using DenseMap = RandomInsertPiggyList<SparseDomain>;
|
263 |
|
264 | typename SparseMap::operation_hints last_ins;
|
265 |
|
266 | SparseMap sparseToDenseMap;
|
267 | // mapping from union-find val to souffle, union-find encoded as index
|
268 | DenseMap denseToSparseMap;
|
269 |
|
270 | public:
|
271 | /**
|
272 | * Retrieve dense encoding, adding it in if non-existent
|
273 | * @param in the sparse value
|
274 | * @return the corresponding dense value
|
275 | */
|
276 | parent_t toDense(const SparseDomain in) {
|
277 | // insert into the mapping - if the key doesn't exist (in), the function will be called
|
278 | // and a dense value will be created for it
|
279 | PairStore p = {in, -1};
|
280 | return sparseToDenseMap.insert(p, [&](PairStore& p) {
|
281 | parent_t c2 = DisjointSet::b2p(this->ds.makeNode());
|
282 | this->denseToSparseMap.insertAt(c2, p.first);
|
283 | p.second = c2;
|
284 | return c2;
|
285 | });
|
286 | }
|
287 |
|
288 | public:
|
289 | SparseDisjointSet() = default;
|
290 |
|
291 | // copy ctor
|
292 | SparseDisjointSet(SparseDisjointSet& other) = delete;
|
293 |
|
294 | // move ctor
|
295 | SparseDisjointSet(SparseDisjointSet&& other) = delete;
|
296 |
|
297 | // copy assign ctor
|
298 | SparseDisjointSet& operator=(SparseDisjointSet& other) = delete;
|
299 |
|
300 | // move assign ctor
|
301 | SparseDisjointSet& operator=(SparseDisjointSet&& other) = delete;
|
302 |
|
303 | /**
|
304 | * For the given dense value, return the associated sparse value
|
305 | * Undefined behaviour if dense value not in set
|
306 | * @param in the supplied dense value
|
307 | * @return the sparse value from the denseToSparseMap
|
308 | */
|
309 | inline const SparseDomain toSparse(const parent_t in) const {
|
310 | return denseToSparseMap.get(in);
|
311 | };
|
312 |
|
313 | /* a wrapper to enable checking in the sparse set - however also adds them if not already existing */
|
314 | inline bool sameSet(SparseDomain x, SparseDomain y) {
|
315 | return ds.sameSet(toDense(x), toDense(y));
|
316 | };
|
317 | /* finds the node in the underlying disjoint set, adding the node if non-existent */
|
318 | inline SparseDomain findNode(SparseDomain x) {
|
319 | return toSparse(ds.findNode(toDense(x)));
|
320 | };
|
321 | /* union the nodes, add if not existing */
|
322 | inline void unionNodes(SparseDomain x, SparseDomain y) {
|
323 | ds.unionNodes(toDense(x), toDense(y));
|
324 | };
|
325 |
|
326 | inline std::size_t size() {
|
327 | return ds.size();
|
328 | };
|
329 |
|
330 | /**
|
331 | * Remove all elements from this disjoint set
|
332 | */
|
333 | void clear() {
|
334 | ds.clear();
|
335 | sparseToDenseMap.clear();
|
336 | denseToSparseMap.clear();
|
337 | }
|
338 |
|
339 | /* wrapper for node creation */
|
340 | inline void makeNode(SparseDomain val) {
|
341 | // dense has the behaviour of creating if not exists.
|
342 | toDense(val);
|
343 | };
|
344 |
|
345 | /* whether the supplied node exists */
|
346 | inline bool nodeExists(const SparseDomain val) const {
|
347 | return sparseToDenseMap.contains({val, -1});
|
348 | };
|
349 |
|
350 | inline bool contains(SparseDomain v1, SparseDomain v2) {
|
351 | if (nodeExists(v1) && nodeExists(v2)) {
|
352 | return sameSet(v1, v2);
|
353 | }
|
354 | return false;
|
355 | }
|
356 | };
|
357 | } // namespace souffle
|