| 1 | /*
 | 
| 2 |  * Souffle - A Datalog Compiler
 | 
| 3 |  * Copyright (c) 2017 The Souffle Developers. All rights reserved
 | 
| 4 |  * Licensed under the Universal Permissive License v 1.0 as shown at:
 | 
| 5 |  * - https://opensource.org/licenses/UPL
 | 
| 6 |  * - <souffle root>/licenses/SOUFFLE-UPL.txt
 | 
| 7 |  */
 | 
| 8 | 
 | 
| 9 | /************************************************************************
 | 
| 10 |  *
 | 
| 11 |  * @file EquivalenceRelation.h
 | 
| 12 |  *
 | 
| 13 |  * Defines a binary relation interface to be used with Souffle as a relational store.
 | 
| 14 |  * Pairs inserted into this relation implicitly store a reflexive, symmetric, and transitive relation
 | 
| 15 |  * with each other.
 | 
| 16 |  *
 | 
| 17 |  ***********************************************************************/
 | 
| 18 | 
 | 
| 19 | #pragma once
 | 
| 20 | 
 | 
| 21 | #include "souffle/RamTypes.h"
 | 
| 22 | #include "souffle/datastructure/LambdaBTree.h"
 | 
| 23 | #include "souffle/datastructure/PiggyList.h"
 | 
| 24 | #include "souffle/datastructure/UnionFind.h"
 | 
| 25 | #include "souffle/utility/ContainerUtil.h"
 | 
| 26 | #include "souffle/utility/ParallelUtil.h"
 | 
| 27 | #include <atomic>
 | 
| 28 | #include <cassert>
 | 
| 29 | #include <cstddef>
 | 
| 30 | #include <functional>
 | 
| 31 | #include <iostream>
 | 
| 32 | #include <iterator>
 | 
| 33 | #include <set>
 | 
| 34 | #include <shared_mutex>
 | 
| 35 | #include <stdexcept>
 | 
| 36 | #include <tuple>
 | 
| 37 | #include <unordered_set>
 | 
| 38 | #include <utility>
 | 
| 39 | #include <vector>
 | 
| 40 | 
 | 
