| 1 | /*
 | 
| 2 |  * Souffle - A Datalog Compiler
 | 
| 3 |  * Copyright (c) 2017 The Souffle Developers. All rights reserved
 | 
| 4 |  * Licensed under the Universal Permissive License v 1.0 as shown at:
 | 
| 5 |  * - https://opensource.org/licenses/UPL
 | 
| 6 |  * - <souffle root>/licenses/SOUFFLE-UPL.txt
 | 
| 7 |  */
 | 
| 8 | 
 | 
| 9 | /************************************************************************
 | 
| 10 |  *
 | 
| 11 |  * @file UnionFind.h
 | 
| 12 |  *
 | 
| 13 |  * Defines a union-find data-structure
 | 
| 14 |  *
 | 
| 15 |  ***********************************************************************/
 | 
| 16 | 
 | 
| 17 | #pragma once
 | 
| 18 | 
 | 
| 19 | #include "souffle/datastructure/LambdaBTree.h"
 | 
| 20 | #include "souffle/datastructure/PiggyList.h"
 | 
| 21 | #include "souffle/utility/MiscUtil.h"
 | 
| 22 | #include <atomic>
 | 
| 23 | #include <cstddef>
 | 
| 24 | #include <cstdint>
 | 
| 25 | #include <functional>
 | 
| 26 | #include <utility>
 | 
| 27 | 
 | 
