| 1 | /* | 
| 2 | * Souffle - A Datalog Compiler | 
| 3 | * Copyright (c) 2017, The Souffle Developers. All rights reserved | 
| 4 | * Licensed under the Universal Permissive License v 1.0 as shown at: | 
| 5 | * - https://opensource.org/licenses/UPL | 
| 6 | * - <souffle root>/licenses/SOUFFLE-UPL.txt | 
| 7 | */ | 
| 8 |  | 
| 9 | /************************************************************************ | 
| 10 | * | 
| 11 | * @file ExplainProvenanceImpl.h | 
| 12 | * | 
| 13 | * Implementation of abstract class in ExplainProvenance.h for guided Impl provenance | 
| 14 | * | 
| 15 | ***********************************************************************/ | 
| 16 |  | 
| 17 | #pragma once | 
| 18 |  | 
| 19 | #include "souffle/BinaryConstraintOps.h" | 
| 20 | #include "souffle/RamTypes.h" | 
| 21 | #include "souffle/SouffleInterface.h" | 
| 22 | #include "souffle/SymbolTable.h" | 
| 23 | #include "souffle/provenance/ExplainProvenance.h" | 
| 24 | #include "souffle/provenance/ExplainTree.h" | 
| 25 | #include "souffle/utility/ContainerUtil.h" | 
| 26 | #include "souffle/utility/MiscUtil.h" | 
| 27 | #include "souffle/utility/StreamUtil.h" | 
| 28 | #include "souffle/utility/StringUtil.h" | 
| 29 | #include <algorithm> | 
| 30 | #include <cassert> | 
| 31 | #include <chrono> | 
| 32 | #include <cstdio> | 
| 33 | #include <iostream> | 
| 34 | #include <map> | 
| 35 | #include <memory> | 
| 36 | #include <regex> | 
| 37 | #include <sstream> | 
| 38 | #include <string> | 
| 39 | #include <tuple> | 
| 40 | #include <type_traits> | 
| 41 | #include <utility> | 
| 42 | #include <vector> | 
| 43 |  | 
| 44 | namespace souffle { | 
| 45 |  | 
| 46 | using namespace stream_write_qualified_char_as_number; | 
| 47 |  | 
| 48 | class ExplainProvenanceImpl : public ExplainProvenance { | 
| 49 | using arity_type = Relation::arity_type; | 
| 50 |  | 
| 51 | public: | 
| 52 | ExplainProvenanceImpl(SouffleProgram& prog) : ExplainProvenance(prog) { | 
| 53 | setup(); | 
| 54 | } | 
| 55 |  | 
| 56 | void setup() override { | 
| 57 | // for each clause, store a mapping from the head relation name to body relation names | 
| 58 | for (auto& rel : prog.getAllRelations()) { | 
| 59 | std::string name = rel->getName(); | 
| 60 |  | 
| 61 | // only process info relations | 
| 62 | if (name.find("@info") == std::string::npos) { | 
| 63 | continue; | 
| 64 | } | 
| 65 |  | 
| 66 | // find all the info tuples | 
| 67 | for (auto& tuple : *rel) { | 
| 68 | std::vector<std::string> bodyLiterals; | 
| 69 |  | 
| 70 | // first field is rule number | 
| 71 | RamDomain ruleNum; | 
| 72 | tuple >> ruleNum; | 
| 73 |  | 
| 74 | // middle fields are body literals | 
| 75 | for (std::size_t i = 1; i + 1 < rel->getArity(); i++) { | 
| 76 | std::string bodyLit; | 
| 77 | tuple >> bodyLit; | 
| 78 | bodyLiterals.push_back(bodyLit); | 
| 79 | } | 
| 80 |  | 
| 81 | // last field is the rule itself | 
| 82 | std::string rule; | 
| 83 | tuple >> rule; | 
| 84 |  | 
| 85 | std::string relName = name.substr(0, name.find(".@info")); | 
| 86 | info.insert({std::make_pair(relName, ruleNum), bodyLiterals}); | 
| 87 | rules.insert({std::make_pair(relName, ruleNum), rule}); | 
| 88 | } | 
| 89 | } | 
| 90 | } | 
| 91 |  | 
| 92 | Own<TreeNode> explain(std::string relName, std::vector<RamDomain> tuple, int ruleNum, int levelNum, | 
| 93 | std::size_t depthLimit) { | 
| 94 | std::stringstream joinedArgs; | 
| 95 | joinedArgs << join(decodeArguments(relName, tuple), ", "); | 
| 96 | auto joinedArgsStr = joinedArgs.str(); | 
| 97 |  | 
| 98 | // if fact | 
| 99 | if (levelNum == 0) { | 
| 100 | return mk<LeafNode>(relName + "(" + joinedArgsStr + ")"); | 
| 101 | } | 
| 102 |  | 
| 103 | assert(contains(info, std::make_pair(relName, ruleNum)) && "invalid rule for tuple"); | 
| 104 |  | 
| 105 | // if depth limit exceeded | 
| 106 | if (depthLimit <= 1) { | 
| 107 | tuple.push_back(ruleNum); | 
| 108 | tuple.push_back(levelNum); | 
| 109 |  | 
| 110 | // find if subproof exists already | 
| 111 | std::size_t idx = 0; | 
| 112 | auto it = std::find(subproofs.begin(), subproofs.end(), tuple); | 
| 113 | if (it != subproofs.end()) { | 
| 114 | idx = it - subproofs.begin(); | 
| 115 | } else { | 
| 116 | subproofs.push_back(tuple); | 
| 117 | idx = subproofs.size() - 1; | 
| 118 | } | 
| 119 |  | 
| 120 | return mk<LeafNode>("subproof " + relName + "(" + std::to_string(idx) + ")"); | 
| 121 | } | 
| 122 |  | 
| 123 | tuple.push_back(levelNum); | 
| 124 |  | 
| 125 | auto internalNode = | 
| 126 | mk<InnerNode>(relName + "(" + joinedArgsStr + ")", "(R" + std::to_string(ruleNum) + ")"); | 
| 127 |  | 
| 128 | // make return vector pointer | 
| 129 | std::vector<RamDomain> ret; | 
| 130 |  | 
| 131 | // execute subroutine to get subproofs | 
| 132 | prog.executeSubroutine(relName + "_" + std::to_string(ruleNum) + "_subproof", tuple, ret); | 
| 133 |  | 
| 134 | // recursively get nodes for subproofs | 
| 135 | std::size_t tupleCurInd = 0; | 
| 136 | auto bodyRelations = info.at(std::make_pair(relName, ruleNum)); | 
| 137 |  | 
| 138 | // start from begin + 1 because the first element represents the head atom | 
| 139 | for (auto it = bodyRelations.begin() + 1; it < bodyRelations.end(); it++) { | 
| 140 | std::string bodyLiteral = *it; | 
| 141 | // split bodyLiteral since it contains relation name plus arguments | 
| 142 | std::string bodyRel = splitString(bodyLiteral, ',')[0]; | 
| 143 |  | 
| 144 | // check whether the current atom is a constraint | 
| 145 | assert(bodyRel.size() > 0 && "body of a relation should have positive length"); | 
| 146 | bool isConstraint = contains(constraintList, bodyRel); | 
| 147 |  | 
| 148 | // handle negated atom names | 
| 149 | auto bodyRelAtomName = bodyRel; | 
| 150 | if (bodyRel[0] == '!' && bodyRel != "!=") { | 
| 151 | bodyRelAtomName = bodyRel.substr(1); | 
| 152 | } | 
| 153 |  | 
| 154 | // traverse subroutine return | 
| 155 | std::size_t arity; | 
| 156 | std::size_t auxiliaryArity; | 
| 157 | if (isConstraint) { | 
| 158 | // we only handle binary constraints, and assume arity is 4 to account for hidden provenance | 
| 159 | // annotations | 
| 160 | arity = 4; | 
| 161 | auxiliaryArity = 2; | 
| 162 | } else { | 
| 163 | arity = prog.getRelation(bodyRelAtomName)->getArity(); | 
| 164 | auxiliaryArity = prog.getRelation(bodyRelAtomName)->getAuxiliaryArity(); | 
| 165 | } | 
| 166 | auto tupleEnd = tupleCurInd + arity; | 
| 167 |  | 
| 168 | // store current tuple | 
| 169 | std::vector<RamDomain> subproofTuple; | 
| 170 |  | 
| 171 | for (; tupleCurInd < tupleEnd - auxiliaryArity; tupleCurInd++) { | 
| 172 | subproofTuple.push_back(ret[tupleCurInd]); | 
| 173 | } | 
| 174 |  | 
| 175 | int subproofRuleNum = ret[tupleCurInd]; | 
| 176 | int subproofLevelNum = ret[tupleCurInd + 1]; | 
| 177 |  | 
| 178 | tupleCurInd += 2; | 
| 179 |  | 
| 180 | // for a negation, display the corresponding tuple and do not recurse | 
| 181 | if (bodyRel[0] == '!' && bodyRel != "!=") { | 
| 182 | std::stringstream joinedTuple; | 
| 183 | joinedTuple << join(decodeArguments(bodyRelAtomName, subproofTuple), ", "); | 
| 184 | auto joinedTupleStr = joinedTuple.str(); | 
| 185 | internalNode->add_child(mk<LeafNode>(bodyRel + "(" + joinedTupleStr + ")")); | 
| 186 | internalNode->setSize(internalNode->getSize() + 1); | 
| 187 | // for a binary constraint, display the corresponding values and do not recurse | 
| 188 | } else if (isConstraint) { | 
| 189 | std::stringstream joinedConstraint; | 
| 190 |  | 
| 191 | // FIXME: We need type info in order to figure out how to print arguments. | 
| 192 | BinaryConstraintOp rawBinOp = toBinaryConstraintOp(bodyRel); | 
| 193 | if (isOrderedBinaryConstraintOp(rawBinOp)) { | 
| 194 | joinedConstraint << subproofTuple[0] << " " << bodyRel << " " << subproofTuple[1]; | 
| 195 | } else { | 
| 196 | joinedConstraint << bodyRel << "(\"" << symTable.decode(subproofTuple[0]) << "\", \"" | 
| 197 | << symTable.decode(subproofTuple[1]) << "\")"; | 
| 198 | } | 
| 199 |  | 
| 200 | internalNode->add_child(mk<LeafNode>(joinedConstraint.str())); | 
| 201 | internalNode->setSize(internalNode->getSize() + 1); | 
| 202 | // otherwise, for a normal tuple, recurse | 
| 203 | } else { | 
| 204 | auto child = | 
| 205 | explain(bodyRel, subproofTuple, subproofRuleNum, subproofLevelNum, depthLimit - 1); | 
| 206 | internalNode->setSize(internalNode->getSize() + child->getSize()); | 
| 207 | internalNode->add_child(std::move(child)); | 
| 208 | } | 
| 209 |  | 
| 210 | tupleCurInd = tupleEnd; | 
| 211 | } | 
| 212 |  | 
| 213 | return internalNode; | 
| 214 | } | 
| 215 |  | 
| 216 | Own<TreeNode> explain( | 
| 217 | std::string relName, std::vector<std::string> args, std::size_t depthLimit) override { | 
| 218 | auto tuple = argsToNums(relName, args); | 
| 219 | if (tuple.empty()) { | 
| 220 | return mk<LeafNode>("Relation not found"); | 
| 221 | } | 
| 222 |  | 
| 223 | std::tuple<int, int> tupleInfo = findTuple(relName, tuple); | 
| 224 |  | 
| 225 | int ruleNum = std::get<0>(tupleInfo); | 
| 226 | int levelNum = std::get<1>(tupleInfo); | 
| 227 |  | 
| 228 | if (ruleNum < 0 || levelNum == -1) { | 
| 229 | return mk<LeafNode>("Tuple not found"); | 
| 230 | } | 
| 231 |  | 
| 232 | return explain(relName, tuple, ruleNum, levelNum, depthLimit); | 
| 233 | } | 
| 234 |  | 
| 235 | Own<TreeNode> explainSubproof( | 
| 236 | std::string relName, RamDomain subproofNum, std::size_t depthLimit) override { | 
| 237 | if (subproofNum >= (int)subproofs.size()) { | 
| 238 | return mk<LeafNode>("Subproof not found"); | 
| 239 | } | 
| 240 |  | 
| 241 | auto tup = subproofs[subproofNum]; | 
| 242 |  | 
| 243 | auto rel = prog.getRelation(relName); | 
| 244 |  | 
| 245 | assert(rel->getAuxiliaryArity() == 2 && "unexpected auxiliary arity in provenance context"); | 
| 246 |  | 
| 247 | RamDomain ruleNum; | 
| 248 | ruleNum = tup[rel->getArity() - 2]; | 
| 249 |  | 
| 250 | RamDomain levelNum; | 
| 251 | levelNum = tup[rel->getArity() - 1]; | 
| 252 |  | 
| 253 | tup.erase(tup.begin() + rel->getArity() - 2, tup.end()); | 
| 254 |  | 
| 255 | return explain(relName, tup, ruleNum, levelNum, depthLimit); | 
| 256 | } | 
| 257 |  | 
| 258 | std::vector<std::string> explainNegationGetVariables( | 
| 259 | std::string relName, std::vector<std::string> args, std::size_t ruleNum) override { | 
| 260 | std::vector<std::string> variables; | 
| 261 |  | 
| 262 | // check that the tuple actually doesn't exist | 
| 263 | std::tuple<int, int> foundTuple = findTuple(relName, argsToNums(relName, args)); | 
| 264 | if (std::get<0>(foundTuple) != -1 || std::get<1>(foundTuple) != -1) { | 
| 265 | // return a sentinel value | 
| 266 | return std::vector<std::string>({"@"}); | 
| 267 | } | 
| 268 |  | 
| 269 | // atom meta information stored for the current rule | 
| 270 | auto atoms = info[std::make_pair(relName, ruleNum)]; | 
| 271 |  | 
| 272 | // the info stores the set of atoms, if there is only 1 atom, then it must be the head, so it must be | 
| 273 | // a fact | 
| 274 | if (atoms.size() <= 1) { | 
| 275 | return std::vector<std::string>({"@fact"}); | 
| 276 | } | 
| 277 |  | 
| 278 | // atoms[0] represents variables in the head atom | 
| 279 | auto headVariables = splitString(atoms[0], ','); | 
| 280 |  | 
| 281 | auto isVariable = [&](std::string arg) { | 
| 282 | return !(isNumber(arg.c_str()) || arg[0] == '\"' || arg == "_"); | 
| 283 | }; | 
| 284 |  | 
| 285 | // check that head variable bindings make sense, i.e. for a head like a(x, x), make sure both x are | 
| 286 | // the same value | 
| 287 | std::map<std::string, std::string> headVariableMapping; | 
| 288 | for (std::size_t i = 0; i < headVariables.size(); i++) { | 
| 289 | if (!isVariable(headVariables[i])) { | 
| 290 | continue; | 
| 291 | } | 
| 292 |  | 
| 293 | if (headVariableMapping.find(headVariables[i]) == headVariableMapping.end()) { | 
| 294 | headVariableMapping[headVariables[i]] = args[i]; | 
| 295 | } else { | 
| 296 | if (headVariableMapping[headVariables[i]] != args[i]) { | 
| 297 | return std::vector<std::string>({"@non_matching"}); | 
| 298 | } | 
| 299 | } | 
| 300 | } | 
| 301 |  | 
| 302 | // get body variables | 
| 303 | std::vector<std::string> uniqueBodyVariables; | 
| 304 | for (auto it = atoms.begin() + 1; it < atoms.end(); it++) { | 
| 305 | auto atomRepresentation = splitString(*it, ','); | 
| 306 |  | 
| 307 | // atomRepresentation.begin() + 1 because the first element is the relation name of the atom | 
| 308 | // which is not relevant for finding variables | 
| 309 | for (auto atomIt = atomRepresentation.begin() + 1; atomIt < atomRepresentation.end(); atomIt++) { | 
| 310 | if (!isVariable(*atomIt)) { | 
| 311 | continue; | 
| 312 | } | 
| 313 |  | 
| 314 | if (!contains(uniqueBodyVariables, *atomIt) && !contains(headVariables, *atomIt)) { | 
| 315 | uniqueBodyVariables.push_back(*atomIt); | 
| 316 | } | 
| 317 | } | 
| 318 | } | 
| 319 |  | 
| 320 | return uniqueBodyVariables; | 
| 321 | } | 
| 322 |  | 
| 323 | Own<TreeNode> explainNegation(std::string relName, std::size_t ruleNum, | 
| 324 | const std::vector<std::string>& tuple, | 
| 325 | std::map<std::string, std::string>& bodyVariables) override { | 
| 326 | // construct a vector of unique variables that occur in the rule | 
| 327 | std::vector<std::string> uniqueVariables; | 
| 328 |  | 
| 329 | // we also need to know the type of each variable | 
| 330 | std::map<std::string, char> variableTypes; | 
| 331 |  | 
| 332 | // atom meta information stored for the current rule | 
| 333 | auto atoms = info.at(std::make_pair(relName, ruleNum)); | 
| 334 |  | 
| 335 | // atoms[0] represents variables in the head atom | 
| 336 | auto headVariables = splitString(atoms[0], ','); | 
| 337 |  | 
| 338 | uniqueVariables.insert(uniqueVariables.end(), headVariables.begin(), headVariables.end()); | 
| 339 |  | 
| 340 | auto isVariable = [&](std::string arg) { | 
| 341 | return !(isNumber(arg.c_str()) || arg[0] == '\"' || arg == "_"); | 
| 342 | }; | 
| 343 |  | 
| 344 | // get body variables | 
| 345 | for (auto it = atoms.begin() + 1; it < atoms.end(); it++) { | 
| 346 | auto atomRepresentation = splitString(*it, ','); | 
| 347 |  | 
| 348 | // atomRepresentation.begin() + 1 because the first element is the relation name of the atom | 
| 349 | // which is not relevant for finding variables | 
| 350 | for (auto atomIt = atomRepresentation.begin() + 1; atomIt < atomRepresentation.end(); atomIt++) { | 
| 351 | if (!contains(uniqueVariables, *atomIt) && !contains(headVariables, *atomIt)) { | 
| 352 | // ignore non-variables | 
| 353 | if (!isVariable(*atomIt)) { | 
| 354 | continue; | 
| 355 | } | 
| 356 |  | 
| 357 | uniqueVariables.push_back(*atomIt); | 
| 358 |  | 
| 359 | if (!contains(constraintList, atomRepresentation[0])) { | 
| 360 | // store type of variable | 
| 361 | auto currentRel = prog.getRelation(atomRepresentation[0]); | 
| 362 | assert(currentRel != nullptr && | 
| 363 | ("relation " + atomRepresentation[0] + " doesn't exist").c_str()); | 
| 364 | variableTypes[*atomIt] = | 
| 365 | *currentRel->getAttrType(atomIt - atomRepresentation.begin() - 1); | 
| 366 | } else if (atomIt->find("agg_") != std::string::npos) { | 
| 367 | variableTypes[*atomIt] = 'i'; | 
| 368 | } | 
| 369 | } | 
| 370 | } | 
| 371 | } | 
| 372 |  | 
| 373 | std::vector<RamDomain> args; | 
| 374 |  | 
| 375 | std::size_t varCounter = 0; | 
| 376 |  | 
| 377 | // construct arguments to pass in to the subroutine | 
| 378 | // - this contains the variable bindings selected by the user | 
| 379 |  | 
| 380 | // add number representation of tuple | 
| 381 | auto tupleNums = argsToNums(relName, tuple); | 
| 382 | args.insert(args.end(), tupleNums.begin(), tupleNums.end()); | 
| 383 | varCounter += tuple.size(); | 
| 384 |  | 
| 385 | while (varCounter < uniqueVariables.size()) { | 
| 386 | auto var = uniqueVariables[varCounter]; | 
| 387 | auto varValue = bodyVariables[var]; | 
| 388 | if (variableTypes[var] == 's') { | 
| 389 | if (varValue.size() >= 2 && varValue[0] == '"' && varValue[varValue.size() - 1] == '"') { | 
| 390 | auto originalStr = varValue.substr(1, varValue.size() - 2); | 
| 391 | args.push_back(symTable.encode(originalStr)); | 
| 392 | } else { | 
| 393 | // assume no quotation marks | 
| 394 | args.push_back(symTable.encode(varValue)); | 
| 395 | } | 
| 396 | } else { | 
| 397 | args.push_back(std::stoi(varValue)); | 
| 398 | } | 
| 399 |  | 
| 400 | varCounter++; | 
| 401 | } | 
| 402 |  | 
| 403 | // set up return and error vectors for subroutine calling | 
| 404 | std::vector<RamDomain> ret; | 
| 405 |  | 
| 406 | // execute subroutine to get subproofs | 
| 407 | prog.executeSubroutine(relName + "_" + std::to_string(ruleNum) + "_negation_subproof", args, ret); | 
| 408 |  | 
| 409 | // ensure the subroutine returns the correct number of results | 
| 410 | assert(ret.size() == atoms.size() - 1); | 
| 411 |  | 
| 412 | // construct tree nodes | 
| 413 | std::stringstream joinedArgsStr; | 
| 414 | joinedArgsStr << join(tuple, ","); | 
| 415 | auto internalNode = mk<InnerNode>( | 
| 416 | relName + "(" + joinedArgsStr.str() + ")", "(R" + std::to_string(ruleNum) + ")"); | 
| 417 |  | 
| 418 | // store the head tuple in bodyVariables so we can print | 
| 419 | for (std::size_t i = 0; i < headVariables.size(); i++) { | 
| 420 | bodyVariables[headVariables[i]] = tuple[i]; | 
| 421 | } | 
| 422 |  | 
| 423 | // traverse return vector and construct child nodes | 
| 424 | // making sure we display existent and non-existent tuples correctly | 
| 425 | int literalCounter = 1; | 
| 426 | for (RamDomain returnCounter : ret) { | 
| 427 | // check what the next contained atom is | 
| 428 | bool atomExists = true; | 
| 429 | if (returnCounter == 0) { | 
| 430 | atomExists = false; | 
| 431 | } | 
| 432 |  | 
| 433 | // get the relation of the current atom | 
| 434 | auto atomRepresentation = splitString(atoms[literalCounter], ','); | 
| 435 | std::string bodyRel = atomRepresentation[0]; | 
| 436 |  | 
| 437 | // check whether the current atom is a constraint | 
| 438 | bool isConstraint = contains(constraintList, bodyRel); | 
| 439 |  | 
| 440 | // handle negated atom names | 
| 441 | auto bodyRelAtomName = bodyRel; | 
| 442 | if (bodyRel[0] == '!' && bodyRel != "!=") { | 
| 443 | bodyRelAtomName = bodyRel.substr(1); | 
| 444 | } | 
| 445 |  | 
| 446 | // construct a label for a node containing a literal (either constraint or atom) | 
| 447 | std::stringstream childLabel; | 
| 448 | if (isConstraint) { | 
| 449 | // for a binary constraint, display the corresponding values and do not recurse | 
| 450 | assert(atomRepresentation.size() == 3 && "not a binary constraint"); | 
| 451 |  | 
| 452 | childLabel << bodyVariables[atomRepresentation[1]] << " " << bodyRel << " " | 
| 453 | << bodyVariables[atomRepresentation[2]]; | 
| 454 | } else { | 
| 455 | childLabel << bodyRel << "("; | 
| 456 | for (std::size_t i = 1; i < atomRepresentation.size(); i++) { | 
| 457 | // if it's a non-variable, print either _ for unnamed, or constant value | 
| 458 | if (!isVariable(atomRepresentation[i])) { | 
| 459 | childLabel << atomRepresentation[i]; | 
| 460 | } else { | 
| 461 | childLabel << bodyVariables[atomRepresentation[i]]; | 
| 462 | } | 
| 463 | if (i < atomRepresentation.size() - 1) { | 
| 464 | childLabel << ", "; | 
| 465 | } | 
| 466 | } | 
| 467 | childLabel << ")"; | 
| 468 | } | 
| 469 |  | 
| 470 | // build a marker for existence of body atoms | 
| 471 | if (atomExists) { | 
| 472 | childLabel << " ✓"; | 
| 473 | } else { | 
| 474 | childLabel << " x"; | 
| 475 | } | 
| 476 |  | 
| 477 | internalNode->add_child(mk<LeafNode>(childLabel.str())); | 
| 478 | internalNode->setSize(internalNode->getSize() + 1); | 
| 479 |  | 
| 480 | literalCounter++; | 
| 481 | } | 
| 482 |  | 
| 483 | return internalNode; | 
| 484 | } | 
| 485 |  | 
| 486 | std::string getRule(std::string relName, std::size_t ruleNum) override { | 
| 487 | auto key = make_pair(relName, ruleNum); | 
| 488 |  | 
| 489 | auto rule = rules.find(key); | 
| 490 | if (rule == rules.end()) { | 
| 491 | return "Rule not found"; | 
| 492 | } else { | 
| 493 | return rule->second; | 
| 494 | } | 
| 495 | } | 
| 496 |  | 
| 497 | std::vector<std::string> getRules(const std::string& relName) override { | 
| 498 | std::vector<std::string> relRules; | 
| 499 | // go through all rules | 
| 500 | for (auto& rule : rules) { | 
| 501 | if (rule.first.first == relName) { | 
| 502 | relRules.push_back(rule.second); | 
| 503 | } | 
| 504 | } | 
| 505 |  | 
| 506 | return relRules; | 
| 507 | } | 
| 508 |  | 
| 509 | std::string measureRelation(std::string relName) override { | 
| 510 | auto rel = prog.getRelation(relName); | 
| 511 |  | 
| 512 | if (rel == nullptr) { | 
| 513 | return "No relation found\n"; | 
| 514 | } | 
| 515 |  | 
| 516 | auto size = rel->size(); | 
| 517 | int skip = size / 10; | 
| 518 |  | 
| 519 | if (skip == 0) { | 
| 520 | skip = 1; | 
| 521 | } | 
| 522 |  | 
| 523 | std::stringstream ss; | 
| 524 |  | 
| 525 | auto before_time = std::chrono::high_resolution_clock::now(); | 
| 526 |  | 
| 527 | int numTuples = 0; | 
| 528 | int proc = 0; | 
| 529 | for (auto& tuple : *rel) { | 
| 530 | auto tupleStart = std::chrono::high_resolution_clock::now(); | 
| 531 |  | 
| 532 | if (numTuples % skip != 0) { | 
| 533 | numTuples++; | 
| 534 | continue; | 
| 535 | } | 
| 536 |  | 
| 537 | std::vector<RamDomain> currentTuple; | 
| 538 | for (arity_type i = 0; i < rel->getPrimaryArity(); i++) { | 
| 539 | RamDomain n; | 
| 540 | if (*rel->getAttrType(i) == 's') { | 
| 541 | std::string s; | 
| 542 | tuple >> s; | 
| 543 | n = lookupExisting(s); | 
| 544 | } else if (*rel->getAttrType(i) == 'f') { | 
| 545 | RamFloat element; | 
| 546 | tuple >> element; | 
| 547 | n = ramBitCast(element); | 
| 548 | } else if (*rel->getAttrType(i) == 'u') { | 
| 549 | RamUnsigned element; | 
| 550 | tuple >> element; | 
| 551 | n = ramBitCast(element); | 
| 552 | } else { | 
| 553 | tuple >> n; | 
| 554 | } | 
| 555 |  | 
| 556 | currentTuple.push_back(n); | 
| 557 | } | 
| 558 |  | 
| 559 | RamDomain ruleNum; | 
| 560 | tuple >> ruleNum; | 
| 561 |  | 
| 562 | RamDomain levelNum; | 
| 563 | tuple >> levelNum; | 
| 564 |  | 
| 565 | std::cout << "Tuples expanded: " | 
| 566 | << explain(relName, currentTuple, ruleNum, levelNum, 10000)->getSize(); | 
| 567 |  | 
| 568 | numTuples++; | 
| 569 | proc++; | 
| 570 |  | 
| 571 | auto tupleEnd = std::chrono::high_resolution_clock::now(); | 
| 572 | auto tupleDuration = | 
| 573 | std::chrono::duration_cast<std::chrono::duration<double>>(tupleEnd - tupleStart); | 
| 574 |  | 
| 575 | std::cout << ", Time: " << tupleDuration.count() << "\n"; | 
| 576 | } | 
| 577 |  | 
| 578 | auto after_time = std::chrono::high_resolution_clock::now(); | 
| 579 | auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(after_time - before_time); | 
| 580 |  | 
| 581 | ss << "total: " << proc << " "; | 
| 582 | ss << duration.count() << std::endl; | 
| 583 |  | 
| 584 | return ss.str(); | 
| 585 | } | 
| 586 |  | 
| 587 | void printRulesJSON(std::ostream& os) override { | 
| 588 | os << "\"rules\": [\n"; | 
| 589 | bool first = true; | 
| 590 | for (auto const& cur : rules) { | 
| 591 | if (first) { | 
| 592 | first = false; | 
| 593 | } else { | 
| 594 | os << ",\n"; | 
| 595 | } | 
| 596 | os << "\t{ \"rule-number\": \"(R" << cur.first.second << ")\", \"rule\": \"" | 
| 597 | << stringify(cur.second) << "\"}"; | 
| 598 | } | 
| 599 | os << "\n]\n"; | 
| 600 | } | 
| 601 |  | 
| 602 | void queryProcess(const std::vector<std::pair<std::string, std::vector<std::string>>>& rels) override { | 
| 603 | std::regex varRegex("[a-zA-Z_][a-zA-Z_0-9]*", std::regex_constants::extended); | 
| 604 | std::regex symbolRegex("\"([^\"]*)\"", std::regex_constants::extended); | 
| 605 | std::regex numberRegex("[0-9]+", std::regex_constants::extended); | 
| 606 |  | 
| 607 | std::smatch argsMatcher; | 
| 608 |  | 
| 609 | // map for variable name and corresponding equivalence class | 
| 610 | std::map<std::string, Equivalence> nameToEquivalence; | 
| 611 |  | 
| 612 | // const constraints that solution must satisfy | 
| 613 | ConstConstraint constConstraints; | 
| 614 |  | 
| 615 | // relations of tuples containing variables | 
| 616 | std::vector<Relation*> varRels; | 
| 617 |  | 
| 618 | // counter for adding element to varRels | 
| 619 | std::size_t idx = 0; | 
| 620 |  | 
| 621 | // parse arguments in each relation Tuple | 
| 622 | for (const auto& rel : rels) { | 
| 623 | Relation* relation = prog.getRelation(rel.first); | 
| 624 | // number/symbol index for constant arguments in tuple | 
| 625 | std::vector<RamDomain> constTuple; | 
| 626 | // relation does not exist | 
| 627 | if (relation == nullptr) { | 
| 628 | std::cout << "Relation <" << rel.first << "> does not exist" << std::endl; | 
| 629 | return; | 
| 630 | } | 
| 631 | // arity error | 
| 632 | if (relation->getPrimaryArity() != rel.second.size()) { | 
| 633 | std::cout << "<" + rel.first << "> has arity of " << relation->getPrimaryArity() << std::endl; | 
| 634 | return; | 
| 635 | } | 
| 636 |  | 
| 637 | // check if args contain variable | 
| 638 | bool containVar = false; | 
| 639 | for (std::size_t j = 0; j < rel.second.size(); ++j) { | 
| 640 | // arg is a variable | 
| 641 | if (std::regex_match(rel.second[j], argsMatcher, varRegex)) { | 
| 642 | containVar = true; | 
| 643 | auto nameToEquivalenceIter = nameToEquivalence.find(argsMatcher[0]); | 
| 644 | // if variable has not shown up before, create an equivalence class for add it to | 
| 645 | // nameToEquivalence map, otherwise add its indices to corresponding equivalence class | 
| 646 | if (nameToEquivalenceIter == nameToEquivalence.end()) { | 
| 647 | nameToEquivalence.insert( | 
| 648 | {argsMatcher[0], Equivalence(*(relation->getAttrType(j)), argsMatcher[0], | 
| 649 | std::make_pair(idx, j))}); | 
| 650 | } else { | 
| 651 | nameToEquivalenceIter->second.push_back(std::make_pair(idx, j)); | 
| 652 | } | 
| 653 | continue; | 
| 654 | } | 
| 655 |  | 
| 656 | RamDomain rd; | 
| 657 | switch (*(relation->getAttrType(j))) { | 
| 658 | case 's': | 
| 659 | if (!std::regex_match(rel.second[j], argsMatcher, symbolRegex)) { | 
| 660 | std::cout << argsMatcher.str(0) << " does not match type defined in relation" | 
| 661 | << std::endl; | 
| 662 | return; | 
| 663 | } | 
| 664 | rd = prog.getSymbolTable().encode(argsMatcher[1]); | 
| 665 | break; | 
| 666 | case 'f': | 
| 667 | if (!canBeParsedAsRamFloat(rel.second[j])) { | 
| 668 | std::cout << rel.second[j] << " does not match type defined in relation" | 
| 669 | << std::endl; | 
| 670 | return; | 
| 671 | } | 
| 672 | rd = ramBitCast(RamFloatFromString(rel.second[j])); | 
| 673 | break; | 
| 674 | case 'i': | 
| 675 | if (!canBeParsedAsRamSigned(rel.second[j])) { | 
| 676 | std::cout << rel.second[j] << " does not match type defined in relation" | 
| 677 | << std::endl; | 
| 678 | return; | 
| 679 | } | 
| 680 | rd = ramBitCast(RamSignedFromString(rel.second[j])); | 
| 681 | break; | 
| 682 | case 'u': | 
| 683 | if (!canBeParsedAsRamUnsigned(rel.second[j])) { | 
| 684 | std::cout << rel.second[j] << " does not match type defined in relation" | 
| 685 | << std::endl; | 
| 686 | return; | 
| 687 | } | 
| 688 | rd = ramBitCast(RamUnsignedFromString(rel.second[j])); | 
| 689 | break; | 
| 690 | default: continue; | 
| 691 | } | 
| 692 |  | 
| 693 | constConstraints.push_back(std::make_pair(std::make_pair(idx, j), rd)); | 
| 694 | if (!containVar) { | 
| 695 | constTuple.push_back(rd); | 
| 696 | } | 
| 697 | } | 
| 698 |  | 
| 699 | // if tuple does not contain any variable, check if existence of the tuple | 
| 700 | if (!containVar) { | 
| 701 | bool tupleExist = containsTuple(relation, constTuple); | 
| 702 |  | 
| 703 | // if relation contains this tuple, remove all related constraints | 
| 704 | if (tupleExist) { | 
| 705 | constConstraints.getConstraints().erase(constConstraints.getConstraints().end() - | 
| 706 | relation->getArity() + | 
| 707 | relation->getAuxiliaryArity(), | 
| 708 | constConstraints.getConstraints().end()); | 
| 709 | // otherwise, there is no solution for given query | 
| 710 | } else { | 
| 711 | std::cout << "false." << std::endl; | 
| 712 | std::cout << "Tuple " << rel.first << "("; | 
| 713 | for (std::size_t l = 0; l < rel.second.size() - 1; ++l) { | 
| 714 | std::cout << rel.second[l] << ", "; | 
| 715 | } | 
| 716 | std::cout << rel.second.back() << ") does not exist" << std::endl; | 
| 717 | return; | 
| 718 | } | 
| 719 | } else { | 
| 720 | varRels.push_back(relation); | 
| 721 | ++idx; | 
| 722 | } | 
| 723 | } | 
| 724 |  | 
| 725 | // if varRels size is 0, all given tuples only contain constant args and exist, no variable to | 
| 726 | // decode, Output true and return | 
| 727 | if (varRels.size() == 0) { | 
| 728 | std::cout << "true." << std::endl; | 
| 729 | return; | 
| 730 | } | 
| 731 |  | 
| 732 | // find solution for parameterised query | 
| 733 | findQuerySolution(varRels, nameToEquivalence, constConstraints); | 
| 734 | } | 
| 735 |  | 
| 736 | private: | 
| 737 | std::map<std::pair<std::string, std::size_t>, std::vector<std::string>> info; | 
| 738 | std::map<std::pair<std::string, std::size_t>, std::string> rules; | 
| 739 | std::vector<std::vector<RamDomain>> subproofs; | 
| 740 | std::vector<std::string> constraintList = { | 
| 741 | "=", "!=", "<", "<=", ">=", ">", "match", "contains", "not_match", "not_contains"}; | 
| 742 |  | 
| 743 | RamDomain lookupExisting(const std::string& symbol) { | 
| 744 | auto Res = symTable.findOrInsert(symbol); | 
| 745 | if (Res.second) { | 
| 746 | fatal("Error string did not exist before call to `SymbolTable::findOrInsert`: `%s`", symbol); | 
| 747 | } | 
| 748 | return Res.first; | 
| 749 | } | 
| 750 |  | 
| 751 | std::tuple<int, int> findTuple(const std::string& relName, std::vector<RamDomain> tup) { | 
| 752 | auto rel = prog.getRelation(relName); | 
| 753 |  | 
| 754 | if (rel == nullptr) { | 
| 755 | return std::make_tuple(-1, -1); | 
| 756 | } | 
| 757 |  | 
| 758 | // find correct tuple | 
| 759 | for (auto& tuple : *rel) { | 
| 760 | bool match = true; | 
| 761 | std::vector<RamDomain> currentTuple; | 
| 762 |  | 
| 763 | for (arity_type i = 0; i < rel->getPrimaryArity(); i++) { | 
| 764 | RamDomain n; | 
| 765 | if (*rel->getAttrType(i) == 's') { | 
| 766 | std::string s; | 
| 767 | tuple >> s; | 
| 768 | n = lookupExisting(s); | 
| 769 | } else if (*rel->getAttrType(i) == 'f') { | 
| 770 | RamFloat element; | 
| 771 | tuple >> element; | 
| 772 | n = ramBitCast(element); | 
| 773 | } else if (*rel->getAttrType(i) == 'u') { | 
| 774 | RamUnsigned element; | 
| 775 | tuple >> element; | 
| 776 | n = ramBitCast(element); | 
| 777 | } else { | 
| 778 | tuple >> n; | 
| 779 | } | 
| 780 |  | 
| 781 | currentTuple.push_back(n); | 
| 782 |  | 
| 783 | if (n != tup[i]) { | 
| 784 | match = false; | 
| 785 | break; | 
| 786 | } | 
| 787 | } | 
| 788 |  | 
| 789 | if (match) { | 
| 790 | RamDomain ruleNum; | 
| 791 | tuple >> ruleNum; | 
| 792 |  | 
| 793 | RamDomain levelNum; | 
| 794 | tuple >> levelNum; | 
| 795 |  | 
| 796 | return std::make_tuple(ruleNum, levelNum); | 
| 797 | } | 
| 798 | } | 
| 799 |  | 
| 800 | // if no tuple exists | 
| 801 | return std::make_tuple(-1, -1); | 
| 802 | } | 
| 803 |  | 
| 804 | /* | 
| 805 | * Find solution for parameterised query satisfying constant constraints and equivalence constraints | 
| 806 | * @param varRels, reference to vector of relation of tuple contains at least one variable in its | 
| 807 | * arguments | 
| 808 | * @param nameToEquivalence, reference to variable name and corresponding equivalence class | 
| 809 | * @param constConstraints, reference to const constraints must be satisfied | 
| 810 | * */ | 
| 811 | void findQuerySolution(const std::vector<Relation*>& varRels, | 
| 812 | const std::map<std::string, Equivalence>& nameToEquivalence, | 
| 813 | const ConstConstraint& constConstraints) { | 
| 814 | // vector of iterators for relations in varRels | 
| 815 | std::vector<Relation::iterator> varRelationIterators; | 
| 816 | for (auto relation : varRels) { | 
| 817 | varRelationIterators.push_back(relation->begin()); | 
| 818 | } | 
| 819 |  | 
| 820 | std::size_t solutionCount = 0; | 
| 821 | std::stringstream solution; | 
| 822 |  | 
| 823 | // iterate through the vector of iterators to find solution | 
| 824 | while (true) { | 
| 825 | bool isSolution = true; | 
| 826 |  | 
| 827 | // vector contains the tuples the iterators currently points to | 
| 828 | std::vector<tuple> element; | 
| 829 | for (auto it : varRelationIterators) { | 
| 830 | element.push_back(*it); | 
| 831 | } | 
| 832 | // check if tuple satisfies variable equivalence | 
| 833 | for (auto var : nameToEquivalence) { | 
| 834 | if (!var.second.verify(element)) { | 
| 835 | isSolution = false; | 
| 836 | break; | 
| 837 | } | 
| 838 | } | 
| 839 | if (isSolution) { | 
| 840 | // check if tuple satisfies constant constraints | 
| 841 | isSolution = constConstraints.verify(element); | 
| 842 | } | 
| 843 |  | 
| 844 | if (isSolution) { | 
| 845 | // print previous solution (if any) | 
| 846 | if (solutionCount != 0) { | 
| 847 | std::cout << solution.str() << std::endl; | 
| 848 | } | 
| 849 | solution.str(std::string());  // reset solution and process | 
| 850 |  | 
| 851 | std::size_t c = 0; | 
| 852 | for (auto&& var : nameToEquivalence) { | 
| 853 | auto idx = var.second.getFirstIdx(); | 
| 854 | auto raw = element[idx.first][idx.second]; | 
| 855 |  | 
| 856 | solution << var.second.getSymbol() << " = "; | 
| 857 | switch (var.second.getType()) { | 
| 858 | case 'i': solution << ramBitCast<RamSigned>(raw); break; | 
| 859 | case 'f': solution << ramBitCast<RamFloat>(raw); break; | 
| 860 | case 'u': solution << ramBitCast<RamUnsigned>(raw); break; | 
| 861 | case 's': solution << prog.getSymbolTable().decode(raw); break; | 
| 862 | default: fatal("invalid type: `%c`", var.second.getType()); | 
| 863 | } | 
| 864 |  | 
| 865 | if (++c < nameToEquivalence.size()) { | 
| 866 | solution << ", "; | 
| 867 | } | 
| 868 | } | 
| 869 |  | 
| 870 | solutionCount++; | 
| 871 | // query has more than one solution; query whether to find next solution or stop | 
| 872 | if (1 < solutionCount) { | 
| 873 | for (std::string input; getline(std::cin, input);) { | 
| 874 | if (input == ";") break;   // print next solution? | 
| 875 | if (input == ".") return;  // break from query? | 
| 876 |  | 
| 877 | std::cout << "use ; to find next solution, use . to break from current query\n"; | 
| 878 | } | 
| 879 | } | 
| 880 | } | 
| 881 |  | 
| 882 | // increment the iterators | 
| 883 | std::size_t i = varRels.size() - 1; | 
| 884 | bool terminate = true; | 
| 885 | for (auto it = varRelationIterators.rbegin(); it != varRelationIterators.rend(); ++it) { | 
| 886 | if ((++(*it)) != varRels[i]->end()) { | 
| 887 | terminate = false; | 
| 888 | break; | 
| 889 | } else { | 
| 890 | (*it) = varRels[i]->begin(); | 
| 891 | --i; | 
| 892 | } | 
| 893 | } | 
| 894 |  | 
| 895 | if (terminate) { | 
| 896 | // if there is no solution, output false | 
| 897 | if (solutionCount == 0) { | 
| 898 | std::cout << "false." << std::endl; | 
| 899 | // otherwise print the last solution | 
| 900 | } else { | 
| 901 | std::cout << solution.str() << "." << std::endl; | 
| 902 | } | 
| 903 | break; | 
| 904 | } | 
| 905 | } | 
| 906 | } | 
| 907 |  | 
| 908 | // check if constTuple exists in relation | 
| 909 | bool containsTuple(Relation* relation, const std::vector<RamDomain>& constTuple) { | 
| 910 | bool tupleExist = false; | 
| 911 | for (auto it = relation->begin(); it != relation->end(); ++it) { | 
| 912 | bool eq = true; | 
| 913 | for (std::size_t j = 0; j < constTuple.size(); ++j) { | 
| 914 | if (constTuple[j] != (*it)[j]) { | 
| 915 | eq = false; | 
| 916 | break; | 
| 917 | } | 
| 918 | } | 
| 919 | if (eq) { | 
| 920 | tupleExist = true; | 
| 921 | break; | 
| 922 | } | 
| 923 | } | 
| 924 | return tupleExist; | 
| 925 | } | 
| 926 | }; | 
| 927 |  | 
| 928 | }  // end of namespace souffle |