OILS / vendor / souffle / provenance / ExplainProvenance.h View on Github | oilshell.org

256 lines, 133 significant
1/*
2 * Souffle - A Datalog Compiler
3 * Copyright (c) 2017, The Souffle Developers. All rights reserved
4 * Licensed under the Universal Permissive License v 1.0 as shown at:
5 * - https://opensource.org/licenses/UPL
6 * - <souffle root>/licenses/SOUFFLE-UPL.txt
7 */
8
9/************************************************************************
10 *
11 * @file ExplainProvenance.h
12 *
13 * Abstract class for implementing an instance of the explain interface for provenance
14 *
15 ***********************************************************************/
16
17#pragma once
18
19#include "souffle/RamTypes.h"
20#include "souffle/SouffleInterface.h"
21#include "souffle/SymbolTable.h"
22#include "souffle/utility/MiscUtil.h"
23#include "souffle/utility/StringUtil.h"
24#include "souffle/utility/tinyformat.h"
25#include <algorithm>
26#include <cassert>
27#include <cstdio>
28#include <map>
29#include <memory>
30#include <sstream>
31#include <string>
32#include <utility>
33#include <vector>
34
35namespace souffle {
36class TreeNode;
37
38/** Equivalence class for variables in query command */
39class Equivalence {
40public:
41 /** Destructor */
42 ~Equivalence() = default;
43
44 /**
45 * Constructor for Equvialence class
46 * @param t, type of the variable
47 * @param s, symbol of the variable
48 * @param idx, first occurence of the variable
49 * */
50 Equivalence(char t, std::string s, std::pair<std::size_t, std::size_t> idx)
51 : type(t), symbol(std::move(s)) {
52 indices.push_back(idx);
53 }
54
55 /** Copy constructor */
56 Equivalence(const Equivalence& o) = default;
57
58 /** Copy assignment operator */
59 Equivalence& operator=(const Equivalence& o) = default;
60
61 /** Add index at the end of indices vector */
62 void push_back(std::pair<std::size_t, std::size_t> idx) {
63 indices.push_back(idx);
64 }
65
66 /** Verify if elements at the indices are equivalent in the given product */
67 bool verify(const std::vector<tuple>& product) const {
68 for (std::size_t i = 1; i < indices.size(); ++i) {
69 if (product[indices[i].first][indices[i].second] !=
70 product[indices[i - 1].first][indices[i - 1].second]) {
71 return false;
72 }
73 }
74 return true;
75 }
76
77 /** Extract index of the first occurrence of the varible */
78 const std::pair<std::size_t, std::size_t>& getFirstIdx() const {
79 return indices[0];
80 }
81
82 /** Get indices of equivalent variables */
83 const std::vector<std::pair<std::size_t, std::size_t>>& getIndices() const {
84 return indices;
85 }
86
87 /** Get type of the variable of the equivalence class,
88 * 'i' for RamSigned, 's' for symbol
89 * 'u' for RamUnsigned, 'f' for RamFloat
90 */
91 char getType() const {
92 return type;
93 }
94
95 /** Get the symbol of variable */
96 const std::string& getSymbol() const {
97 return symbol;
98 }
99
100private:
101 char type;
102 std::string symbol;
103 std::vector<std::pair<std::size_t, std::size_t>> indices;
104};
105
106/** Constant constraints for values in query command */
107class ConstConstraint {
108public:
109 /** Constructor */
110 ConstConstraint() = default;
111
112 /** Destructor */
113 ~ConstConstraint() = default;
114
115 /** Add constant constraint at the end of constConstrs vector */
116 void push_back(std::pair<std::pair<std::size_t, std::size_t>, RamDomain> constr) {
117 constConstrs.push_back(constr);
118 }
119
120 /** Verify if the query product satisfies constant constraint */
121 bool verify(const std::vector<tuple>& product) const {
122 return std::all_of(constConstrs.begin(), constConstrs.end(), [&product](auto constr) {
123 return product[constr.first.first][constr.first.second] == constr.second;
124 });
125 }
126
127 /** Get the constant constraint vector */
128 std::vector<std::pair<std::pair<std::size_t, std::size_t>, RamDomain>>& getConstraints() {
129 return constConstrs;
130 }
131
132 const std::vector<std::pair<std::pair<std::size_t, std::size_t>, RamDomain>>& getConstraints() const {
133 return constConstrs;
134 }
135
136private:
137 std::vector<std::pair<std::pair<std::size_t, std::size_t>, RamDomain>> constConstrs;
138};
139
140/** utility function to split a string */
141inline std::vector<std::string> split(const std::string& s, char delim, int times = -1) {
142 std::vector<std::string> v;
143 std::stringstream ss(s);
144 std::string item;
145
146 while ((times > 0 || times <= -1) && std::getline(ss, item, delim)) {
147 v.push_back(item);
148 times--;
149 }
150
151 if (ss.peek() != EOF) {
152 std::string remainder;
153 std::getline(ss, remainder);
154 v.push_back(remainder);
155 }
156
157 return v;
158}
159
160class ExplainProvenance {
161public:
162 ExplainProvenance(SouffleProgram& prog) : prog(prog), symTable(prog.getSymbolTable()) {}
163 virtual ~ExplainProvenance() = default;
164
165 virtual void setup() = 0;
166
167 virtual Own<TreeNode> explain(
168 std::string relName, std::vector<std::string> tuple, std::size_t depthLimit) = 0;
169
170 virtual Own<TreeNode> explainSubproof(std::string relName, RamDomain label, std::size_t depthLimit) = 0;
171
172 virtual std::vector<std::string> explainNegationGetVariables(
173 std::string relName, std::vector<std::string> args, std::size_t ruleNum) = 0;
174
175 virtual Own<TreeNode> explainNegation(std::string relName, std::size_t ruleNum,
176 const std::vector<std::string>& tuple, std::map<std::string, std::string>& bodyVariables) = 0;
177
178 virtual std::string getRule(std::string relName, std::size_t ruleNum) = 0;
179
180 virtual std::vector<std::string> getRules(const std::string& relName) = 0;
181
182 virtual std::string measureRelation(std::string relName) = 0;
183
184 virtual void printRulesJSON(std::ostream& os) = 0;
185
186 /**
187 * Process query with given arguments
188 * @param rels, vector of relation, argument pairs
189 * */
190 virtual void queryProcess(const std::vector<std::pair<std::string, std::vector<std::string>>>& rels) = 0;
191
192protected:
193 SouffleProgram& prog;
194 SymbolTable& symTable;
195
196 std::vector<RamDomain> argsToNums(
197 const std::string& relName, const std::vector<std::string>& args) const {
198 std::vector<RamDomain> nums;
199
200 auto rel = prog.getRelation(relName);
201 if (rel == nullptr) {
202 return nums;
203 }
204
205 for (std::size_t i = 0; i < args.size(); i++) {
206 nums.push_back(valueRead(rel->getAttrType(i)[0], args[i]));
207 }
208
209 return nums;
210 }
211
212 /**
213 * Decode arguments from their ram representation and return as strings.
214 **/
215 std::vector<std::string> decodeArguments(
216 const std::string& relName, const std::vector<RamDomain>& nums) const {
217 std::vector<std::string> args;
218
219 auto rel = prog.getRelation(relName);
220 if (rel == nullptr) {
221 return args;
222 }
223
224 for (std::size_t i = 0; i < nums.size(); i++) {
225 args.push_back(valueShow(rel->getAttrType(i)[0], nums[i]));
226 }
227
228 return args;
229 }
230
231 std::string valueShow(const char type, const RamDomain value) const {
232 switch (type) {
233 case 'i': return tfm::format("%d", ramBitCast<RamSigned>(value));
234 case 'u': return tfm::format("%d", ramBitCast<RamUnsigned>(value));
235 case 'f': return tfm::format("%f", ramBitCast<RamFloat>(value));
236 case 's': return tfm::format("\"%s\"", symTable.decode(value));
237 case 'r': return tfm::format("record #%d", value);
238 default: fatal("unhandled type attr code");
239 }
240 }
241
242 RamDomain valueRead(const char type, const std::string& value) const {
243 switch (type) {
244 case 'i': return ramBitCast(RamSignedFromString(value));
245 case 'u': return ramBitCast(RamUnsignedFromString(value));
246 case 'f': return ramBitCast(RamFloatFromString(value));
247 case 's':
248 assert(2 <= value.size() && value[0] == '"' && value.back() == '"');
249 return symTable.encode(value.substr(1, value.size() - 2));
250 case 'r': fatal("not implemented");
251 default: fatal("unhandled type attr code");
252 }
253 }
254};
255
256} // end of namespace souffle