1 | /*
|
2 | * Souffle - A Datalog Compiler
|
3 | * Copyright (c) 2017, The Souffle Developers. All rights reserved
|
4 | * Licensed under the Universal Permissive License v 1.0 as shown at:
|
5 | * - https://opensource.org/licenses/UPL
|
6 | * - <souffle root>/licenses/SOUFFLE-UPL.txt
|
7 | */
|
8 |
|
9 | /************************************************************************
|
10 | *
|
11 | * @file ExplainProvenance.h
|
12 | *
|
13 | * Abstract class for implementing an instance of the explain interface for provenance
|
14 | *
|
15 | ***********************************************************************/
|
16 |
|
17 | #pragma once
|
18 |
|
19 | #include "souffle/RamTypes.h"
|
20 | #include "souffle/SouffleInterface.h"
|
21 | #include "souffle/SymbolTable.h"
|
22 | #include "souffle/utility/MiscUtil.h"
|
23 | #include "souffle/utility/StringUtil.h"
|
24 | #include "souffle/utility/tinyformat.h"
|
25 | #include <algorithm>
|
26 | #include <cassert>
|
27 | #include <cstdio>
|
28 | #include <map>
|
29 | #include <memory>
|
30 | #include <sstream>
|
31 | #include <string>
|
32 | #include <utility>
|
33 | #include <vector>
|
34 |
|
35 | namespace souffle {
|
36 | class TreeNode;
|
37 |
|
38 | /** Equivalence class for variables in query command */
|
39 | class Equivalence {
|
40 | public:
|
41 | /** Destructor */
|
42 | ~Equivalence() = default;
|
43 |
|
44 | /**
|
45 | * Constructor for Equvialence class
|
46 | * @param t, type of the variable
|
47 | * @param s, symbol of the variable
|
48 | * @param idx, first occurence of the variable
|
49 | * */
|
50 | Equivalence(char t, std::string s, std::pair<std::size_t, std::size_t> idx)
|
51 | : type(t), symbol(std::move(s)) {
|
52 | indices.push_back(idx);
|
53 | }
|
54 |
|
55 | /** Copy constructor */
|
56 | Equivalence(const Equivalence& o) = default;
|
57 |
|
58 | /** Copy assignment operator */
|
59 | Equivalence& operator=(const Equivalence& o) = default;
|
60 |
|
61 | /** Add index at the end of indices vector */
|
62 | void push_back(std::pair<std::size_t, std::size_t> idx) {
|
63 | indices.push_back(idx);
|
64 | }
|
65 |
|
66 | /** Verify if elements at the indices are equivalent in the given product */
|
67 | bool verify(const std::vector<tuple>& product) const {
|
68 | for (std::size_t i = 1; i < indices.size(); ++i) {
|
69 | if (product[indices[i].first][indices[i].second] !=
|
70 | product[indices[i - 1].first][indices[i - 1].second]) {
|
71 | return false;
|
72 | }
|
73 | }
|
74 | return true;
|
75 | }
|
76 |
|
77 | /** Extract index of the first occurrence of the varible */
|
78 | const std::pair<std::size_t, std::size_t>& getFirstIdx() const {
|
79 | return indices[0];
|
80 | }
|
81 |
|
82 | /** Get indices of equivalent variables */
|
83 | const std::vector<std::pair<std::size_t, std::size_t>>& getIndices() const {
|
84 | return indices;
|
85 | }
|
86 |
|
87 | /** Get type of the variable of the equivalence class,
|
88 | * 'i' for RamSigned, 's' for symbol
|
89 | * 'u' for RamUnsigned, 'f' for RamFloat
|
90 | */
|
91 | char getType() const {
|
92 | return type;
|
93 | }
|
94 |
|
95 | /** Get the symbol of variable */
|
96 | const std::string& getSymbol() const {
|
97 | return symbol;
|
98 | }
|
99 |
|
100 | private:
|
101 | char type;
|
102 | std::string symbol;
|
103 | std::vector<std::pair<std::size_t, std::size_t>> indices;
|
104 | };
|
105 |
|
106 | /** Constant constraints for values in query command */
|
107 | class ConstConstraint {
|
108 | public:
|
109 | /** Constructor */
|
110 | ConstConstraint() = default;
|
111 |
|
112 | /** Destructor */
|
113 | ~ConstConstraint() = default;
|
114 |
|
115 | /** Add constant constraint at the end of constConstrs vector */
|
116 | void push_back(std::pair<std::pair<std::size_t, std::size_t>, RamDomain> constr) {
|
117 | constConstrs.push_back(constr);
|
118 | }
|
119 |
|
120 | /** Verify if the query product satisfies constant constraint */
|
121 | bool verify(const std::vector<tuple>& product) const {
|
122 | return std::all_of(constConstrs.begin(), constConstrs.end(), [&product](auto constr) {
|
123 | return product[constr.first.first][constr.first.second] == constr.second;
|
124 | });
|
125 | }
|
126 |
|
127 | /** Get the constant constraint vector */
|
128 | std::vector<std::pair<std::pair<std::size_t, std::size_t>, RamDomain>>& getConstraints() {
|
129 | return constConstrs;
|
130 | }
|
131 |
|
132 | const std::vector<std::pair<std::pair<std::size_t, std::size_t>, RamDomain>>& getConstraints() const {
|
133 | return constConstrs;
|
134 | }
|
135 |
|
136 | private:
|
137 | std::vector<std::pair<std::pair<std::size_t, std::size_t>, RamDomain>> constConstrs;
|
138 | };
|
139 |
|
140 | /** utility function to split a string */
|
141 | inline std::vector<std::string> split(const std::string& s, char delim, int times = -1) {
|
142 | std::vector<std::string> v;
|
143 | std::stringstream ss(s);
|
144 | std::string item;
|
145 |
|
146 | while ((times > 0 || times <= -1) && std::getline(ss, item, delim)) {
|
147 | v.push_back(item);
|
148 | times--;
|
149 | }
|
150 |
|
151 | if (ss.peek() != EOF) {
|
152 | std::string remainder;
|
153 | std::getline(ss, remainder);
|
154 | v.push_back(remainder);
|
155 | }
|
156 |
|
157 | return v;
|
158 | }
|
159 |
|
160 | class ExplainProvenance {
|
161 | public:
|
162 | ExplainProvenance(SouffleProgram& prog) : prog(prog), symTable(prog.getSymbolTable()) {}
|
163 | virtual ~ExplainProvenance() = default;
|
164 |
|
165 | virtual void setup() = 0;
|
166 |
|
167 | virtual Own<TreeNode> explain(
|
168 | std::string relName, std::vector<std::string> tuple, std::size_t depthLimit) = 0;
|
169 |
|
170 | virtual Own<TreeNode> explainSubproof(std::string relName, RamDomain label, std::size_t depthLimit) = 0;
|
171 |
|
172 | virtual std::vector<std::string> explainNegationGetVariables(
|
173 | std::string relName, std::vector<std::string> args, std::size_t ruleNum) = 0;
|
174 |
|
175 | virtual Own<TreeNode> explainNegation(std::string relName, std::size_t ruleNum,
|
176 | const std::vector<std::string>& tuple, std::map<std::string, std::string>& bodyVariables) = 0;
|
177 |
|
178 | virtual std::string getRule(std::string relName, std::size_t ruleNum) = 0;
|
179 |
|
180 | virtual std::vector<std::string> getRules(const std::string& relName) = 0;
|
181 |
|
182 | virtual std::string measureRelation(std::string relName) = 0;
|
183 |
|
184 | virtual void printRulesJSON(std::ostream& os) = 0;
|
185 |
|
186 | /**
|
187 | * Process query with given arguments
|
188 | * @param rels, vector of relation, argument pairs
|
189 | * */
|
190 | virtual void queryProcess(const std::vector<std::pair<std::string, std::vector<std::string>>>& rels) = 0;
|
191 |
|
192 | protected:
|
193 | SouffleProgram& prog;
|
194 | SymbolTable& symTable;
|
195 |
|
196 | std::vector<RamDomain> argsToNums(
|
197 | const std::string& relName, const std::vector<std::string>& args) const {
|
198 | std::vector<RamDomain> nums;
|
199 |
|
200 | auto rel = prog.getRelation(relName);
|
201 | if (rel == nullptr) {
|
202 | return nums;
|
203 | }
|
204 |
|
205 | for (std::size_t i = 0; i < args.size(); i++) {
|
206 | nums.push_back(valueRead(rel->getAttrType(i)[0], args[i]));
|
207 | }
|
208 |
|
209 | return nums;
|
210 | }
|
211 |
|
212 | /**
|
213 | * Decode arguments from their ram representation and return as strings.
|
214 | **/
|
215 | std::vector<std::string> decodeArguments(
|
216 | const std::string& relName, const std::vector<RamDomain>& nums) const {
|
217 | std::vector<std::string> args;
|
218 |
|
219 | auto rel = prog.getRelation(relName);
|
220 | if (rel == nullptr) {
|
221 | return args;
|
222 | }
|
223 |
|
224 | for (std::size_t i = 0; i < nums.size(); i++) {
|
225 | args.push_back(valueShow(rel->getAttrType(i)[0], nums[i]));
|
226 | }
|
227 |
|
228 | return args;
|
229 | }
|
230 |
|
231 | std::string valueShow(const char type, const RamDomain value) const {
|
232 | switch (type) {
|
233 | case 'i': return tfm::format("%d", ramBitCast<RamSigned>(value));
|
234 | case 'u': return tfm::format("%d", ramBitCast<RamUnsigned>(value));
|
235 | case 'f': return tfm::format("%f", ramBitCast<RamFloat>(value));
|
236 | case 's': return tfm::format("\"%s\"", symTable.decode(value));
|
237 | case 'r': return tfm::format("record #%d", value);
|
238 | default: fatal("unhandled type attr code");
|
239 | }
|
240 | }
|
241 |
|
242 | RamDomain valueRead(const char type, const std::string& value) const {
|
243 | switch (type) {
|
244 | case 'i': return ramBitCast(RamSignedFromString(value));
|
245 | case 'u': return ramBitCast(RamUnsignedFromString(value));
|
246 | case 'f': return ramBitCast(RamFloatFromString(value));
|
247 | case 's':
|
248 | assert(2 <= value.size() && value[0] == '"' && value.back() == '"');
|
249 | return symTable.encode(value.substr(1, value.size() - 2));
|
250 | case 'r': fatal("not implemented");
|
251 | default: fatal("unhandled type attr code");
|
252 | }
|
253 | }
|
254 | };
|
255 |
|
256 | } // end of namespace souffle
|