1 | /*
|
2 | * Souffle - A Datalog Compiler
|
3 | * Copyright (c) 2017 The Souffle Developers. All rights reserved
|
4 | * Licensed under the Universal Permissive License v 1.0 as shown at:
|
5 | * - https://opensource.org/licenses/UPL
|
6 | * - <souffle root>/licenses/SOUFFLE-UPL.txt
|
7 | */
|
8 |
|
9 | /************************************************************************
|
10 | *
|
11 | * @file EquivalenceRelation.h
|
12 | *
|
13 | * Defines a binary relation interface to be used with Souffle as a relational store.
|
14 | * Pairs inserted into this relation implicitly store a reflexive, symmetric, and transitive relation
|
15 | * with each other.
|
16 | *
|
17 | ***********************************************************************/
|
18 |
|
19 | #pragma once
|
20 |
|
21 | #include "souffle/RamTypes.h"
|
22 | #include "souffle/datastructure/LambdaBTree.h"
|
23 | #include "souffle/datastructure/PiggyList.h"
|
24 | #include "souffle/datastructure/UnionFind.h"
|
25 | #include "souffle/utility/ContainerUtil.h"
|
26 | #include "souffle/utility/ParallelUtil.h"
|
27 | #include <atomic>
|
28 | #include <cassert>
|
29 | #include <cstddef>
|
30 | #include <functional>
|
31 | #include <iostream>
|
32 | #include <iterator>
|
33 | #include <set>
|
34 | #include <shared_mutex>
|
35 | #include <stdexcept>
|
36 | #include <tuple>
|
37 | #include <unordered_set>
|
38 | #include <utility>
|
39 | #include <vector>
|
40 |
|
41 | namespace souffle {
|
42 | template <typename TupleType>
|
43 | class EquivalenceRelation {
|
44 | using value_type = typename TupleType::value_type;
|
45 |
|
46 | // mapping from representative to disjoint set
|
47 | // just a cache, essentially, used for iteration over
|
48 | using StatesList = souffle::PiggyList<value_type>;
|
49 | using StatesBucket = StatesList*;
|
50 | using StorePair = std::pair<value_type, StatesBucket>;
|
51 | using StatesMap = souffle::LambdaBTreeSet<StorePair, std::function<StatesBucket(StorePair&)>,
|
52 | souffle::EqrelMapComparator<StorePair>>;
|
53 |
|
54 | public:
|
55 | using element_type = TupleType;
|
56 |
|
57 | EquivalenceRelation() : statesMapStale(false){};
|
58 | ~EquivalenceRelation() {
|
59 | emptyPartition();
|
60 | }
|
61 |
|
62 | /**
|
63 | * A collection of operation hints speeding up some of the involved operations
|
64 | * by exploiting temporal locality.
|
65 | * Unused in this class, as there is no speedup to be gained.
|
66 | * This is just defined as the class expects it.
|
67 | */
|
68 | struct operation_hints {
|
69 | // resets all hints (to be triggered e.g. when deleting nodes)
|
70 | void clear() {}
|
71 | };
|
72 |
|
73 | /**
|
74 | * Insert the two values symbolically as a binary relation
|
75 | * @param x node to be added/paired
|
76 | * @param y node to be added/paired
|
77 | * @return true if the pair is new to the data structure
|
78 | */
|
79 | bool insert(value_type x, value_type y) {
|
80 | operation_hints z;
|
81 | return insert(x, y, z);
|
82 | };
|
83 |
|
84 | /**
|
85 | * Insert the tuple symbolically.
|
86 | * @param tuple The tuple to be inserted
|
87 | * @return true if the tuple is new to the data structure
|
88 | */
|
89 | bool insert(const TupleType& tuple) {
|
90 | operation_hints hints;
|
91 | return insert(tuple[0], tuple[1], hints);
|
92 | };
|
93 |
|
94 | /**
|
95 | * Insert the two values symbolically as a binary relation
|
96 | * @param x node to be added/paired
|
97 | * @param y node to be added/paired
|
98 | * @param z the hints to where the pair should be inserted (not applicable atm)
|
99 | * @return true if the pair is new to the data structure
|
100 | */
|
101 | bool insert(value_type x, value_type y, operation_hints) {
|
102 | // indicate that iterators will have to generate on request
|
103 | this->statesMapStale.store(true, std::memory_order_relaxed);
|
104 | bool retval = !contains(x, y);
|
105 | sds.unionNodes(x, y);
|
106 | return retval;
|
107 | }
|
108 |
|
109 | /**
|
110 | * inserts all nodes from the other relation into this one
|
111 | * @param other the binary relation from which to add elements from
|
112 | */
|
113 | void insertAll(const EquivalenceRelation<TupleType>& other) {
|
114 | other.genAllDisjointSetLists();
|
115 |
|
116 | // iterate over partitions at a time
|
117 | for (auto&& [rep, pl] : other.equivalencePartition) {
|
118 | const std::size_t ksize = pl->size();
|
119 | for (std::size_t i = 0; i < ksize; ++i) {
|
120 | this->sds.unionNodes(rep, pl->get(i));
|
121 | }
|
122 | }
|
123 | // invalidate iterators unconditionally
|
124 | this->statesMapStale.store(true, std::memory_order_relaxed);
|
125 | }
|
126 |
|
127 | /**
|
128 | * Extend this relation with another relation, expanding this equivalence
|
129 | * relation and inserting it into the other relation.
|
130 | *
|
131 | * The supplied relation is the old knowledge, whilst this relation only
|
132 | * contains explicitly new knowledge. After this operation the "implicitly
|
133 | * new tuples" are now explicitly inserted this relation, and all of the new
|
134 | * tuples in this relation are inserted into the old relation.
|
135 | */
|
136 | void extendAndInsert(EquivalenceRelation<TupleType>& other) {
|
137 | if (other.size() == 0 && this->size() == 0) return;
|
138 |
|
139 | std::unordered_set<value_type> repsCovered;
|
140 |
|
141 | // This vector holds all of the elements of this equivalence relation
|
142 | // that aren't yet in other, which get inserted after extending this
|
143 | // relation by other. These operations are interleaved for maximum
|
144 | // efficiency - either extend or inserting first would make the other
|
145 | // operation unnecessarily slow.
|
146 | std::vector<std::pair<value_type, value_type>> toInsert;
|
147 | auto size = std::distance(this->sds.sparseToDenseMap.begin(), this->sds.sparseToDenseMap.end());
|
148 | toInsert.reserve(size);
|
149 |
|
150 | // find all the disjoint sets that need to be added to this relation
|
151 | // that exist in other (and exist in this)
|
152 | {
|
153 | auto it = this->sds.sparseToDenseMap.begin();
|
154 | auto end = this->sds.sparseToDenseMap.end();
|
155 | value_type el;
|
156 | for (; it != end; ++it) {
|
157 | std::tie(el, std::ignore) = *it;
|
158 | if (other.containsElement(el)) {
|
159 | value_type rep = other.sds.findNode(el);
|
160 | if (repsCovered.count(rep) == 0) {
|
161 | repsCovered.emplace(rep);
|
162 | }
|
163 | }
|
164 | toInsert.emplace_back(el, this->sds.findNode(el));
|
165 | }
|
166 | }
|
167 | assert(size >= 0);
|
168 | assert(toInsert.size() == (std::size_t)size);
|
169 |
|
170 | // add the intersecting dj sets into this one
|
171 | {
|
172 | value_type el;
|
173 | value_type rep;
|
174 | auto it = other.sds.sparseToDenseMap.begin();
|
175 | auto end = other.sds.sparseToDenseMap.end();
|
176 | for (; it != end; ++it) {
|
177 | std::tie(el, std::ignore) = *it;
|
178 | rep = other.sds.findNode(el);
|
179 | if (repsCovered.count(rep) != 0) {
|
180 | this->insert(el, rep);
|
181 | }
|
182 | }
|
183 | }
|
184 |
|
185 | // Insert all new tuples from this relation into the old relation
|
186 | {
|
187 | value_type el;
|
188 | value_type rep;
|
189 | for (std::pair<value_type, value_type> p : toInsert) {
|
190 | std::tie(el, rep) = p;
|
191 | other.insert(el, rep);
|
192 | }
|
193 | }
|
194 | }
|
195 |
|
196 | /**
|
197 | * Returns whether there exists a pair with these two nodes
|
198 | * @param x front of pair
|
199 | * @param y back of pair
|
200 | */
|
201 | bool contains(value_type x, value_type y) const {
|
202 | return sds.contains(x, y);
|
203 | }
|
204 |
|
205 | /**
|
206 | * Returns whether there exists given tuple.
|
207 | * @param tuple The tuple to search for.
|
208 | */
|
209 | bool contains(const TupleType& tuple, operation_hints&) const {
|
210 | return contains(tuple[0], tuple[1]);
|
211 | };
|
212 |
|
213 | bool contains(const TupleType& tuple) const {
|
214 | return contains(tuple[0], tuple[1]);
|
215 | };
|
216 |
|
217 | void emptyPartition() const {
|
218 | // delete the beautiful values inside (they're raw ptrs, so they need to be.)
|
219 | for (auto& pair : equivalencePartition) {
|
220 | delete pair.second;
|
221 | }
|
222 | // invalidate it my dude
|
223 | this->statesMapStale.store(true, std::memory_order_relaxed);
|
224 |
|
225 | equivalencePartition.clear();
|
226 | }
|
227 |
|
228 | /**
|
229 | * Empty the relation
|
230 | */
|
231 | void clear() {
|
232 | statesLock.lock();
|
233 |
|
234 | sds.clear();
|
235 | emptyPartition();
|
236 |
|
237 | statesLock.unlock();
|
238 | }
|
239 |
|
240 | /**
|
241 | * Size of relation
|
242 | * @return the sum of the number of pairs per disjoint set
|
243 | */
|
244 | std::size_t size() const {
|
245 | genAllDisjointSetLists();
|
246 |
|
247 | statesLock.lock_shared();
|
248 |
|
249 | std::size_t retVal = 0;
|
250 | for (auto& e : this->equivalencePartition) {
|
251 | const std::size_t s = e.second->size();
|
252 | retVal += s * s;
|
253 | }
|
254 |
|
255 | statesLock.unlock_shared();
|
256 | return retVal;
|
257 | }
|
258 |
|
259 | // an almighty iterator for several types of iteration.
|
260 | // Unfortunately, subclassing isn't an option with souffle
|
261 | // - we don't deal with pointers (so no virtual)
|
262 | // - and a single iter type is expected (see Relation::iterator e.g.) (i think)
|
263 | class iterator {
|
264 | public:
|
265 | typedef std::forward_iterator_tag iterator_category;
|
266 | using value_type = TupleType;
|
267 | using difference_type = ptrdiff_t;
|
268 | using pointer = value_type*;
|
269 | using reference = value_type&;
|
270 |
|
271 | // one iterator for signalling the end (simplifies)
|
272 | explicit iterator(const EquivalenceRelation* br, bool /* signalIsEndIterator */)
|
273 | : br(br), isEndVal(true){};
|
274 |
|
275 | explicit iterator(const EquivalenceRelation* br)
|
276 | : br(br), ityp(IterType::ALL), djSetMapListIt(br->equivalencePartition.begin()),
|
277 | djSetMapListEnd(br->equivalencePartition.end()) {
|
278 | // no need to fast forward if this iterator is empty
|
279 | if (djSetMapListIt == djSetMapListEnd) {
|
280 | isEndVal = true;
|
281 | return;
|
282 | }
|
283 | // grab the pointer to the list, and make it our current list
|
284 | djSetList = (*djSetMapListIt).second;
|
285 | assert(djSetList->size() != 0);
|
286 |
|
287 | updateAnterior();
|
288 | updatePosterior();
|
289 | }
|
290 |
|
291 | // WITHIN: iterator for everything within the same DJset (used for EquivalenceRelation.partition())
|
292 | explicit iterator(const EquivalenceRelation* br, const StatesBucket within)
|
293 | : br(br), ityp(IterType::WITHIN), djSetList(within) {
|
294 | // empty dj set
|
295 | if (djSetList->size() == 0) {
|
296 | isEndVal = true;
|
297 | }
|
298 |
|
299 | updateAnterior();
|
300 | updatePosterior();
|
301 | }
|
302 |
|
303 | // ANTERIOR: iterator that yields all (former, _) \in djset(former) (djset(former) === within)
|
304 | explicit iterator(const EquivalenceRelation* br, const typename TupleType::value_type former,
|
305 | const StatesBucket within)
|
306 | : br(br), ityp(IterType::ANTERIOR), djSetList(within) {
|
307 | if (djSetList->size() == 0) {
|
308 | isEndVal = true;
|
309 | }
|
310 |
|
311 | setAnterior(former);
|
312 | updatePosterior();
|
313 | }
|
314 |
|
315 | // ANTPOST: iterator that yields all (former, latter) \in djset(former), (djset(former) ==
|
316 | // djset(latter) == within)
|
317 | explicit iterator(const EquivalenceRelation* br, const typename TupleType::value_type former,
|
318 | typename TupleType::value_type latter, const StatesBucket within)
|
319 | : br(br), ityp(IterType::ANTPOST), djSetList(within) {
|
320 | if (djSetList->size() == 0) {
|
321 | isEndVal = true;
|
322 | }
|
323 |
|
324 | setAnterior(former);
|
325 | setPosterior(latter);
|
326 | }
|
327 |
|
328 | /** explicit set first half of cPair */
|
329 | inline void setAnterior(const typename TupleType::value_type a) {
|
330 | this->cPair[0] = a;
|
331 | }
|
332 |
|
333 | /** quick update to whatever the current index is pointing to */
|
334 | inline void updateAnterior() {
|
335 | this->cPair[0] = this->djSetList->get(this->cAnteriorIndex);
|
336 | }
|
337 |
|
338 | /** explicit set second half of cPair */
|
339 | inline void setPosterior(const typename TupleType::value_type b) {
|
340 | this->cPair[1] = b;
|
341 | }
|
342 |
|
343 | /** quick update to whatever the current index is pointing to */
|
344 | inline void updatePosterior() {
|
345 | this->cPair[1] = this->djSetList->get(this->cPosteriorIndex);
|
346 | }
|
347 |
|
348 | // copy ctor
|
349 | iterator(const iterator& other) = default;
|
350 | // move ctor
|
351 | iterator(iterator&& other) = default;
|
352 | // assign iter
|
353 | iterator& operator=(const iterator& other) = default;
|
354 |
|
355 | bool operator==(const iterator& other) const {
|
356 | if (isEndVal && other.isEndVal) return br == other.br;
|
357 | return isEndVal == other.isEndVal && cPair == other.cPair;
|
358 | }
|
359 |
|
360 | bool operator!=(const iterator& other) const {
|
361 | return !((*this) == other);
|
362 | }
|
363 |
|
364 | const TupleType& operator*() const {
|
365 | return cPair;
|
366 | }
|
367 |
|
368 | const TupleType* operator->() const {
|
369 | return &cPair;
|
370 | }
|
371 |
|
372 | /* pre-increment */
|
373 | iterator& operator++() {
|
374 | if (isEndVal) {
|
375 | throw std::out_of_range("error: incrementing an out of range iterator");
|
376 | }
|
377 |
|
378 | switch (ityp) {
|
379 | case IterType::ALL:
|
380 | // move posterior along one
|
381 | // see if we can't move the posterior along
|
382 | if (++cPosteriorIndex == djSetList->size()) {
|
383 | // move anterior along one
|
384 | // see if we can't move the anterior along one
|
385 | if (++cAnteriorIndex == djSetList->size()) {
|
386 | // move the djset it along one
|
387 | // see if we can't move it along one (we're at the end)
|
388 | if (++djSetMapListIt == djSetMapListEnd) {
|
389 | isEndVal = true;
|
390 | return *this;
|
391 | }
|
392 |
|
393 | // we can't iterate along this djset if it is empty
|
394 | djSetList = (*djSetMapListIt).second;
|
395 | if (djSetList->size() == 0) {
|
396 | throw std::out_of_range("error: encountered a zero size djset");
|
397 | }
|
398 |
|
399 | // update our cAnterior and cPosterior
|
400 | cAnteriorIndex = 0;
|
401 | cPosteriorIndex = 0;
|
402 | updateAnterior();
|
403 | updatePosterior();
|
404 | }
|
405 |
|
406 | // we moved our anterior along one
|
407 | updateAnterior();
|
408 |
|
409 | cPosteriorIndex = 0;
|
410 | updatePosterior();
|
411 | }
|
412 | // we just moved our posterior along one
|
413 | updatePosterior();
|
414 |
|
415 | break;
|
416 | case IterType::ANTERIOR:
|
417 | // step posterior along one, and if we can't, then we're done.
|
418 | if (++cPosteriorIndex == djSetList->size()) {
|
419 | isEndVal = true;
|
420 | return *this;
|
421 | }
|
422 | updatePosterior();
|
423 |
|
424 | break;
|
425 | case IterType::ANTPOST:
|
426 | // fixed anterior and posterior literally only points to one, so if we increment, its the
|
427 | // end
|
428 | isEndVal = true;
|
429 | break;
|
430 | case IterType::WITHIN:
|
431 | // move posterior along one
|
432 | // see if we can't move the posterior along
|
433 | if (++cPosteriorIndex == djSetList->size()) {
|
434 | // move anterior along one
|
435 | // see if we can't move the anterior along one
|
436 | if (++cAnteriorIndex == djSetList->size()) {
|
437 | isEndVal = true;
|
438 | return *this;
|
439 | }
|
440 |
|
441 | // we moved our anterior along one
|
442 | updateAnterior();
|
443 |
|
444 | cPosteriorIndex = 0;
|
445 | updatePosterior();
|
446 | }
|
447 | // we just moved our posterior along one
|
448 | updatePosterior();
|
449 | break;
|
450 | }
|
451 |
|
452 | return *this;
|
453 | }
|
454 |
|
455 | private:
|
456 | const EquivalenceRelation* br = nullptr;
|
457 | // special tombstone value to notify that this iter represents the end
|
458 | bool isEndVal = false;
|
459 |
|
460 | // all the different types of iterator this can be
|
461 | enum IterType { ALL, ANTERIOR, ANTPOST, WITHIN };
|
462 | IterType ityp;
|
463 |
|
464 | TupleType cPair;
|
465 |
|
466 | // the disjoint set that we're currently iterating through
|
467 | StatesBucket djSetList;
|
468 | typename StatesMap::iterator djSetMapListIt;
|
469 | typename StatesMap::iterator djSetMapListEnd;
|
470 |
|
471 | // used for ALL, and POSTERIOR (just a current index in the cList)
|
472 | std::size_t cAnteriorIndex = 0;
|
473 | // used for ALL, and ANTERIOR (just a current index in the cList)
|
474 | std::size_t cPosteriorIndex = 0;
|
475 | };
|
476 |
|
477 | public:
|
478 | /**
|
479 | * iterator pointing to the beginning of the tuples, with no restrictions
|
480 | * @return the iterator that corresponds to the beginning of the binary relation
|
481 | */
|
482 | iterator begin() const {
|
483 | genAllDisjointSetLists();
|
484 | return iterator(this);
|
485 | }
|
486 |
|
487 | /**
|
488 | * iterator pointing to the end of the tuples
|
489 | * @return the iterator which represents the end of the binary rel
|
490 | */
|
491 | iterator end() const {
|
492 | return iterator(this, true);
|
493 | }
|
494 |
|
495 | /**
|
496 | * Obtains a range of elements matching the prefix of the given entry up to
|
497 | * levels elements.
|
498 | *
|
499 | * @tparam levels the length of the requested matching prefix
|
500 | * @param entry the entry to be looking for
|
501 | * @return the corresponding range of matching elements
|
502 | */
|
503 | template <unsigned levels>
|
504 | range<iterator> getBoundaries(const TupleType& entry) const {
|
505 | operation_hints ctxt;
|
506 | return getBoundaries<levels>(entry, ctxt);
|
507 | }
|
508 |
|
509 | /**
|
510 | * Obtains a range of elements matching the prefix of the given entry up to
|
511 | * levels elements. A operation context may be provided to exploit temporal
|
512 | * locality.
|
513 | *
|
514 | * @tparam levels the length of the requested matching prefix
|
515 | * @param entry the entry to be looking for
|
516 | * @param ctxt the operation context to be utilized
|
517 | * @return the corresponding range of matching elements
|
518 | */
|
519 | template <unsigned levels>
|
520 | range<iterator> getBoundaries(const TupleType& entry, operation_hints&) const {
|
521 | // if nothing is bound => just use begin and end
|
522 | if (levels == 0) return make_range(begin(), end());
|
523 |
|
524 | // as disjoint set is exactly two args (equiv relation)
|
525 | // we only need to handle these cases
|
526 |
|
527 | if (levels == 1) {
|
528 | // need to test if the entry actually exists
|
529 | if (!sds.nodeExists(entry[0])) return make_range(end(), end());
|
530 |
|
531 | // return an iterator over all (entry[0], _)
|
532 | return make_range(anteriorIt(entry[0]), end());
|
533 | }
|
534 |
|
535 | if (levels == 2) {
|
536 | // need to test if the entry actually exists
|
537 | if (!sds.contains(entry[0], entry[1])) return make_range(end(), end());
|
538 |
|
539 | // if so return an iterator containing exactly that node
|
540 | return make_range(antpostit(entry[0], entry[1]), end());
|
541 | }
|
542 |
|
543 | std::cerr << "invalid state, cannot search for >2 arg start point in getBoundaries, in 2 arg tuple "
|
544 | "store\n";
|
545 | throw "invalid state, cannot search for >2 arg start point in getBoundaries, in 2 arg tuple store";
|
546 |
|
547 | return make_range(end(), end());
|
548 | }
|
549 |
|
550 | /**
|
551 | * Act similar to getBoundaries. But non-static.
|
552 | * This function should be used ONLY by interpreter,
|
553 | * and its behavior is tightly coupling with InterpreterIndex.
|
554 | * Do Not rely on this interface outside the interpreter.
|
555 | *
|
556 | * @param entry the entry to be looking for
|
557 | * @return the corresponding range of matching elements
|
558 | */
|
559 | iterator lower_bound(const TupleType& entry, operation_hints&) const {
|
560 | if (entry[0] == MIN_RAM_SIGNED && entry[1] == MIN_RAM_SIGNED) {
|
561 | // Return an iterator over all tuples.
|
562 | return begin();
|
563 | }
|
564 |
|
565 | if (entry[0] != MIN_RAM_SIGNED && entry[1] == MIN_RAM_SIGNED) {
|
566 | // Return an iterator over all (entry[0], _)
|
567 |
|
568 | if (!sds.nodeExists(entry[0])) {
|
569 | return end();
|
570 | }
|
571 | return anteriorIt(entry[0]);
|
572 | }
|
573 |
|
574 | if (entry[0] != MIN_RAM_SIGNED && entry[1] != MIN_RAM_SIGNED) {
|
575 | // Return an iterator point to the exact same node.
|
576 |
|
577 | if (!sds.contains(entry[0], entry[1])) {
|
578 | return end();
|
579 | }
|
580 | return antpostit(entry[0], entry[1]);
|
581 | }
|
582 |
|
583 | return end();
|
584 | }
|
585 |
|
586 | iterator lower_bound(const TupleType& entry) const {
|
587 | operation_hints hints;
|
588 | return lower_bound(entry, hints);
|
589 | }
|
590 |
|
591 | /**
|
592 | * This function is only here in order to unify interfaces in InterpreterIndex.
|
593 | * Unlike the name suggestes, it omit the arguments and simply return the end
|
594 | * iterator of the relation.
|
595 | *
|
596 | * @param omitted
|
597 | * @return the end iterator.
|
598 | */
|
599 | iterator upper_bound(const TupleType&, operation_hints&) const {
|
600 | return end();
|
601 | }
|
602 |
|
603 | iterator upper_bound(const TupleType& entry) const {
|
604 | operation_hints hints;
|
605 | return upper_bound(entry, hints);
|
606 | }
|
607 |
|
608 | /**
|
609 | * Check emptiness.
|
610 | */
|
611 | bool empty() const {
|
612 | return this->size() == 0;
|
613 | }
|
614 |
|
615 | /**
|
616 | * Creates an iterator that generates all pairs (A, X)
|
617 | * for a given A, and X are elements within A's disjoint set.
|
618 | * @param anteriorVal: The first value of the tuple to be generated for
|
619 | * @return the iterator representing this.
|
620 | */
|
621 | iterator anteriorIt(value_type anteriorVal) const {
|
622 | genAllDisjointSetLists();
|
623 |
|
624 | // locate the blocklist that the anterior val resides in
|
625 | auto found = equivalencePartition.find({sds.findNode(anteriorVal), nullptr});
|
626 | assert(found != equivalencePartition.end() && "iterator called on partition that doesn't exist");
|
627 |
|
628 | return iterator(static_cast<const EquivalenceRelation*>(this),
|
629 | static_cast<const value_type>(anteriorVal), static_cast<const StatesBucket>((*found).second));
|
630 | }
|
631 |
|
632 | /**
|
633 | * Creates an iterator that generates the pair (A, B)
|
634 | * for a given A and B. If A and B don't exist, or aren't in the same set,
|
635 | * then the end() iterator is returned.
|
636 | * @param anteriorVal: the A value of the tuple
|
637 | * @param posteriorVal: the B value of the tuple
|
638 | * @return the iterator representing this
|
639 | */
|
640 | iterator antpostit(value_type anteriorVal, value_type posteriorVal) const {
|
641 | // obv if they're in diff sets, then iteration for this pair just ends.
|
642 | if (!sds.sameSet(anteriorVal, posteriorVal)) return end();
|
643 |
|
644 | genAllDisjointSetLists();
|
645 |
|
646 | // locate the blocklist that the val resides in
|
647 | auto found = equivalencePartition.find({sds.findNode(posteriorVal), nullptr});
|
648 | assert(found != equivalencePartition.end() && "iterator called on partition that doesn't exist");
|
649 |
|
650 | return iterator(this, anteriorVal, posteriorVal, (*found).second);
|
651 | }
|
652 |
|
653 | /**
|
654 | * Begin an iterator over all pairs within a single disjoint set - This is used for partition().
|
655 | * @param rep the representative of (or element within) a disjoint set of which to generate all pairs
|
656 | * @return an iterator that will generate all pairs within the disjoint set
|
657 | */
|
658 | iterator closure(value_type rep) const {
|
659 | genAllDisjointSetLists();
|
660 |
|
661 | // locate the blocklist that the val resides in
|
662 | auto found = equivalencePartition.find({sds.findNode(rep), nullptr});
|
663 | return iterator(this, (*found).second);
|
664 | }
|
665 |
|
666 | /**
|
667 | * Generate an approximate number of iterators for parallel iteration
|
668 | * The iterators returned are not necessarily equal in size, but in practise are approximately similarly
|
669 | * sized
|
670 | * Depending on the structure of the data, there can be more or less partitions returned than requested.
|
671 | * @param chunks the number of requested partitions
|
672 | * @return a list of the iterators as ranges
|
673 | */
|
674 | std::vector<souffle::range<iterator>> partition(std::size_t chunks) const {
|
675 | // generate all reps
|
676 | genAllDisjointSetLists();
|
677 |
|
678 | std::size_t numPairs = this->size();
|
679 | if (numPairs == 0) return {};
|
680 | if (numPairs == 1 || chunks <= 1) return {souffle::make_range(begin(), end())};
|
681 |
|
682 | // if there's more dj sets than requested chunks, then just return an iter per dj set
|
683 | std::vector<souffle::range<iterator>> ret;
|
684 | if (chunks <= equivalencePartition.size()) {
|
685 | for (auto& p : equivalencePartition) {
|
686 | ret.push_back(souffle::make_range(closure(p.first), end()));
|
687 | }
|
688 | return ret;
|
689 | }
|
690 |
|
691 | // keep it simple stupid
|
692 | // just go through and if the size of the binrel is > numpairs/chunks, then generate an anteriorIt for
|
693 | // each
|
694 | const std::size_t perchunk = numPairs / chunks;
|
695 | for (const auto& itp : equivalencePartition) {
|
696 | const std::size_t s = itp.second->size();
|
697 | if (s * s > perchunk) {
|
698 | for (const auto& i : *itp.second) {
|
699 | ret.push_back(souffle::make_range(anteriorIt(i), end()));
|
700 | }
|
701 | } else {
|
702 | ret.push_back(souffle::make_range(closure(itp.first), end()));
|
703 | }
|
704 | }
|
705 |
|
706 | return ret;
|
707 | }
|
708 |
|
709 | iterator find(const TupleType&, operation_hints&) const {
|
710 | throw std::runtime_error("error: find() is not compatible with equivalence relations");
|
711 | return begin();
|
712 | }
|
713 |
|
714 | iterator find(const TupleType& t) const {
|
715 | operation_hints context;
|
716 | return find(t, context);
|
717 | }
|
718 |
|
719 | protected:
|
720 | bool containsElement(value_type e) const {
|
721 | return this->sds.nodeExists(e);
|
722 | }
|
723 |
|
724 | private:
|
725 | // marked as mutable due to difficulties with the const enforcement via the Relation API
|
726 | // const operations *may* safely change internal state (i.e. collapse djset forest)
|
727 | mutable souffle::SparseDisjointSet<value_type> sds;
|
728 |
|
729 | // read/write lock on equivalencePartition
|
730 | mutable std::shared_mutex statesLock;
|
731 |
|
732 | mutable StatesMap equivalencePartition;
|
733 | // whether the cache is stale
|
734 | mutable std::atomic<bool> statesMapStale;
|
735 |
|
736 | /**
|
737 | * Generate a cache of the sets such that they can be iterated over efficiently.
|
738 | * Each set is partitioned into a PiggyList.
|
739 | */
|
740 | void genAllDisjointSetLists() const {
|
741 | statesLock.lock();
|
742 |
|
743 | // no need to generate again, already done.
|
744 | if (!this->statesMapStale.load(std::memory_order_acquire)) {
|
745 | statesLock.unlock();
|
746 | return;
|
747 | }
|
748 |
|
749 | // btree version
|
750 | emptyPartition();
|
751 |
|
752 | std::size_t dSetSize = this->sds.ds.a_blocks.size();
|
753 | for (std::size_t i = 0; i < dSetSize; ++i) {
|
754 | typename TupleType::value_type sparseVal = this->sds.toSparse(i);
|
755 | parent_t rep = this->sds.findNode(sparseVal);
|
756 |
|
757 | StorePair p = {static_cast<value_type>(rep), nullptr};
|
758 | StatesList* mapList = equivalencePartition.insert(p, [&](StorePair& sp) {
|
759 | auto* r = new StatesList(1);
|
760 | sp.second = r;
|
761 | return r;
|
762 | });
|
763 | mapList->append(sparseVal);
|
764 | }
|
765 |
|
766 | statesMapStale.store(false, std::memory_order_release);
|
767 | statesLock.unlock();
|
768 | }
|
769 | };
|
770 | } // namespace souffle
|