| 41 | namespace souffle {
 | 
| 42 | template <typename TupleType>
 | 
| 43 | class EquivalenceRelation {
 | 
| 44 |     using value_type = typename TupleType::value_type;
 | 
| 45 | 
 | 
| 46 |     // mapping from representative to disjoint set
 | 
| 47 |     // just a cache, essentially, used for iteration over
 | 
| 48 |     using StatesList = souffle::PiggyList<value_type>;
 | 
| 49 |     using StatesBucket = StatesList*;
 | 
| 50 |     using StorePair = std::pair<value_type, StatesBucket>;
 | 
| 51 |     using StatesMap = souffle::LambdaBTreeSet<StorePair, std::function<StatesBucket(StorePair&)>,
 | 
| 52 |             souffle::EqrelMapComparator<StorePair>>;
 | 
| 53 | 
 | 
| 54 | public:
 | 
| 55 |     using element_type = TupleType;
 | 
| 56 | 
 | 
| 57 |     EquivalenceRelation() : statesMapStale(false){};
 | 
| 58 |     ~EquivalenceRelation() {
 | 
| 59 |         emptyPartition();
 | 
| 60 |     }
 | 
| 61 | 
 | 
| 62 |     /**
 | 
| 63 |      * A collection of operation hints speeding up some of the involved operations
 | 
| 64 |      * by exploiting temporal locality.
 | 
| 65 |      * Unused in this class, as there is no speedup to be gained.
 | 
| 66 |      * This is just defined as the class expects it.
 | 
| 67 |      */
 | 
| 68 |     struct operation_hints {
 | 
| 69 |         // resets all hints (to be triggered e.g. when deleting nodes)
 | 
| 70 |         void clear() {}
 | 
| 71 |     };
 | 
| 72 | 
 | 
| 73 |     /**
 | 
| 74 |      * Insert the two values symbolically as a binary relation
 | 
| 75 |      * @param x node to be added/paired
 | 
| 76 |      * @param y node to be added/paired
 | 
| 77 |      * @return true if the pair is new to the data structure
 | 
| 78 |      */
 | 
| 79 |     bool insert(value_type x, value_type y) {
 | 
| 80 |         operation_hints z;
 | 
| 81 |         return insert(x, y, z);
 | 
| 82 |     };
 | 
| 83 | 
 | 
| 84 |     /**
 | 
| 85 |      * Insert the tuple symbolically.
 | 
| 86 |      * @param tuple The tuple to be inserted
 | 
| 87 |      * @return true if the tuple is new to the data structure
 | 
| 88 |      */
 | 
| 89 |     bool insert(const TupleType& tuple) {
 | 
| 90 |         operation_hints hints;
 | 
| 91 |         return insert(tuple[0], tuple[1], hints);
 | 
| 92 |     };
 | 
| 93 | 
 | 
| 94 |     /**
 | 
| 95 |      * Insert the two values symbolically as a binary relation
 | 
| 96 |      * @param x node to be added/paired
 | 
| 97 |      * @param y node to be added/paired
 | 
| 98 |      * @param z the hints to where the pair should be inserted (not applicable atm)
 | 
| 99 |      * @return true if the pair is new to the data structure
 | 
| 100 |      */
 | 
| 101 |     bool insert(value_type x, value_type y, operation_hints) {
 | 
| 102 |         // indicate that iterators will have to generate on request
 | 
| 103 |         this->statesMapStale.store(true, std::memory_order_relaxed);
 | 
| 104 |         bool retval = !contains(x, y);
 | 
| 105 |         sds.unionNodes(x, y);
 | 
| 106 |         return retval;
 | 
| 107 |     }
 | 
| 108 | 
 | 
| 109 |     /**
 | 
| 110 |      * inserts all nodes from the other relation into this one
 | 
| 111 |      * @param other the binary relation from which to add elements from
 | 
| 112 |      */
 | 
| 113 |     void insertAll(const EquivalenceRelation<TupleType>& other) {
 | 
| 114 |         other.genAllDisjointSetLists();
 | 
| 115 | 
 | 
| 116 |         // iterate over partitions at a time
 | 
| 117 |         for (auto&& [rep, pl] : other.equivalencePartition) {
 | 
| 118 |             const std::size_t ksize = pl->size();
 | 
| 119 |             for (std::size_t i = 0; i < ksize; ++i) {
 | 
| 120 |                 this->sds.unionNodes(rep, pl->get(i));
 | 
| 121 |             }
 | 
| 122 |         }
 | 
| 123 |         // invalidate iterators unconditionally
 | 
| 124 |         this->statesMapStale.store(true, std::memory_order_relaxed);
 | 
| 125 |     }
 | 
| 126 | 
 | 
| 127 |     /**
 | 
| 128 |      * Extend this relation with another relation, expanding this equivalence
 | 
| 129 |      * relation and inserting it into the other relation.
 | 
| 130 |      *
 | 
| 131 |      * The supplied relation is the old knowledge, whilst this relation only
 | 
| 132 |      * contains explicitly new knowledge. After this operation the "implicitly
 | 
| 133 |      * new tuples" are now explicitly inserted this relation, and all of the new
 | 
| 134 |      * tuples in this relation are inserted into the old relation.
 | 
| 135 |      */
 | 
| 136 |     void extendAndInsert(EquivalenceRelation<TupleType>& other) {
 | 
| 137 |         if (other.size() == 0 && this->size() == 0) return;
 | 
| 138 | 
 | 
| 139 |         std::unordered_set<value_type> repsCovered;
 | 
| 140 | 
 | 
| 141 |         // This vector holds all of the elements of this equivalence relation
 | 
| 142 |         // that aren't yet in other, which get inserted after extending this
 | 
| 143 |         // relation by other. These operations are interleaved for maximum
 | 
| 144 |         // efficiency - either extend or inserting first would make the other
 | 
| 145 |         // operation unnecessarily slow.
 | 
| 146 |         std::vector<std::pair<value_type, value_type>> toInsert;
 | 
| 147 |         auto size = std::distance(this->sds.sparseToDenseMap.begin(), this->sds.sparseToDenseMap.end());
 | 
| 148 |         toInsert.reserve(size);
 | 
| 149 | 
 | 
| 150 |         // find all the disjoint sets that need to be added to this relation
 | 
| 151 |         // that exist in other (and exist in this)
 | 
| 152 |         {
 | 
| 153 |             auto it = this->sds.sparseToDenseMap.begin();
 | 
| 154 |             auto end = this->sds.sparseToDenseMap.end();
 | 
| 155 |             value_type el;
 | 
| 156 |             for (; it != end; ++it) {
 | 
| 157 |                 std::tie(el, std::ignore) = *it;
 | 
| 158 |                 if (other.containsElement(el)) {
 | 
| 159 |                     value_type rep = other.sds.findNode(el);
 | 
| 160 |                     if (repsCovered.count(rep) == 0) {
 | 
| 161 |                         repsCovered.emplace(rep);
 | 
| 162 |                     }
 | 
| 163 |                 }
 | 
| 164 |                 toInsert.emplace_back(el, this->sds.findNode(el));
 | 
| 165 |             }
 | 
| 166 |         }
 | 
| 167 |         assert(size >= 0);
 | 
| 168 |         assert(toInsert.size() == (std::size_t)size);
 | 
| 169 | 
 | 
| 170 |         // add the intersecting dj sets into this one
 | 
| 171 |         {
 | 
| 172 |             value_type el;
 | 
| 173 |             value_type rep;
 | 
| 174 |             auto it = other.sds.sparseToDenseMap.begin();
 | 
| 175 |             auto end = other.sds.sparseToDenseMap.end();
 | 
| 176 |             for (; it != end; ++it) {
 | 
| 177 |                 std::tie(el, std::ignore) = *it;
 | 
| 178 |                 rep = other.sds.findNode(el);
 | 
| 179 |                 if (repsCovered.count(rep) != 0) {
 | 
| 180 |                     this->insert(el, rep);
 | 
| 181 |                 }
 | 
| 182 |             }
 | 
| 183 |         }
 | 
| 184 | 
 | 
| 185 |         // Insert all new tuples from this relation into the old relation
 | 
| 186 |         {
 | 
| 187 |             value_type el;
 | 
| 188 |             value_type rep;
 | 
| 189 |             for (std::pair<value_type, value_type> p : toInsert) {
 | 
| 190 |                 std::tie(el, rep) = p;
 | 
| 191 |                 other.insert(el, rep);
 | 
| 192 |             }
 | 
| 193 |         }
 | 
| 194 |     }
 | 
| 195 | 
 | 
| 196 |     /**
 | 
| 197 |      * Returns whether there exists a pair with these two nodes
 | 
| 198 |      * @param x front of pair
 | 
| 199 |      * @param y back of pair
 | 
| 200 |      */
 | 
| 201 |     bool contains(value_type x, value_type y) const {
 | 
| 202 |         return sds.contains(x, y);
 | 
| 203 |     }
 | 
| 204 | 
 | 
| 205 |     /**
 | 
| 206 |      * Returns whether there exists given tuple.
 | 
| 207 |      * @param tuple The tuple to search for.
 | 
| 208 |      */
 | 
| 209 |     bool contains(const TupleType& tuple, operation_hints&) const {
 | 
| 210 |         return contains(tuple[0], tuple[1]);
 | 
| 211 |     };
 | 
| 212 | 
 | 
| 213 |     bool contains(const TupleType& tuple) const {
 | 
| 214 |         return contains(tuple[0], tuple[1]);
 | 
| 215 |     };
 | 
| 216 | 
 | 
| 217 |     void emptyPartition() const {
 | 
| 218 |         // delete the beautiful values inside (they're raw ptrs, so they need to be.)
 | 
| 219 |         for (auto& pair : equivalencePartition) {
 | 
| 220 |             delete pair.second;
 | 
| 221 |         }
 | 
| 222 |         // invalidate it my dude
 | 
| 223 |         this->statesMapStale.store(true, std::memory_order_relaxed);
 | 
| 224 | 
 | 
| 225 |         equivalencePartition.clear();
 | 
| 226 |     }
 | 
| 227 | 
 | 
| 228 |     /**
 | 
| 229 |      * Empty the relation
 | 
| 230 |      */
 | 
| 231 |     void clear() {
 | 
| 232 |         statesLock.lock();
 | 
| 233 | 
 | 
| 234 |         sds.clear();
 | 
| 235 |         emptyPartition();
 | 
| 236 | 
 | 
| 237 |         statesLock.unlock();
 | 
| 238 |     }
 | 
| 239 | 
 | 
| 240 |     /**
 | 
| 241 |      * Size of relation
 | 
| 242 |      * @return the sum of the number of pairs per disjoint set
 | 
| 243 |      */
 | 
| 244 |     std::size_t size() const {
 | 
| 245 |         genAllDisjointSetLists();
 | 
| 246 | 
 | 
| 247 |         statesLock.lock_shared();
 | 
| 248 | 
 | 
| 249 |         std::size_t retVal = 0;
 | 
| 250 |         for (auto& e : this->equivalencePartition) {
 | 
| 251 |             const std::size_t s = e.second->size();
 | 
| 252 |             retVal += s * s;
 | 
| 253 |         }
 | 
| 254 | 
 | 
| 255 |         statesLock.unlock_shared();
 | 
| 256 |         return retVal;
 | 
| 257 |     }
 | 
| 258 | 
 | 
| 259 |     // an almighty iterator for several types of iteration.
 | 
| 260 |     // Unfortunately, subclassing isn't an option with souffle
 | 
| 261 |     //   - we don't deal with pointers (so no virtual)
 | 
| 262 |     //   - and a single iter type is expected (see Relation::iterator e.g.) (i think)
 | 
| 263 |     class iterator {
 | 
| 264 |     public:
 | 
| 265 |         typedef std::forward_iterator_tag iterator_category;
 | 
| 266 |         using value_type = TupleType;
 | 
| 267 |         using difference_type = ptrdiff_t;
 | 
| 268 |         using pointer = value_type*;
 | 
| 269 |         using reference = value_type&;
 | 
| 270 | 
 | 
| 271 |         // one iterator for signalling the end (simplifies)
 | 
| 272 |         explicit iterator(const EquivalenceRelation* br, bool /* signalIsEndIterator */)
 | 
| 273 |                 : br(br), isEndVal(true){};
 | 
| 274 | 
 | 
| 275 |         explicit iterator(const EquivalenceRelation* br)
 | 
| 276 |                 : br(br), ityp(IterType::ALL), djSetMapListIt(br->equivalencePartition.begin()),
 | 
| 277 |                   djSetMapListEnd(br->equivalencePartition.end()) {
 | 
| 278 |             // no need to fast forward if this iterator is empty
 | 
| 279 |             if (djSetMapListIt == djSetMapListEnd) {
 | 
| 280 |                 isEndVal = true;
 | 
| 281 |                 return;
 | 
| 282 |             }
 | 
| 283 |             // grab the pointer to the list, and make it our current list
 | 
| 284 |             djSetList = (*djSetMapListIt).second;
 | 
| 285 |             assert(djSetList->size() != 0);
 | 
| 286 | 
 | 
| 287 |             updateAnterior();
 | 
| 288 |             updatePosterior();
 | 
| 289 |         }
 | 
| 290 | 
 | 
| 291 |         // WITHIN: iterator for everything within the same DJset (used for EquivalenceRelation.partition())
 | 
| 292 |         explicit iterator(const EquivalenceRelation* br, const StatesBucket within)
 | 
| 293 |                 : br(br), ityp(IterType::WITHIN), djSetList(within) {
 | 
| 294 |             // empty dj set
 | 
| 295 |             if (djSetList->size() == 0) {
 | 
| 296 |                 isEndVal = true;
 | 
| 297 |             }
 | 
| 298 | 
 | 
| 299 |             updateAnterior();
 | 
| 300 |             updatePosterior();
 | 
| 301 |         }
 | 
| 302 | 
 | 
| 303 |         // ANTERIOR: iterator that yields all (former, _) \in djset(former) (djset(former) === within)
 | 
| 304 |         explicit iterator(const EquivalenceRelation* br, const typename TupleType::value_type former,
 | 
| 305 |                 const StatesBucket within)
 | 
| 306 |                 : br(br), ityp(IterType::ANTERIOR), djSetList(within) {
 | 
| 307 |             if (djSetList->size() == 0) {
 | 
| 308 |                 isEndVal = true;
 | 
| 309 |             }
 | 
| 310 | 
 | 
| 311 |             setAnterior(former);
 | 
| 312 |             updatePosterior();
 | 
| 313 |         }
 | 
| 314 | 
 | 
| 315 |         // ANTPOST: iterator that yields all (former, latter) \in djset(former), (djset(former) ==
 | 
| 316 |         // djset(latter) == within)
 | 
| 317 |         explicit iterator(const EquivalenceRelation* br, const typename TupleType::value_type former,
 | 
| 318 |                 typename TupleType::value_type latter, const StatesBucket within)
 | 
| 319 |                 : br(br), ityp(IterType::ANTPOST), djSetList(within) {
 | 
| 320 |             if (djSetList->size() == 0) {
 | 
| 321 |                 isEndVal = true;
 | 
| 322 |             }
 | 
| 323 | 
 | 
| 324 |             setAnterior(former);
 | 
| 325 |             setPosterior(latter);
 | 
| 326 |         }
 | 
| 327 | 
 | 
| 328 |         /** explicit set first half of cPair */
 | 
| 329 |         inline void setAnterior(const typename TupleType::value_type a) {
 | 
| 330 |             this->cPair[0] = a;
 | 
| 331 |         }
 | 
| 332 | 
 | 
| 333 |         /** quick update to whatever the current index is pointing to */
 | 
| 334 |         inline void updateAnterior() {
 | 
| 335 |             this->cPair[0] = this->djSetList->get(this->cAnteriorIndex);
 | 
| 336 |         }
 | 
| 337 | 
 | 
| 338 |         /** explicit set second half of cPair */
 | 
| 339 |         inline void setPosterior(const typename TupleType::value_type b) {
 | 
| 340 |             this->cPair[1] = b;
 | 
| 341 |         }
 | 
| 342 | 
 | 
| 343 |         /** quick update to whatever the current index is pointing to */
 | 
| 344 |         inline void updatePosterior() {
 | 
| 345 |             this->cPair[1] = this->djSetList->get(this->cPosteriorIndex);
 | 
| 346 |         }
 | 
| 347 | 
 | 
| 348 |         // copy ctor
 | 
| 349 |         iterator(const iterator& other) = default;
 | 
| 350 |         // move ctor
 | 
| 351 |         iterator(iterator&& other) = default;
 | 
| 352 |         // assign iter
 | 
| 353 |         iterator& operator=(const iterator& other) = default;
 | 
| 354 | 
 | 
| 355 |         bool operator==(const iterator& other) const {
 | 
| 356 |             if (isEndVal && other.isEndVal) return br == other.br;
 | 
| 357 |             return isEndVal == other.isEndVal && cPair == other.cPair;
 | 
| 358 |         }
 | 
| 359 | 
 | 
| 360 |         bool operator!=(const iterator& other) const {
 | 
| 361 |             return !((*this) == other);
 | 
| 362 |         }
 | 
| 363 | 
 | 
| 364 |         const TupleType& operator*() const {
 | 
| 365 |             return cPair;
 | 
| 366 |         }
 | 
| 367 | 
 | 
| 368 |         const TupleType* operator->() const {
 | 
| 369 |             return &cPair;
 | 
| 370 |         }
 | 
| 371 | 
 | 
| 372 |         /* pre-increment */
 | 
| 373 |         iterator& operator++() {
 | 
| 374 |             if (isEndVal) {
 | 
| 375 |                 throw std::out_of_range("error: incrementing an out of range iterator");
 | 
| 376 |             }
 | 
| 377 | 
 | 
| 378 |             switch (ityp) {
 | 
| 379 |                 case IterType::ALL:
 | 
| 380 |                     // move posterior along one
 | 
| 381 |                     // see if we can't move the posterior along
 | 
| 382 |                     if (++cPosteriorIndex == djSetList->size()) {
 | 
| 383 |                         // move anterior along one
 | 
| 384 |                         // see if we can't move the anterior along one
 | 
| 385 |                         if (++cAnteriorIndex == djSetList->size()) {
 | 
| 386 |                             // move the djset it along one
 | 
| 387 |                             // see if we can't move it along one (we're at the end)
 | 
| 388 |                             if (++djSetMapListIt == djSetMapListEnd) {
 | 
| 389 |                                 isEndVal = true;
 | 
| 390 |                                 return *this;
 | 
| 391 |                             }
 | 
| 392 | 
 | 
| 393 |                             // we can't iterate along this djset if it is empty
 | 
| 394 |                             djSetList = (*djSetMapListIt).second;
 | 
| 395 |                             if (djSetList->size() == 0) {
 | 
| 396 |                                 throw std::out_of_range("error: encountered a zero size djset");
 | 
| 397 |                             }
 | 
| 398 | 
 | 
| 399 |                             // update our cAnterior and cPosterior
 | 
| 400 |                             cAnteriorIndex = 0;
 | 
| 401 |                             cPosteriorIndex = 0;
 | 
| 402 |                             updateAnterior();
 | 
| 403 |                             updatePosterior();
 | 
| 404 |                         }
 | 
| 405 | 
 | 
| 406 |                         // we moved our anterior along one
 | 
| 407 |                         updateAnterior();
 | 
| 408 | 
 | 
| 409 |                         cPosteriorIndex = 0;
 | 
| 410 |                         updatePosterior();
 | 
| 411 |                     }
 | 
| 412 |                     // we just moved our posterior along one
 | 
| 413 |                     updatePosterior();
 | 
| 414 | 
 | 
| 415 |                     break;
 | 
| 416 |                 case IterType::ANTERIOR:
 | 
| 417 |                     // step posterior along one, and if we can't, then we're done.
 | 
| 418 |                     if (++cPosteriorIndex == djSetList->size()) {
 | 
| 419 |                         isEndVal = true;
 | 
| 420 |                         return *this;
 | 
| 421 |                     }
 | 
| 422 |                     updatePosterior();
 | 
| 423 | 
 | 
| 424 |                     break;
 | 
| 425 |                 case IterType::ANTPOST:
 | 
| 426 |                     // fixed anterior and posterior literally only points to one, so if we increment, its the
 | 
| 427 |                     // end
 | 
| 428 |                     isEndVal = true;
 | 
| 429 |                     break;
 | 
| 430 |                 case IterType::WITHIN:
 | 
| 431 |                     // move posterior along one
 | 
| 432 |                     // see if we can't move the posterior along
 | 
| 433 |                     if (++cPosteriorIndex == djSetList->size()) {
 | 
| 434 |                         // move anterior along one
 | 
| 435 |                         // see if we can't move the anterior along one
 | 
| 436 |                         if (++cAnteriorIndex == djSetList->size()) {
 | 
| 437 |                             isEndVal = true;
 | 
| 438 |                             return *this;
 | 
| 439 |                         }
 | 
| 440 | 
 | 
| 441 |                         // we moved our anterior along one
 | 
| 442 |                         updateAnterior();
 | 
| 443 | 
 | 
| 444 |                         cPosteriorIndex = 0;
 | 
| 445 |                         updatePosterior();
 | 
| 446 |                     }
 | 
| 447 |                     // we just moved our posterior along one
 | 
| 448 |                     updatePosterior();
 | 
| 449 |                     break;
 | 
| 450 |             }
 | 
| 451 | 
 | 
| 452 |             return *this;
 | 
| 453 |         }
 | 
| 454 | 
 | 
| 455 |     private:
 | 
| 456 |         const EquivalenceRelation* br = nullptr;
 | 
| 457 |         // special tombstone value to notify that this iter represents the end
 | 
| 458 |         bool isEndVal = false;
 | 
| 459 | 
 | 
| 460 |         // all the different types of iterator this can be
 | 
| 461 |         enum IterType { ALL, ANTERIOR, ANTPOST, WITHIN };
 | 
| 462 |         IterType ityp;
 | 
| 463 | 
 | 
| 464 |         TupleType cPair;
 | 
| 465 | 
 | 
| 466 |         // the disjoint set that we're currently iterating through
 | 
| 467 |         StatesBucket djSetList;
 | 
| 468 |         typename StatesMap::iterator djSetMapListIt;
 | 
| 469 |         typename StatesMap::iterator djSetMapListEnd;
 | 
| 470 | 
 | 
| 471 |         // used for ALL, and POSTERIOR (just a current index in the cList)
 | 
| 472 |         std::size_t cAnteriorIndex = 0;
 | 
| 473 |         // used for ALL, and ANTERIOR (just a current index in the cList)
 | 
| 474 |         std::size_t cPosteriorIndex = 0;
 | 
| 475 |     };
 | 
| 476 | 
 | 
| 477 | public:
 | 
| 478 |     /**
 | 
| 479 |      * iterator pointing to the beginning of the tuples, with no restrictions
 | 
| 480 |      * @return the iterator that corresponds to the beginning of the binary relation
 | 
| 481 |      */
 | 
| 482 |     iterator begin() const {
 | 
| 483 |         genAllDisjointSetLists();
 | 
| 484 |         return iterator(this);
 | 
| 485 |     }
 | 
| 486 | 
 | 
| 487 |     /**
 | 
| 488 |      * iterator pointing to the end of the tuples
 | 
| 489 |      * @return the iterator which represents the end of the binary rel
 | 
| 490 |      */
 | 
| 491 |     iterator end() const {
 | 
| 492 |         return iterator(this, true);
 | 
| 493 |     }
 | 
| 494 | 
 | 
| 495 |     /**
 | 
| 496 |      * Obtains a range of elements matching the prefix of the given entry up to
 | 
| 497 |      * levels elements.
 | 
| 498 |      *
 | 
| 499 |      * @tparam levels the length of the requested matching prefix
 | 
| 500 |      * @param entry the entry to be looking for
 | 
| 501 |      * @return the corresponding range of matching elements
 | 
| 502 |      */
 | 
| 503 |     template <unsigned levels>
 | 
| 504 |     range<iterator> getBoundaries(const TupleType& entry) const {
 | 
| 505 |         operation_hints ctxt;
 | 
| 506 |         return getBoundaries<levels>(entry, ctxt);
 | 
| 507 |     }
 | 
| 508 | 
 | 
| 509 |     /**
 | 
| 510 |      * Obtains a range of elements matching the prefix of the given entry up to
 | 
| 511 |      * levels elements. A operation context may be provided to exploit temporal
 | 
| 512 |      * locality.
 | 
| 513 |      *
 | 
| 514 |      * @tparam levels the length of the requested matching prefix
 | 
| 515 |      * @param entry the entry to be looking for
 | 
| 516 |      * @param ctxt the operation context to be utilized
 | 
| 517 |      * @return the corresponding range of matching elements
 | 
| 518 |      */
 | 
| 519 |     template <unsigned levels>
 | 
| 520 |     range<iterator> getBoundaries(const TupleType& entry, operation_hints&) const {
 | 
| 521 |         // if nothing is bound => just use begin and end
 | 
| 522 |         if (levels == 0) return make_range(begin(), end());
 | 
| 523 | 
 | 
| 524 |         // as disjoint set is exactly two args (equiv relation)
 | 
| 525 |         // we only need to handle these cases
 | 
| 526 | 
 | 
| 527 |         if (levels == 1) {
 | 
| 528 |             // need to test if the entry actually exists
 | 
| 529 |             if (!sds.nodeExists(entry[0])) return make_range(end(), end());
 | 
| 530 | 
 | 
| 531 |             // return an iterator over all (entry[0], _)
 | 
| 532 |             return make_range(anteriorIt(entry[0]), end());
 | 
| 533 |         }
 | 
| 534 | 
 | 
| 535 |         if (levels == 2) {
 | 
| 536 |             // need to test if the entry actually exists
 | 
| 537 |             if (!sds.contains(entry[0], entry[1])) return make_range(end(), end());
 | 
| 538 | 
 | 
| 539 |             // if so return an iterator containing exactly that node
 | 
| 540 |             return make_range(antpostit(entry[0], entry[1]), end());
 | 
| 541 |         }
 | 
| 542 | 
 | 
| 543 |         std::cerr << "invalid state, cannot search for >2 arg start point in getBoundaries, in 2 arg tuple "
 | 
| 544 |                      "store\n";
 | 
| 545 |         throw "invalid state, cannot search for >2 arg start point in getBoundaries, in 2 arg tuple store";
 | 
| 546 | 
 | 
| 547 |         return make_range(end(), end());
 | 
| 548 |     }
 | 
| 549 | 
 | 
| 550 |     /**
 | 
| 551 |      * Act similar to getBoundaries. But non-static.
 | 
| 552 |      * This function should be used ONLY by interpreter,
 | 
| 553 |      * and its behavior is tightly coupling with InterpreterIndex.
 | 
| 554 |      * Do Not rely on this interface outside the interpreter.
 | 
| 555 |      *
 | 
| 556 |      * @param entry the entry to be looking for
 | 
| 557 |      * @return the corresponding range of matching elements
 | 
| 558 |      */
 | 
| 559 |     iterator lower_bound(const TupleType& entry, operation_hints&) const {
 | 
| 560 |         if (entry[0] == MIN_RAM_SIGNED && entry[1] == MIN_RAM_SIGNED) {
 | 
| 561 |             // Return an iterator over all tuples.
 | 
| 562 |             return begin();
 | 
| 563 |         }
 | 
| 564 | 
 | 
| 565 |         if (entry[0] != MIN_RAM_SIGNED && entry[1] == MIN_RAM_SIGNED) {
 | 
| 566 |             // Return an iterator over all (entry[0], _)
 | 
| 567 | 
 | 
| 568 |             if (!sds.nodeExists(entry[0])) {
 | 
| 569 |                 return end();
 | 
| 570 |             }
 | 
| 571 |             return anteriorIt(entry[0]);
 | 
| 572 |         }
 | 
| 573 | 
 | 
| 574 |         if (entry[0] != MIN_RAM_SIGNED && entry[1] != MIN_RAM_SIGNED) {
 | 
| 575 |             // Return an iterator point to the exact same node.
 | 
| 576 | 
 | 
| 577 |             if (!sds.contains(entry[0], entry[1])) {
 | 
| 578 |                 return end();
 | 
| 579 |             }
 | 
| 580 |             return antpostit(entry[0], entry[1]);
 | 
| 581 |         }
 | 
| 582 | 
 | 
| 583 |         return end();
 | 
| 584 |     }
 | 
| 585 | 
 | 
| 586 |     iterator lower_bound(const TupleType& entry) const {
 | 
| 587 |         operation_hints hints;
 | 
| 588 |         return lower_bound(entry, hints);
 | 
| 589 |     }
 | 
| 590 | 
 | 
| 591 |     /**
 | 
| 592 |      * This function is only here in order to unify interfaces in InterpreterIndex.
 | 
| 593 |      * Unlike the name suggestes, it omit the arguments and simply return the end
 | 
| 594 |      * iterator of the relation.
 | 
| 595 |      *
 | 
| 596 |      * @param omitted
 | 
| 597 |      * @return the end iterator.
 | 
| 598 |      */
 | 
| 599 |     iterator upper_bound(const TupleType&, operation_hints&) const {
 | 
| 600 |         return end();
 | 
| 601 |     }
 | 
| 602 | 
 | 
| 603 |     iterator upper_bound(const TupleType& entry) const {
 | 
| 604 |         operation_hints hints;
 | 
| 605 |         return upper_bound(entry, hints);
 | 
| 606 |     }
 | 
| 607 | 
 | 
| 608 |     /**
 | 
| 609 |      * Check emptiness.
 | 
| 610 |      */
 | 
| 611 |     bool empty() const {
 | 
| 612 |         return this->size() == 0;
 | 
| 613 |     }
 | 
| 614 | 
 | 
| 615 |     /**
 | 
| 616 |      * Creates an iterator that generates all pairs (A, X)
 | 
| 617 |      * for a given A, and X are elements within A's disjoint set.
 | 
| 618 |      * @param anteriorVal: The first value of the tuple to be generated for
 | 
| 619 |      * @return the iterator representing this.
 | 
| 620 |      */
 | 
| 621 |     iterator anteriorIt(value_type anteriorVal) const {
 | 
| 622 |         genAllDisjointSetLists();
 | 
| 623 | 
 | 
| 624 |         // locate the blocklist that the anterior val resides in
 | 
| 625 |         auto found = equivalencePartition.find({sds.findNode(anteriorVal), nullptr});
 | 
| 626 |         assert(found != equivalencePartition.end() && "iterator called on partition that doesn't exist");
 | 
| 627 | 
 | 
| 628 |         return iterator(static_cast<const EquivalenceRelation*>(this),
 | 
| 629 |                 static_cast<const value_type>(anteriorVal), static_cast<const StatesBucket>((*found).second));
 | 
| 630 |     }
 | 
| 631 | 
 | 
| 632 |     /**
 | 
| 633 |      * Creates an iterator that generates the pair (A, B)
 | 
| 634 |      * for a given A and B. If A and B don't exist, or aren't in the same set,
 | 
| 635 |      * then the end() iterator is returned.
 | 
| 636 |      * @param anteriorVal: the A value of the tuple
 | 
| 637 |      * @param posteriorVal: the B value of the tuple
 | 
| 638 |      * @return the iterator representing this
 | 
| 639 |      */
 | 
| 640 |     iterator antpostit(value_type anteriorVal, value_type posteriorVal) const {
 | 
| 641 |         // obv if they're in diff sets, then iteration for this pair just ends.
 | 
| 642 |         if (!sds.sameSet(anteriorVal, posteriorVal)) return end();
 | 
| 643 | 
 | 
| 644 |         genAllDisjointSetLists();
 | 
| 645 | 
 | 
| 646 |         // locate the blocklist that the val resides in
 | 
| 647 |         auto found = equivalencePartition.find({sds.findNode(posteriorVal), nullptr});
 | 
| 648 |         assert(found != equivalencePartition.end() && "iterator called on partition that doesn't exist");
 | 
| 649 | 
 | 
| 650 |         return iterator(this, anteriorVal, posteriorVal, (*found).second);
 | 
| 651 |     }
 | 
| 652 | 
 | 
| 653 |     /**
 | 
| 654 |      * Begin an iterator over all pairs within a single disjoint set - This is used for partition().
 | 
| 655 |      * @param rep the representative of (or element within) a disjoint set of which to generate all pairs
 | 
| 656 |      * @return an iterator that will generate all pairs within the disjoint set
 | 
| 657 |      */
 | 
| 658 |     iterator closure(value_type rep) const {
 | 
| 659 |         genAllDisjointSetLists();
 | 
| 660 | 
 | 
| 661 |         // locate the blocklist that the val resides in
 | 
| 662 |         auto found = equivalencePartition.find({sds.findNode(rep), nullptr});
 | 
| 663 |         return iterator(this, (*found).second);
 | 
| 664 |     }
 | 
| 665 | 
 | 
| 666 |     /**
 | 
| 667 |      * Generate an approximate number of iterators for parallel iteration
 | 
| 668 |      * The iterators returned are not necessarily equal in size, but in practise are approximately similarly
 | 
| 669 |      * sized
 | 
| 670 |      * Depending on the structure of the data, there can be more or less partitions returned than requested.
 | 
| 671 |      * @param chunks the number of requested partitions
 | 
| 672 |      * @return a list of the iterators as ranges
 | 
| 673 |      */
 | 
| 674 |     std::vector<souffle::range<iterator>> partition(std::size_t chunks) const {
 | 
| 675 |         // generate all reps
 | 
| 676 |         genAllDisjointSetLists();
 | 
| 677 | 
 | 
| 678 |         std::size_t numPairs = this->size();
 | 
| 679 |         if (numPairs == 0) return {};
 | 
| 680 |         if (numPairs == 1 || chunks <= 1) return {souffle::make_range(begin(), end())};
 | 
| 681 | 
 | 
| 682 |         // if there's more dj sets than requested chunks, then just return an iter per dj set
 | 
| 683 |         std::vector<souffle::range<iterator>> ret;
 | 
| 684 |         if (chunks <= equivalencePartition.size()) {
 | 
| 685 |             for (auto& p : equivalencePartition) {
 | 
| 686 |                 ret.push_back(souffle::make_range(closure(p.first), end()));
 | 
| 687 |             }
 | 
| 688 |             return ret;
 | 
| 689 |         }
 | 
| 690 | 
 | 
| 691 |         // keep it simple stupid
 | 
| 692 |         // just go through and if the size of the binrel is > numpairs/chunks, then generate an anteriorIt for
 | 
| 693 |         // each
 | 
| 694 |         const std::size_t perchunk = numPairs / chunks;
 | 
| 695 |         for (const auto& itp : equivalencePartition) {
 | 
| 696 |             const std::size_t s = itp.second->size();
 | 
| 697 |             if (s * s > perchunk) {
 | 
| 698 |                 for (const auto& i : *itp.second) {
 | 
| 699 |                     ret.push_back(souffle::make_range(anteriorIt(i), end()));
 | 
| 700 |                 }
 | 
| 701 |             } else {
 | 
| 702 |                 ret.push_back(souffle::make_range(closure(itp.first), end()));
 | 
| 703 |             }
 | 
| 704 |         }
 | 
| 705 | 
 | 
| 706 |         return ret;
 | 
| 707 |     }
 | 
| 708 | 
 | 
| 709 |     iterator find(const TupleType&, operation_hints&) const {
 | 
| 710 |         throw std::runtime_error("error: find() is not compatible with equivalence relations");
 | 
| 711 |         return begin();
 | 
| 712 |     }
 | 
| 713 | 
 | 
| 714 |     iterator find(const TupleType& t) const {
 | 
| 715 |         operation_hints context;
 | 
| 716 |         return find(t, context);
 | 
| 717 |     }
 | 
| 718 | 
 | 
| 719 | protected:
 | 
| 720 |     bool containsElement(value_type e) const {
 | 
| 721 |         return this->sds.nodeExists(e);
 | 
| 722 |     }
 | 
| 723 | 
 | 
| 724 | private:
 | 
| 725 |     // marked as mutable due to difficulties with the const enforcement via the Relation API
 | 
| 726 |     // const operations *may* safely change internal state (i.e. collapse djset forest)
 | 
| 727 |     mutable souffle::SparseDisjointSet<value_type> sds;
 | 
| 728 | 
 | 
| 729 |     // read/write lock on equivalencePartition
 | 
| 730 |     mutable std::shared_mutex statesLock;
 | 
| 731 | 
 | 
| 732 |     mutable StatesMap equivalencePartition;
 | 
| 733 |     // whether the cache is stale
 | 
| 734 |     mutable std::atomic<bool> statesMapStale;
 | 
| 735 | 
 | 
| 736 |     /**
 | 
| 737 |      * Generate a cache of the sets such that they can be iterated over efficiently.
 | 
| 738 |      * Each set is partitioned into a PiggyList.
 | 
| 739 |      */
 | 
| 740 |     void genAllDisjointSetLists() const {
 | 
| 741 |         statesLock.lock();
 | 
| 742 | 
 | 
| 743 |         // no need to generate again, already done.
 | 
| 744 |         if (!this->statesMapStale.load(std::memory_order_acquire)) {
 | 
| 745 |             statesLock.unlock();
 | 
| 746 |             return;
 | 
| 747 |         }
 | 
| 748 | 
 | 
| 749 |         // btree version
 | 
| 750 |         emptyPartition();
 | 
| 751 | 
 | 
| 752 |         std::size_t dSetSize = this->sds.ds.a_blocks.size();
 | 
| 753 |         for (std::size_t i = 0; i < dSetSize; ++i) {
 | 
| 754 |             typename TupleType::value_type sparseVal = this->sds.toSparse(i);
 | 
| 755 |             parent_t rep = this->sds.findNode(sparseVal);
 | 
| 756 | 
 | 
| 757 |             StorePair p = {static_cast<value_type>(rep), nullptr};
 | 
| 758 |             StatesList* mapList = equivalencePartition.insert(p, [&](StorePair& sp) {
 | 
| 759 |                 auto* r = new StatesList(1);
 | 
| 760 |                 sp.second = r;
 | 
| 761 |                 return r;
 | 
| 762 |             });
 | 
| 763 |             mapList->append(sparseVal);
 | 
| 764 |         }
 | 
| 765 | 
 | 
| 766 |         statesMapStale.store(false, std::memory_order_release);
 | 
| 767 |         statesLock.unlock();
 | 
| 768 |     }
 | 
| 769 | };
 | 
| 770 | }  // namespace souffle
 |