OILS / vendor / souffle / provenance / ExplainProvenanceImpl.h View on Github | oilshell.org

928 lines, 539 significant
1/*
2 * Souffle - A Datalog Compiler
3 * Copyright (c) 2017, The Souffle Developers. All rights reserved
4 * Licensed under the Universal Permissive License v 1.0 as shown at:
5 * - https://opensource.org/licenses/UPL
6 * - <souffle root>/licenses/SOUFFLE-UPL.txt
7 */
8
9/************************************************************************
10 *
11 * @file ExplainProvenanceImpl.h
12 *
13 * Implementation of abstract class in ExplainProvenance.h for guided Impl provenance
14 *
15 ***********************************************************************/
16
17#pragma once
18
19#include "souffle/BinaryConstraintOps.h"
20#include "souffle/RamTypes.h"
21#include "souffle/SouffleInterface.h"
22#include "souffle/SymbolTable.h"
23#include "souffle/provenance/ExplainProvenance.h"
24#include "souffle/provenance/ExplainTree.h"
25#include "souffle/utility/ContainerUtil.h"
26#include "souffle/utility/MiscUtil.h"
27#include "souffle/utility/StreamUtil.h"
28#include "souffle/utility/StringUtil.h"
29#include <algorithm>
30#include <cassert>
31#include <chrono>
32#include <cstdio>
33#include <iostream>
34#include <map>
35#include <memory>
36#include <regex>
37#include <sstream>
38#include <string>
39#include <tuple>
40#include <type_traits>
41#include <utility>
42#include <vector>
43
44namespace souffle {
45
46using namespace stream_write_qualified_char_as_number;
47
48class ExplainProvenanceImpl : public ExplainProvenance {
49 using arity_type = Relation::arity_type;
50
51public:
52 ExplainProvenanceImpl(SouffleProgram& prog) : ExplainProvenance(prog) {
53 setup();
54 }
55
56 void setup() override {
57 // for each clause, store a mapping from the head relation name to body relation names
58 for (auto& rel : prog.getAllRelations()) {
59 std::string name = rel->getName();
60
61 // only process info relations
62 if (name.find("@info") == std::string::npos) {
63 continue;
64 }
65
66 // find all the info tuples
67 for (auto& tuple : *rel) {
68 std::vector<std::string> bodyLiterals;
69
70 // first field is rule number
71 RamDomain ruleNum;
72 tuple >> ruleNum;
73
74 // middle fields are body literals
75 for (std::size_t i = 1; i + 1 < rel->getArity(); i++) {
76 std::string bodyLit;
77 tuple >> bodyLit;
78 bodyLiterals.push_back(bodyLit);
79 }
80
81 // last field is the rule itself
82 std::string rule;
83 tuple >> rule;
84
85 std::string relName = name.substr(0, name.find(".@info"));
86 info.insert({std::make_pair(relName, ruleNum), bodyLiterals});
87 rules.insert({std::make_pair(relName, ruleNum), rule});
88 }
89 }
90 }
91
92 Own<TreeNode> explain(std::string relName, std::vector<RamDomain> tuple, int ruleNum, int levelNum,
93 std::size_t depthLimit) {
94 std::stringstream joinedArgs;
95 joinedArgs << join(decodeArguments(relName, tuple), ", ");
96 auto joinedArgsStr = joinedArgs.str();
97
98 // if fact
99 if (levelNum == 0) {
100 return mk<LeafNode>(relName + "(" + joinedArgsStr + ")");
101 }
102
103 assert(contains(info, std::make_pair(relName, ruleNum)) && "invalid rule for tuple");
104
105 // if depth limit exceeded
106 if (depthLimit <= 1) {
107 tuple.push_back(ruleNum);
108 tuple.push_back(levelNum);
109
110 // find if subproof exists already
111 std::size_t idx = 0;
112 auto it = std::find(subproofs.begin(), subproofs.end(), tuple);
113 if (it != subproofs.end()) {
114 idx = it - subproofs.begin();
115 } else {
116 subproofs.push_back(tuple);
117 idx = subproofs.size() - 1;
118 }
119
120 return mk<LeafNode>("subproof " + relName + "(" + std::to_string(idx) + ")");
121 }
122
123 tuple.push_back(levelNum);
124
125 auto internalNode =
126 mk<InnerNode>(relName + "(" + joinedArgsStr + ")", "(R" + std::to_string(ruleNum) + ")");
127
128 // make return vector pointer
129 std::vector<RamDomain> ret;
130
131 // execute subroutine to get subproofs
132 prog.executeSubroutine(relName + "_" + std::to_string(ruleNum) + "_subproof", tuple, ret);
133
134 // recursively get nodes for subproofs
135 std::size_t tupleCurInd = 0;
136 auto bodyRelations = info.at(std::make_pair(relName, ruleNum));
137
138 // start from begin + 1 because the first element represents the head atom
139 for (auto it = bodyRelations.begin() + 1; it < bodyRelations.end(); it++) {
140 std::string bodyLiteral = *it;
141 // split bodyLiteral since it contains relation name plus arguments
142 std::string bodyRel = splitString(bodyLiteral, ',')[0];
143
144 // check whether the current atom is a constraint
145 assert(bodyRel.size() > 0 && "body of a relation should have positive length");
146 bool isConstraint = contains(constraintList, bodyRel);
147
148 // handle negated atom names
149 auto bodyRelAtomName = bodyRel;
150 if (bodyRel[0] == '!' && bodyRel != "!=") {
151 bodyRelAtomName = bodyRel.substr(1);
152 }
153
154 // traverse subroutine return
155 std::size_t arity;
156 std::size_t auxiliaryArity;
157 if (isConstraint) {
158 // we only handle binary constraints, and assume arity is 4 to account for hidden provenance
159 // annotations
160 arity = 4;
161 auxiliaryArity = 2;
162 } else {
163 arity = prog.getRelation(bodyRelAtomName)->getArity();
164 auxiliaryArity = prog.getRelation(bodyRelAtomName)->getAuxiliaryArity();
165 }
166 auto tupleEnd = tupleCurInd + arity;
167
168 // store current tuple
169 std::vector<RamDomain> subproofTuple;
170
171 for (; tupleCurInd < tupleEnd - auxiliaryArity; tupleCurInd++) {
172 subproofTuple.push_back(ret[tupleCurInd]);
173 }
174
175 int subproofRuleNum = ret[tupleCurInd];
176 int subproofLevelNum = ret[tupleCurInd + 1];
177
178 tupleCurInd += 2;
179
180 // for a negation, display the corresponding tuple and do not recurse
181 if (bodyRel[0] == '!' && bodyRel != "!=") {
182 std::stringstream joinedTuple;
183 joinedTuple << join(decodeArguments(bodyRelAtomName, subproofTuple), ", ");
184 auto joinedTupleStr = joinedTuple.str();
185 internalNode->add_child(mk<LeafNode>(bodyRel + "(" + joinedTupleStr + ")"));
186 internalNode->setSize(internalNode->getSize() + 1);
187 // for a binary constraint, display the corresponding values and do not recurse
188 } else if (isConstraint) {
189 std::stringstream joinedConstraint;
190
191 // FIXME: We need type info in order to figure out how to print arguments.
192 BinaryConstraintOp rawBinOp = toBinaryConstraintOp(bodyRel);
193 if (isOrderedBinaryConstraintOp(rawBinOp)) {
194 joinedConstraint << subproofTuple[0] << " " << bodyRel << " " << subproofTuple[1];
195 } else {
196 joinedConstraint << bodyRel << "(\"" << symTable.decode(subproofTuple[0]) << "\", \""
197 << symTable.decode(subproofTuple[1]) << "\")";
198 }
199
200 internalNode->add_child(mk<LeafNode>(joinedConstraint.str()));
201 internalNode->setSize(internalNode->getSize() + 1);
202 // otherwise, for a normal tuple, recurse
203 } else {
204 auto child =
205 explain(bodyRel, subproofTuple, subproofRuleNum, subproofLevelNum, depthLimit - 1);
206 internalNode->setSize(internalNode->getSize() + child->getSize());
207 internalNode->add_child(std::move(child));
208 }
209
210 tupleCurInd = tupleEnd;
211 }
212
213 return internalNode;
214 }
215
216 Own<TreeNode> explain(
217 std::string relName, std::vector<std::string> args, std::size_t depthLimit) override {
218 auto tuple = argsToNums(relName, args);
219 if (tuple.empty()) {
220 return mk<LeafNode>("Relation not found");
221 }
222
223 std::tuple<int, int> tupleInfo = findTuple(relName, tuple);
224
225 int ruleNum = std::get<0>(tupleInfo);
226 int levelNum = std::get<1>(tupleInfo);
227
228 if (ruleNum < 0 || levelNum == -1) {
229 return mk<LeafNode>("Tuple not found");
230 }
231
232 return explain(relName, tuple, ruleNum, levelNum, depthLimit);
233 }
234
235 Own<TreeNode> explainSubproof(
236 std::string relName, RamDomain subproofNum, std::size_t depthLimit) override {
237 if (subproofNum >= (int)subproofs.size()) {
238 return mk<LeafNode>("Subproof not found");
239 }
240
241 auto tup = subproofs[subproofNum];
242
243 auto rel = prog.getRelation(relName);
244
245 assert(rel->getAuxiliaryArity() == 2 && "unexpected auxiliary arity in provenance context");
246
247 RamDomain ruleNum;
248 ruleNum = tup[rel->getArity() - 2];
249
250 RamDomain levelNum;
251 levelNum = tup[rel->getArity() - 1];
252
253 tup.erase(tup.begin() + rel->getArity() - 2, tup.end());
254
255 return explain(relName, tup, ruleNum, levelNum, depthLimit);
256 }
257
258 std::vector<std::string> explainNegationGetVariables(
259 std::string relName, std::vector<std::string> args, std::size_t ruleNum) override {
260 std::vector<std::string> variables;
261
262 // check that the tuple actually doesn't exist
263 std::tuple<int, int> foundTuple = findTuple(relName, argsToNums(relName, args));
264 if (std::get<0>(foundTuple) != -1 || std::get<1>(foundTuple) != -1) {
265 // return a sentinel value
266 return std::vector<std::string>({"@"});
267 }
268
269 // atom meta information stored for the current rule
270 auto atoms = info[std::make_pair(relName, ruleNum)];
271
272 // the info stores the set of atoms, if there is only 1 atom, then it must be the head, so it must be
273 // a fact
274 if (atoms.size() <= 1) {
275 return std::vector<std::string>({"@fact"});
276 }
277
278 // atoms[0] represents variables in the head atom
279 auto headVariables = splitString(atoms[0], ',');
280
281 auto isVariable = [&](std::string arg) {
282 return !(isNumber(arg.c_str()) || arg[0] == '\"' || arg == "_");
283 };
284
285 // check that head variable bindings make sense, i.e. for a head like a(x, x), make sure both x are
286 // the same value
287 std::map<std::string, std::string> headVariableMapping;
288 for (std::size_t i = 0; i < headVariables.size(); i++) {
289 if (!isVariable(headVariables[i])) {
290 continue;
291 }
292
293 if (headVariableMapping.find(headVariables[i]) == headVariableMapping.end()) {
294 headVariableMapping[headVariables[i]] = args[i];
295 } else {
296 if (headVariableMapping[headVariables[i]] != args[i]) {
297 return std::vector<std::string>({"@non_matching"});
298 }
299 }
300 }
301
302 // get body variables
303 std::vector<std::string> uniqueBodyVariables;
304 for (auto it = atoms.begin() + 1; it < atoms.end(); it++) {
305 auto atomRepresentation = splitString(*it, ',');
306
307 // atomRepresentation.begin() + 1 because the first element is the relation name of the atom
308 // which is not relevant for finding variables
309 for (auto atomIt = atomRepresentation.begin() + 1; atomIt < atomRepresentation.end(); atomIt++) {
310 if (!isVariable(*atomIt)) {
311 continue;
312 }
313
314 if (!contains(uniqueBodyVariables, *atomIt) && !contains(headVariables, *atomIt)) {
315 uniqueBodyVariables.push_back(*atomIt);
316 }
317 }
318 }
319
320 return uniqueBodyVariables;
321 }
322
323 Own<TreeNode> explainNegation(std::string relName, std::size_t ruleNum,
324 const std::vector<std::string>& tuple,
325 std::map<std::string, std::string>& bodyVariables) override {
326 // construct a vector of unique variables that occur in the rule
327 std::vector<std::string> uniqueVariables;
328
329 // we also need to know the type of each variable
330 std::map<std::string, char> variableTypes;
331
332 // atom meta information stored for the current rule
333 auto atoms = info.at(std::make_pair(relName, ruleNum));
334
335 // atoms[0] represents variables in the head atom
336 auto headVariables = splitString(atoms[0], ',');
337
338 uniqueVariables.insert(uniqueVariables.end(), headVariables.begin(), headVariables.end());
339
340 auto isVariable = [&](std::string arg) {
341 return !(isNumber(arg.c_str()) || arg[0] == '\"' || arg == "_");
342 };
343
344 // get body variables
345 for (auto it = atoms.begin() + 1; it < atoms.end(); it++) {
346 auto atomRepresentation = splitString(*it, ',');
347
348 // atomRepresentation.begin() + 1 because the first element is the relation name of the atom
349 // which is not relevant for finding variables
350 for (auto atomIt = atomRepresentation.begin() + 1; atomIt < atomRepresentation.end(); atomIt++) {
351 if (!contains(uniqueVariables, *atomIt) && !contains(headVariables, *atomIt)) {
352 // ignore non-variables
353 if (!isVariable(*atomIt)) {
354 continue;
355 }
356
357 uniqueVariables.push_back(*atomIt);
358
359 if (!contains(constraintList, atomRepresentation[0])) {
360 // store type of variable
361 auto currentRel = prog.getRelation(atomRepresentation[0]);
362 assert(currentRel != nullptr &&
363 ("relation " + atomRepresentation[0] + " doesn't exist").c_str());
364 variableTypes[*atomIt] =
365 *currentRel->getAttrType(atomIt - atomRepresentation.begin() - 1);
366 } else if (atomIt->find("agg_") != std::string::npos) {
367 variableTypes[*atomIt] = 'i';
368 }
369 }
370 }
371 }
372
373 std::vector<RamDomain> args;
374
375 std::size_t varCounter = 0;
376
377 // construct arguments to pass in to the subroutine
378 // - this contains the variable bindings selected by the user
379
380 // add number representation of tuple
381 auto tupleNums = argsToNums(relName, tuple);
382 args.insert(args.end(), tupleNums.begin(), tupleNums.end());
383 varCounter += tuple.size();
384
385 while (varCounter < uniqueVariables.size()) {
386 auto var = uniqueVariables[varCounter];
387 auto varValue = bodyVariables[var];
388 if (variableTypes[var] == 's') {
389 if (varValue.size() >= 2 && varValue[0] == '"' && varValue[varValue.size() - 1] == '"') {
390 auto originalStr = varValue.substr(1, varValue.size() - 2);
391 args.push_back(symTable.encode(originalStr));
392 } else {
393 // assume no quotation marks
394 args.push_back(symTable.encode(varValue));
395 }
396 } else {
397 args.push_back(std::stoi(varValue));
398 }
399
400 varCounter++;
401 }
402
403 // set up return and error vectors for subroutine calling
404 std::vector<RamDomain> ret;
405
406 // execute subroutine to get subproofs
407 prog.executeSubroutine(relName + "_" + std::to_string(ruleNum) + "_negation_subproof", args, ret);
408
409 // ensure the subroutine returns the correct number of results
410 assert(ret.size() == atoms.size() - 1);
411
412 // construct tree nodes
413 std::stringstream joinedArgsStr;
414 joinedArgsStr << join(tuple, ",");
415 auto internalNode = mk<InnerNode>(
416 relName + "(" + joinedArgsStr.str() + ")", "(R" + std::to_string(ruleNum) + ")");
417
418 // store the head tuple in bodyVariables so we can print
419 for (std::size_t i = 0; i < headVariables.size(); i++) {
420 bodyVariables[headVariables[i]] = tuple[i];
421 }
422
423 // traverse return vector and construct child nodes
424 // making sure we display existent and non-existent tuples correctly
425 int literalCounter = 1;
426 for (RamDomain returnCounter : ret) {
427 // check what the next contained atom is
428 bool atomExists = true;
429 if (returnCounter == 0) {
430 atomExists = false;
431 }
432
433 // get the relation of the current atom
434 auto atomRepresentation = splitString(atoms[literalCounter], ',');
435 std::string bodyRel = atomRepresentation[0];
436
437 // check whether the current atom is a constraint
438 bool isConstraint = contains(constraintList, bodyRel);
439
440 // handle negated atom names
441 auto bodyRelAtomName = bodyRel;
442 if (bodyRel[0] == '!' && bodyRel != "!=") {
443 bodyRelAtomName = bodyRel.substr(1);
444 }
445
446 // construct a label for a node containing a literal (either constraint or atom)
447 std::stringstream childLabel;
448 if (isConstraint) {
449 // for a binary constraint, display the corresponding values and do not recurse
450 assert(atomRepresentation.size() == 3 && "not a binary constraint");
451
452 childLabel << bodyVariables[atomRepresentation[1]] << " " << bodyRel << " "
453 << bodyVariables[atomRepresentation[2]];
454 } else {
455 childLabel << bodyRel << "(";
456 for (std::size_t i = 1; i < atomRepresentation.size(); i++) {
457 // if it's a non-variable, print either _ for unnamed, or constant value
458 if (!isVariable(atomRepresentation[i])) {
459 childLabel << atomRepresentation[i];
460 } else {
461 childLabel << bodyVariables[atomRepresentation[i]];
462 }
463 if (i < atomRepresentation.size() - 1) {
464 childLabel << ", ";
465 }
466 }
467 childLabel << ")";
468 }
469
470 // build a marker for existence of body atoms
471 if (atomExists) {
472 childLabel << " ✓";
473 } else {
474 childLabel << " x";
475 }
476
477 internalNode->add_child(mk<LeafNode>(childLabel.str()));
478 internalNode->setSize(internalNode->getSize() + 1);
479
480 literalCounter++;
481 }
482
483 return internalNode;
484 }
485
486 std::string getRule(std::string relName, std::size_t ruleNum) override {
487 auto key = make_pair(relName, ruleNum);
488
489 auto rule = rules.find(key);
490 if (rule == rules.end()) {
491 return "Rule not found";
492 } else {
493 return rule->second;
494 }
495 }
496
497 std::vector<std::string> getRules(const std::string& relName) override {
498 std::vector<std::string> relRules;
499 // go through all rules
500 for (auto& rule : rules) {
501 if (rule.first.first == relName) {
502 relRules.push_back(rule.second);
503 }
504 }
505
506 return relRules;
507 }
508
509 std::string measureRelation(std::string relName) override {
510 auto rel = prog.getRelation(relName);
511
512 if (rel == nullptr) {
513 return "No relation found\n";
514 }
515
516 auto size = rel->size();
517 int skip = size / 10;
518
519 if (skip == 0) {
520 skip = 1;
521 }
522
523 std::stringstream ss;
524
525 auto before_time = std::chrono::high_resolution_clock::now();
526
527 int numTuples = 0;
528 int proc = 0;
529 for (auto& tuple : *rel) {
530 auto tupleStart = std::chrono::high_resolution_clock::now();
531
532 if (numTuples % skip != 0) {
533 numTuples++;
534 continue;
535 }
536
537 std::vector<RamDomain> currentTuple;
538 for (arity_type i = 0; i < rel->getPrimaryArity(); i++) {
539 RamDomain n;
540 if (*rel->getAttrType(i) == 's') {
541 std::string s;
542 tuple >> s;
543 n = lookupExisting(s);
544 } else if (*rel->getAttrType(i) == 'f') {
545 RamFloat element;
546 tuple >> element;
547 n = ramBitCast(element);
548 } else if (*rel->getAttrType(i) == 'u') {
549 RamUnsigned element;
550 tuple >> element;
551 n = ramBitCast(element);
552 } else {
553 tuple >> n;
554 }
555
556 currentTuple.push_back(n);
557 }
558
559 RamDomain ruleNum;
560 tuple >> ruleNum;
561
562 RamDomain levelNum;
563 tuple >> levelNum;
564
565 std::cout << "Tuples expanded: "
566 << explain(relName, currentTuple, ruleNum, levelNum, 10000)->getSize();
567
568 numTuples++;
569 proc++;
570
571 auto tupleEnd = std::chrono::high_resolution_clock::now();
572 auto tupleDuration =
573 std::chrono::duration_cast<std::chrono::duration<double>>(tupleEnd - tupleStart);
574
575 std::cout << ", Time: " << tupleDuration.count() << "\n";
576 }
577
578 auto after_time = std::chrono::high_resolution_clock::now();
579 auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(after_time - before_time);
580
581 ss << "total: " << proc << " ";
582 ss << duration.count() << std::endl;
583
584 return ss.str();
585 }
586
587 void printRulesJSON(std::ostream& os) override {
588 os << "\"rules\": [\n";
589 bool first = true;
590 for (auto const& cur : rules) {
591 if (first) {
592 first = false;
593 } else {
594 os << ",\n";
595 }
596 os << "\t{ \"rule-number\": \"(R" << cur.first.second << ")\", \"rule\": \""
597 << stringify(cur.second) << "\"}";
598 }
599 os << "\n]\n";
600 }
601
602 void queryProcess(const std::vector<std::pair<std::string, std::vector<std::string>>>& rels) override {
603 std::regex varRegex("[a-zA-Z_][a-zA-Z_0-9]*", std::regex_constants::extended);
604 std::regex symbolRegex("\"([^\"]*)\"", std::regex_constants::extended);
605 std::regex numberRegex("[0-9]+", std::regex_constants::extended);
606
607 std::smatch argsMatcher;
608
609 // map for variable name and corresponding equivalence class
610 std::map<std::string, Equivalence> nameToEquivalence;
611
612 // const constraints that solution must satisfy
613 ConstConstraint constConstraints;
614
615 // relations of tuples containing variables
616 std::vector<Relation*> varRels;
617
618 // counter for adding element to varRels
619 std::size_t idx = 0;
620
621 // parse arguments in each relation Tuple
622 for (const auto& rel : rels) {
623 Relation* relation = prog.getRelation(rel.first);
624 // number/symbol index for constant arguments in tuple
625 std::vector<RamDomain> constTuple;
626 // relation does not exist
627 if (relation == nullptr) {
628 std::cout << "Relation <" << rel.first << "> does not exist" << std::endl;
629 return;
630 }
631 // arity error
632 if (relation->getPrimaryArity() != rel.second.size()) {
633 std::cout << "<" + rel.first << "> has arity of " << relation->getPrimaryArity() << std::endl;
634 return;
635 }
636
637 // check if args contain variable
638 bool containVar = false;
639 for (std::size_t j = 0; j < rel.second.size(); ++j) {
640 // arg is a variable
641 if (std::regex_match(rel.second[j], argsMatcher, varRegex)) {
642 containVar = true;
643 auto nameToEquivalenceIter = nameToEquivalence.find(argsMatcher[0]);
644 // if variable has not shown up before, create an equivalence class for add it to
645 // nameToEquivalence map, otherwise add its indices to corresponding equivalence class
646 if (nameToEquivalenceIter == nameToEquivalence.end()) {
647 nameToEquivalence.insert(
648 {argsMatcher[0], Equivalence(*(relation->getAttrType(j)), argsMatcher[0],
649 std::make_pair(idx, j))});
650 } else {
651 nameToEquivalenceIter->second.push_back(std::make_pair(idx, j));
652 }
653 continue;
654 }
655
656 RamDomain rd;
657 switch (*(relation->getAttrType(j))) {
658 case 's':
659 if (!std::regex_match(rel.second[j], argsMatcher, symbolRegex)) {
660 std::cout << argsMatcher.str(0) << " does not match type defined in relation"
661 << std::endl;
662 return;
663 }
664 rd = prog.getSymbolTable().encode(argsMatcher[1]);
665 break;
666 case 'f':
667 if (!canBeParsedAsRamFloat(rel.second[j])) {
668 std::cout << rel.second[j] << " does not match type defined in relation"
669 << std::endl;
670 return;
671 }
672 rd = ramBitCast(RamFloatFromString(rel.second[j]));
673 break;
674 case 'i':
675 if (!canBeParsedAsRamSigned(rel.second[j])) {
676 std::cout << rel.second[j] << " does not match type defined in relation"
677 << std::endl;
678 return;
679 }
680 rd = ramBitCast(RamSignedFromString(rel.second[j]));
681 break;
682 case 'u':
683 if (!canBeParsedAsRamUnsigned(rel.second[j])) {
684 std::cout << rel.second[j] << " does not match type defined in relation"
685 << std::endl;
686 return;
687 }
688 rd = ramBitCast(RamUnsignedFromString(rel.second[j]));
689 break;
690 default: continue;
691 }
692
693 constConstraints.push_back(std::make_pair(std::make_pair(idx, j), rd));
694 if (!containVar) {
695 constTuple.push_back(rd);
696 }
697 }
698
699 // if tuple does not contain any variable, check if existence of the tuple
700 if (!containVar) {
701 bool tupleExist = containsTuple(relation, constTuple);
702
703 // if relation contains this tuple, remove all related constraints
704 if (tupleExist) {
705 constConstraints.getConstraints().erase(constConstraints.getConstraints().end() -
706 relation->getArity() +
707 relation->getAuxiliaryArity(),
708 constConstraints.getConstraints().end());
709 // otherwise, there is no solution for given query
710 } else {
711 std::cout << "false." << std::endl;
712 std::cout << "Tuple " << rel.first << "(";
713 for (std::size_t l = 0; l < rel.second.size() - 1; ++l) {
714 std::cout << rel.second[l] << ", ";
715 }
716 std::cout << rel.second.back() << ") does not exist" << std::endl;
717 return;
718 }
719 } else {
720 varRels.push_back(relation);
721 ++idx;
722 }
723 }
724
725 // if varRels size is 0, all given tuples only contain constant args and exist, no variable to
726 // decode, Output true and return
727 if (varRels.size() == 0) {
728 std::cout << "true." << std::endl;
729 return;
730 }
731
732 // find solution for parameterised query
733 findQuerySolution(varRels, nameToEquivalence, constConstraints);
734 }
735
736private:
737 std::map<std::pair<std::string, std::size_t>, std::vector<std::string>> info;
738 std::map<std::pair<std::string, std::size_t>, std::string> rules;
739 std::vector<std::vector<RamDomain>> subproofs;
740 std::vector<std::string> constraintList = {
741 "=", "!=", "<", "<=", ">=", ">", "match", "contains", "not_match", "not_contains"};
742
743 RamDomain lookupExisting(const std::string& symbol) {
744 auto Res = symTable.findOrInsert(symbol);
745 if (Res.second) {
746 fatal("Error string did not exist before call to `SymbolTable::findOrInsert`: `%s`", symbol);
747 }
748 return Res.first;
749 }
750
751 std::tuple<int, int> findTuple(const std::string& relName, std::vector<RamDomain> tup) {
752 auto rel = prog.getRelation(relName);
753
754 if (rel == nullptr) {
755 return std::make_tuple(-1, -1);
756 }
757
758 // find correct tuple
759 for (auto& tuple : *rel) {
760 bool match = true;
761 std::vector<RamDomain> currentTuple;
762
763 for (arity_type i = 0; i < rel->getPrimaryArity(); i++) {
764 RamDomain n;
765 if (*rel->getAttrType(i) == 's') {
766 std::string s;
767 tuple >> s;
768 n = lookupExisting(s);
769 } else if (*rel->getAttrType(i) == 'f') {
770 RamFloat element;
771 tuple >> element;
772 n = ramBitCast(element);
773 } else if (*rel->getAttrType(i) == 'u') {
774 RamUnsigned element;
775 tuple >> element;
776 n = ramBitCast(element);
777 } else {
778 tuple >> n;
779 }
780
781 currentTuple.push_back(n);
782
783 if (n != tup[i]) {
784 match = false;
785 break;
786 }
787 }
788
789 if (match) {
790 RamDomain ruleNum;
791 tuple >> ruleNum;
792
793 RamDomain levelNum;
794 tuple >> levelNum;
795
796 return std::make_tuple(ruleNum, levelNum);
797 }
798 }
799
800 // if no tuple exists
801 return std::make_tuple(-1, -1);
802 }
803
804 /*
805 * Find solution for parameterised query satisfying constant constraints and equivalence constraints
806 * @param varRels, reference to vector of relation of tuple contains at least one variable in its
807 * arguments
808 * @param nameToEquivalence, reference to variable name and corresponding equivalence class
809 * @param constConstraints, reference to const constraints must be satisfied
810 * */
811 void findQuerySolution(const std::vector<Relation*>& varRels,
812 const std::map<std::string, Equivalence>& nameToEquivalence,
813 const ConstConstraint& constConstraints) {
814 // vector of iterators for relations in varRels
815 std::vector<Relation::iterator> varRelationIterators;
816 for (auto relation : varRels) {
817 varRelationIterators.push_back(relation->begin());
818 }
819
820 std::size_t solutionCount = 0;
821 std::stringstream solution;
822
823 // iterate through the vector of iterators to find solution
824 while (true) {
825 bool isSolution = true;
826
827 // vector contains the tuples the iterators currently points to
828 std::vector<tuple> element;
829 for (auto it : varRelationIterators) {
830 element.push_back(*it);
831 }
832 // check if tuple satisfies variable equivalence
833 for (auto var : nameToEquivalence) {
834 if (!var.second.verify(element)) {
835 isSolution = false;
836 break;
837 }
838 }
839 if (isSolution) {
840 // check if tuple satisfies constant constraints
841 isSolution = constConstraints.verify(element);
842 }
843
844 if (isSolution) {
845 // print previous solution (if any)
846 if (solutionCount != 0) {
847 std::cout << solution.str() << std::endl;
848 }
849 solution.str(std::string()); // reset solution and process
850
851 std::size_t c = 0;
852 for (auto&& var : nameToEquivalence) {
853 auto idx = var.second.getFirstIdx();
854 auto raw = element[idx.first][idx.second];
855
856 solution << var.second.getSymbol() << " = ";
857 switch (var.second.getType()) {
858 case 'i': solution << ramBitCast<RamSigned>(raw); break;
859 case 'f': solution << ramBitCast<RamFloat>(raw); break;
860 case 'u': solution << ramBitCast<RamUnsigned>(raw); break;
861 case 's': solution << prog.getSymbolTable().decode(raw); break;
862 default: fatal("invalid type: `%c`", var.second.getType());
863 }
864
865 if (++c < nameToEquivalence.size()) {
866 solution << ", ";
867 }
868 }
869
870 solutionCount++;
871 // query has more than one solution; query whether to find next solution or stop
872 if (1 < solutionCount) {
873 for (std::string input; getline(std::cin, input);) {
874 if (input == ";") break; // print next solution?
875 if (input == ".") return; // break from query?
876
877 std::cout << "use ; to find next solution, use . to break from current query\n";
878 }
879 }
880 }
881
882 // increment the iterators
883 std::size_t i = varRels.size() - 1;
884 bool terminate = true;
885 for (auto it = varRelationIterators.rbegin(); it != varRelationIterators.rend(); ++it) {
886 if ((++(*it)) != varRels[i]->end()) {
887 terminate = false;
888 break;
889 } else {
890 (*it) = varRels[i]->begin();
891 --i;
892 }
893 }
894
895 if (terminate) {
896 // if there is no solution, output false
897 if (solutionCount == 0) {
898 std::cout << "false." << std::endl;
899 // otherwise print the last solution
900 } else {
901 std::cout << solution.str() << "." << std::endl;
902 }
903 break;
904 }
905 }
906 }
907
908 // check if constTuple exists in relation
909 bool containsTuple(Relation* relation, const std::vector<RamDomain>& constTuple) {
910 bool tupleExist = false;
911 for (auto it = relation->begin(); it != relation->end(); ++it) {
912 bool eq = true;
913 for (std::size_t j = 0; j < constTuple.size(); ++j) {
914 if (constTuple[j] != (*it)[j]) {
915 eq = false;
916 break;
917 }
918 }
919 if (eq) {
920 tupleExist = true;
921 break;
922 }
923 }
924 return tupleExist;
925 }
926};
927
928} // end of namespace souffle