| 28 | namespace souffle {
 | 
| 29 | 
 | 
| 30 | // branch predictor hacks
 | 
| 31 | #define unlikely(x) __builtin_expect((x), 0)
 | 
| 32 | #define likely(x) __builtin_expect((x), 1)
 | 
| 33 | 
 | 
| 34 | using rank_t = uint8_t;
 | 
| 35 | /* technically uint56_t, but, doesn't exist. Just be careful about storing > 2^56 elements. */
 | 
| 36 | using parent_t = uint64_t;
 | 
| 37 | 
 | 
| 38 | // number of bits that the rank is
 | 
| 39 | constexpr uint8_t split_size = 8u;
 | 
| 40 | 
 | 
| 41 | // block_t stores parent in the upper half, rank in the lower half
 | 
| 42 | using block_t = uint64_t;
 | 
| 43 | // block_t & rank_mask extracts the rank
 | 
| 44 | constexpr block_t rank_mask = (1ul << split_size) - 1;
 | 
| 45 | 
 | 
| 46 | /**
 | 
| 47 |  * Structure that emulates a Disjoint Set, i.e. a data structure that supports efficient union-find operations
 | 
| 48 |  */
 | 
| 49 | class DisjointSet {
 | 
| 50 |     template <typename TupleType>
 | 
| 51 |     friend class EquivalenceRelation;
 | 
| 52 | 
 | 
| 53 |     PiggyList<std::atomic<block_t>> a_blocks;
 | 
| 54 | 
 | 
| 55 | public:
 | 
| 56 |     DisjointSet() = default;
 | 
| 57 | 
 | 
| 58 |     // copy ctor
 | 
| 59 |     DisjointSet(DisjointSet& other) = delete;
 | 
| 60 |     // move ctor
 | 
| 61 |     DisjointSet(DisjointSet&& other) = delete;
 | 
| 62 | 
 | 
| 63 |     // copy assign ctor
 | 
| 64 |     DisjointSet& operator=(DisjointSet& ds) = delete;
 | 
| 65 |     // move assign ctor
 | 
| 66 |     DisjointSet& operator=(DisjointSet&& ds) = delete;
 | 
| 67 | 
 | 
| 68 |     /**
 | 
| 69 |      * Return the number of elements in this disjoint set (not the number of pairs)
 | 
| 70 |      */
 | 
| 71 |     inline std::size_t size() {
 | 
| 72 |         auto sz = a_blocks.size();
 | 
| 73 |         return sz;
 | 
| 74 |     };
 | 
| 75 | 
 | 
| 76 |     /**
 | 
| 77 |      * Yield reference to the node by its node index
 | 
| 78 |      * @param node node to be searched
 | 
| 79 |      * @return the parent block of the specified node
 | 
| 80 |      */
 | 
| 81 |     inline std::atomic<block_t>& get(parent_t node) const {
 | 
| 82 |         auto& ret = a_blocks.get(node);
 | 
| 83 |         return ret;
 | 
| 84 |     };
 | 
| 85 | 
 | 
| 86 |     /**
 | 
| 87 |      * Equivalent to the find() function in union/find
 | 
| 88 |      * Find the highest ancestor of the provided node - flattening as we go
 | 
| 89 |      * @param x the node to find the parent of, whilst flattening its set-tree
 | 
| 90 |      * @return The parent of x
 | 
| 91 |      */
 | 
| 92 |     parent_t findNode(parent_t x) {
 | 
| 93 |         // while x's parent is not itself
 | 
| 94 |         while (x != b2p(get(x))) {
 | 
| 95 |             block_t xState = get(x);
 | 
| 96 |             // yield x's parent's parent
 | 
| 97 |             parent_t newParent = b2p(get(b2p(xState)));
 | 
| 98 |             // construct block out of the original rank and the new parent
 | 
| 99 |             block_t newState = pr2b(newParent, b2r(xState));
 | 
| 100 | 
 | 
| 101 |             this->get(x).compare_exchange_strong(xState, newState);
 | 
| 102 | 
 | 
| 103 |             x = newParent;
 | 
| 104 |         }
 | 
| 105 |         return x;
 | 
| 106 |     }
 | 
| 107 | 
 | 
| 108 | private:
 | 
| 109 |     /**
 | 
| 110 |      * Update the root of the tree of which x is, to have y as the base instead
 | 
| 111 |      * @param x : old root
 | 
| 112 |      * @param oldrank : old root rank
 | 
| 113 |      * @param y : new root
 | 
| 114 |      * @param newrank : new root rank
 | 
| 115 |      * @return Whether the update succeeded (fails if another root update/union has been perfomed in the
 | 
| 116 |      * interim)
 | 
| 117 |      */
 | 
| 118 |     bool updateRoot(const parent_t x, const rank_t oldrank, const parent_t y, const rank_t newrank) {
 | 
| 119 |         block_t oldState = get(x);
 | 
| 120 |         parent_t nextN = b2p(oldState);
 | 
| 121 |         rank_t rankN = b2r(oldState);
 | 
| 122 | 
 | 
| 123 |         if (nextN != x || rankN != oldrank) return false;
 | 
| 124 |         // set the parent and rank of the new record
 | 
| 125 |         block_t newVal = pr2b(y, newrank);
 | 
| 126 | 
 | 
| 127 |         return this->get(x).compare_exchange_strong(oldState, newVal);
 | 
| 128 |     }
 | 
| 129 | 
 | 
| 130 | public:
 | 
| 131 |     /**
 | 
| 132 |      * Clears the DisjointSet of all nodes
 | 
| 133 |      * Invalidates all iterators
 | 
| 134 |      */
 | 
| 135 |     void clear() {
 | 
| 136 |         a_blocks.clear();
 | 
| 137 |     }
 | 
| 138 | 
 | 
| 139 |     /**
 | 
| 140 |      * Check whether the two indices are in the same set
 | 
| 141 |      * @param x node to be checked
 | 
| 142 |      * @param y node to be checked
 | 
| 143 |      * @return where the two indices are in the same set
 | 
| 144 |      */
 | 
| 145 |     bool sameSet(parent_t x, parent_t y) {
 | 
| 146 |         while (true) {
 | 
| 147 |             x = findNode(x);
 | 
| 148 |             y = findNode(y);
 | 
| 149 |             if (x == y) return true;
 | 
| 150 |             // if x's parent is itself, they are not the same set
 | 
| 151 |             if (b2p(get(x)) == x) return false;
 | 
| 152 |         }
 | 
| 153 |     }
 | 
| 154 | 
 | 
| 155 |     /**
 | 
| 156 |      * Union the two specified index nodes
 | 
| 157 |      * @param x node to be unioned
 | 
| 158 |      * @param y node to be unioned
 | 
| 159 |      */
 | 
| 160 |     void unionNodes(parent_t x, parent_t y) {
 | 
| 161 |         while (true) {
 | 
| 162 |             x = findNode(x);
 | 
| 163 |             y = findNode(y);
 | 
| 164 | 
 | 
| 165 |             // no need to union if both already in same set
 | 
| 166 |             if (x == y) return;
 | 
| 167 | 
 | 
| 168 |             rank_t xrank = b2r(get(x));
 | 
| 169 |             rank_t yrank = b2r(get(y));
 | 
| 170 | 
 | 
| 171 |             // if x comes before y (better rank or earlier & equal node)
 | 
| 172 |             if (xrank > yrank || ((xrank == yrank) && x > y)) {
 | 
| 173 |                 std::swap(x, y);
 | 
| 174 |                 std::swap(xrank, yrank);
 | 
| 175 |             }
 | 
| 176 |             // join the trees together
 | 
| 177 |             // perhaps we can optimise the use of compare_exchange_strong here, as we're in a pessimistic loop
 | 
| 178 |             if (!updateRoot(x, xrank, y, yrank)) {
 | 
| 179 |                 continue;
 | 
| 180 |             }
 | 
| 181 |             // make sure that the ranks are orderable
 | 
| 182 |             if (xrank == yrank) {
 | 
| 183 |                 updateRoot(y, yrank, y, yrank + 1);
 | 
| 184 |             }
 | 
| 185 |             break;
 | 
| 186 |         }
 | 
| 187 |     }
 | 
| 188 | 
 | 
| 189 |     /**
 | 
| 190 |      * Create a node with its parent as itself, rank 0
 | 
| 191 |      * @return the newly created block
 | 
| 192 |      */
 | 
| 193 |     inline block_t makeNode() {
 | 
| 194 |         // make node and find out where we've added it
 | 
| 195 |         std::size_t nodeDetails = a_blocks.createNode();
 | 
| 196 | 
 | 
| 197 |         a_blocks.get(nodeDetails).store(pr2b(nodeDetails, 0));
 | 
| 198 | 
 | 
| 199 |         return a_blocks.get(nodeDetails).load();
 | 
| 200 |     };
 | 
| 201 | 
 | 
| 202 |     /**
 | 
| 203 |      * Extract parent from block
 | 
| 204 |      * @param inblock the block to be masked
 | 
| 205 |      * @return The parent_t contained in the upper half of block_t
 | 
| 206 |      */
 | 
| 207 |     static inline parent_t b2p(const block_t inblock) {
 | 
| 208 |         return (parent_t)(inblock >> split_size);
 | 
| 209 |     };
 | 
| 210 | 
 | 
| 211 |     /**
 | 
| 212 |      * Extract rank from block
 | 
| 213 |      * @param inblock the block to be masked
 | 
| 214 |      * @return the rank_t contained in the lower half of block_t
 | 
| 215 |      */
 | 
| 216 |     static inline rank_t b2r(const block_t inblock) {
 | 
| 217 |         return (rank_t)(inblock & rank_mask);
 | 
| 218 |     };
 | 
| 219 | 
 | 
| 220 |     /**
 | 
| 221 |      * Yield a block given parent and rank
 | 
| 222 |      * @param parent the top half bits
 | 
| 223 |      * @param rank the lower half bits
 | 
| 224 |      * @return the resultant block after merge
 | 
| 225 |      */
 | 
| 226 |     static inline block_t pr2b(const parent_t parent, const rank_t rank) {
 | 
| 227 |         return (((block_t)parent) << split_size) | rank;
 | 
| 228 |     };
 | 
| 229 | };
 | 
| 230 | 
 | 
| 231 | template <typename StorePair>
 | 
| 232 | struct EqrelMapComparator {
 | 
| 233 |     int operator()(const StorePair& a, const StorePair& b) {
 | 
| 234 |         if (a.first < b.first) {
 | 
| 235 |             return -1;
 | 
| 236 |         } else if (b.first < a.first) {
 | 
| 237 |             return 1;
 | 
| 238 |         } else {
 | 
| 239 |             return 0;
 | 
| 240 |         }
 | 
| 241 |     }
 | 
| 242 | 
 | 
| 243 |     bool less(const StorePair& a, const StorePair& b) {
 | 
| 244 |         return operator()(a, b) < 0;
 | 
| 245 |     }
 | 
| 246 | 
 | 
| 247 |     bool equal(const StorePair& a, const StorePair& b) {
 | 
| 248 |         return operator()(a, b) == 0;
 | 
| 249 |     }
 | 
| 250 | };
 | 
| 251 | 
 | 
| 252 | template <typename SparseDomain>
 | 
| 253 | class SparseDisjointSet {
 | 
| 254 |     DisjointSet ds;
 | 
| 255 | 
 | 
| 256 |     template <typename TupleType>
 | 
| 257 |     friend class EquivalenceRelation;
 | 
| 258 | 
 | 
| 259 |     using PairStore = std::pair<SparseDomain, parent_t>;
 | 
| 260 |     using SparseMap =
 | 
| 261 |             LambdaBTreeSet<PairStore, std::function<parent_t(PairStore&)>, EqrelMapComparator<PairStore>>;
 | 
| 262 |     using DenseMap = RandomInsertPiggyList<SparseDomain>;
 | 
| 263 | 
 | 
| 264 |     typename SparseMap::operation_hints last_ins;
 | 
| 265 | 
 | 
| 266 |     SparseMap sparseToDenseMap;
 | 
| 267 |     // mapping from union-find val to souffle, union-find encoded as index
 | 
| 268 |     DenseMap denseToSparseMap;
 | 
| 269 | 
 | 
| 270 | public:
 | 
| 271 |     /**
 | 
| 272 |      * Retrieve dense encoding, adding it in if non-existent
 | 
| 273 |      * @param in the sparse value
 | 
| 274 |      * @return the corresponding dense value
 | 
| 275 |      */
 | 
| 276 |     parent_t toDense(const SparseDomain in) {
 | 
| 277 |         // insert into the mapping - if the key doesn't exist (in), the function will be called
 | 
| 278 |         // and a dense value will be created for it
 | 
| 279 |         PairStore p = {in, -1};
 | 
| 280 |         return sparseToDenseMap.insert(p, [&](PairStore& p) {
 | 
| 281 |             parent_t c2 = DisjointSet::b2p(this->ds.makeNode());
 | 
| 282 |             this->denseToSparseMap.insertAt(c2, p.first);
 | 
| 283 |             p.second = c2;
 | 
| 284 |             return c2;
 | 
| 285 |         });
 | 
| 286 |     }
 | 
| 287 | 
 | 
| 288 | public:
 | 
| 289 |     SparseDisjointSet() = default;
 | 
| 290 | 
 | 
| 291 |     // copy ctor
 | 
| 292 |     SparseDisjointSet(SparseDisjointSet& other) = delete;
 | 
| 293 | 
 | 
| 294 |     // move ctor
 | 
| 295 |     SparseDisjointSet(SparseDisjointSet&& other) = delete;
 | 
| 296 | 
 | 
| 297 |     // copy assign ctor
 | 
| 298 |     SparseDisjointSet& operator=(SparseDisjointSet& other) = delete;
 | 
| 299 | 
 | 
| 300 |     // move assign ctor
 | 
| 301 |     SparseDisjointSet& operator=(SparseDisjointSet&& other) = delete;
 | 
| 302 | 
 | 
| 303 |     /**
 | 
| 304 |      * For the given dense value, return the associated sparse value
 | 
| 305 |      *   Undefined behaviour if dense value not in set
 | 
| 306 |      * @param in the supplied dense value
 | 
| 307 |      * @return the sparse value from the denseToSparseMap
 | 
| 308 |      */
 | 
| 309 |     inline const SparseDomain toSparse(const parent_t in) const {
 | 
| 310 |         return denseToSparseMap.get(in);
 | 
| 311 |     };
 | 
| 312 | 
 | 
| 313 |     /* a wrapper to enable checking in the sparse set - however also adds them if not already existing */
 | 
| 314 |     inline bool sameSet(SparseDomain x, SparseDomain y) {
 | 
| 315 |         return ds.sameSet(toDense(x), toDense(y));
 | 
| 316 |     };
 | 
| 317 |     /* finds the node in the underlying disjoint set, adding the node if non-existent */
 | 
| 318 |     inline SparseDomain findNode(SparseDomain x) {
 | 
| 319 |         return toSparse(ds.findNode(toDense(x)));
 | 
| 320 |     };
 | 
| 321 |     /* union the nodes, add if not existing */
 | 
| 322 |     inline void unionNodes(SparseDomain x, SparseDomain y) {
 | 
| 323 |         ds.unionNodes(toDense(x), toDense(y));
 | 
| 324 |     };
 | 
| 325 | 
 | 
| 326 |     inline std::size_t size() {
 | 
| 327 |         return ds.size();
 | 
| 328 |     };
 | 
| 329 | 
 | 
| 330 |     /**
 | 
| 331 |      * Remove all elements from this disjoint set
 | 
| 332 |      */
 | 
| 333 |     void clear() {
 | 
| 334 |         ds.clear();
 | 
| 335 |         sparseToDenseMap.clear();
 | 
| 336 |         denseToSparseMap.clear();
 | 
| 337 |     }
 | 
| 338 | 
 | 
| 339 |     /* wrapper for node creation */
 | 
| 340 |     inline void makeNode(SparseDomain val) {
 | 
| 341 |         // dense has the behaviour of creating if not exists.
 | 
| 342 |         toDense(val);
 | 
| 343 |     };
 | 
| 344 | 
 | 
| 345 |     /* whether the supplied node exists */
 | 
| 346 |     inline bool nodeExists(const SparseDomain val) const {
 | 
| 347 |         return sparseToDenseMap.contains({val, -1});
 | 
| 348 |     };
 | 
| 349 | 
 | 
| 350 |     inline bool contains(SparseDomain v1, SparseDomain v2) {
 | 
| 351 |         if (nodeExists(v1) && nodeExists(v2)) {
 | 
| 352 |             return sameSet(v1, v2);
 | 
| 353 |         }
 | 
| 354 |         return false;
 | 
| 355 |     }
 | 
| 356 | };
 | 
| 357 | }  // namespace souffle
 |