1 | /*
|
2 | * Souffle - A Datalog Compiler
|
3 | * Copyright (c) 2021, The Souffle Developers. All rights reserved
|
4 | * Licensed under the Universal Permissive License v 1.0 as shown at:
|
5 | * - https://opensource.org/licenses/UPL
|
6 | * - <souffle root>/licenses/SOUFFLE-UPL.txt
|
7 | */
|
8 |
|
9 | /************************************************************************
|
10 | *
|
11 | * @file ReadStream.h
|
12 | *
|
13 | ***********************************************************************/
|
14 |
|
15 | #pragma once
|
16 |
|
17 | #include "souffle/RamTypes.h"
|
18 | #include "souffle/RecordTable.h"
|
19 | #include "souffle/SymbolTable.h"
|
20 | #include "souffle/io/SerialisationStream.h"
|
21 | #include "souffle/utility/ContainerUtil.h"
|
22 | #include "souffle/utility/MiscUtil.h"
|
23 | #include "souffle/utility/StringUtil.h"
|
24 | #include "souffle/utility/json11.h"
|
25 | #include <cctype>
|
26 | #include <cstddef>
|
27 | #include <map>
|
28 | #include <memory>
|
29 | #include <ostream>
|
30 | #include <stdexcept>
|
31 | #include <string>
|
32 | #include <vector>
|
33 |
|
34 | namespace souffle {
|
35 |
|
36 | class ReadStream : public SerialisationStream<false> {
|
37 | protected:
|
38 | ReadStream(
|
39 | const std::map<std::string, std::string>& rwOperation, SymbolTable& symTab, RecordTable& recTab)
|
40 | : SerialisationStream(symTab, recTab, rwOperation) {}
|
41 |
|
42 | public:
|
43 | template <typename T>
|
44 | void readAll(T& relation) {
|
45 | while (const auto next = readNextTuple()) {
|
46 | const RamDomain* ramDomain = next.get();
|
47 | relation.insert(ramDomain);
|
48 | }
|
49 | }
|
50 |
|
51 | protected:
|
52 | /**
|
53 | * Read a record from a string.
|
54 | *
|
55 | * @param source - string containing a record
|
56 | * @param recordTypeName - record type.
|
57 | * @parem pos - start parsing from this position.
|
58 | * @param consumed - if not nullptr: number of characters read.
|
59 | *
|
60 | */
|
61 | RamDomain readRecord(const std::string& source, const std::string& recordTypeName, std::size_t pos = 0,
|
62 | std::size_t* charactersRead = nullptr) {
|
63 | const std::size_t initial_position = pos;
|
64 |
|
65 | // Check if record type information are present
|
66 | auto&& recordInfo = types["records"][recordTypeName];
|
67 | if (recordInfo.is_null()) {
|
68 | throw std::invalid_argument("Missing record type information: " + recordTypeName);
|
69 | }
|
70 |
|
71 | // Handle nil case
|
72 | consumeWhiteSpace(source, pos);
|
73 | if (source.substr(pos, 3) == "nil") {
|
74 | if (charactersRead != nullptr) {
|
75 | *charactersRead = 3;
|
76 | }
|
77 | return 0;
|
78 | }
|
79 |
|
80 | auto&& recordTypes = recordInfo["types"];
|
81 | const std::size_t recordArity = recordInfo["arity"].long_value();
|
82 |
|
83 | std::vector<RamDomain> recordValues(recordArity);
|
84 |
|
85 | consumeChar(source, '[', pos);
|
86 |
|
87 | for (std::size_t i = 0; i < recordArity; ++i) {
|
88 | const std::string& recordType = recordTypes[i].string_value();
|
89 | std::size_t consumed = 0;
|
90 |
|
91 | if (i > 0) {
|
92 | consumeChar(source, ',', pos);
|
93 | }
|
94 | consumeWhiteSpace(source, pos);
|
95 | switch (recordType[0]) {
|
96 | case 's': {
|
97 | recordValues[i] = symbolTable.encode(readSymbol(source, ",]", pos, &consumed));
|
98 | break;
|
99 | }
|
100 | case 'i': {
|
101 | recordValues[i] = RamSignedFromString(source.substr(pos), &consumed);
|
102 | break;
|
103 | }
|
104 | case 'u': {
|
105 | recordValues[i] = ramBitCast(RamUnsignedFromString(source.substr(pos), &consumed));
|
106 | break;
|
107 | }
|
108 | case 'f': {
|
109 | recordValues[i] = ramBitCast(RamFloatFromString(source.substr(pos), &consumed));
|
110 | break;
|
111 | }
|
112 | case 'r': {
|
113 | recordValues[i] = readRecord(source, recordType, pos, &consumed);
|
114 | break;
|
115 | }
|
116 | case '+': {
|
117 | recordValues[i] = readADT(source, recordType, pos, &consumed);
|
118 | break;
|
119 | }
|
120 | default: fatal("Invalid type attribute");
|
121 | }
|
122 | pos += consumed;
|
123 | }
|
124 | consumeChar(source, ']', pos);
|
125 |
|
126 | if (charactersRead != nullptr) {
|
127 | *charactersRead = pos - initial_position;
|
128 | }
|
129 |
|
130 | return recordTable.pack(recordValues.data(), recordValues.size());
|
131 | }
|
132 |
|
133 | RamDomain readADT(const std::string& source, const std::string& adtName, std::size_t pos = 0,
|
134 | std::size_t* charactersRead = nullptr) {
|
135 | const std::size_t initial_position = pos;
|
136 |
|
137 | // Branch will are encoded as one of the:
|
138 | // [branchIdx, [branchValues...]]
|
139 | // [branchIdx, branchValue]
|
140 | // branchIdx
|
141 | RamDomain branchIdx = -1;
|
142 |
|
143 | auto&& adtInfo = types["ADTs"][adtName];
|
144 | const auto& branches = adtInfo["branches"];
|
145 |
|
146 | if (adtInfo.is_null() || !branches.is_array()) {
|
147 | throw std::invalid_argument("Missing ADT information: " + adtName);
|
148 | }
|
149 |
|
150 | // Consume initial character
|
151 | consumeChar(source, '$', pos);
|
152 | std::string constructor = readQualifiedName(source, pos);
|
153 |
|
154 | json11::Json branchInfo = [&]() -> json11::Json {
|
155 | for (auto branch : branches.array_items()) {
|
156 | ++branchIdx;
|
157 |
|
158 | if (branch["name"].string_value() == constructor) {
|
159 | return branch;
|
160 | }
|
161 | }
|
162 |
|
163 | throw std::invalid_argument("Missing branch information: " + constructor);
|
164 | }();
|
165 |
|
166 | assert(branchInfo["types"].is_array());
|
167 | auto branchTypes = branchInfo["types"].array_items();
|
168 |
|
169 | // Handle a branch without arguments.
|
170 | if (branchTypes.empty()) {
|
171 | if (charactersRead != nullptr) {
|
172 | *charactersRead = pos - initial_position;
|
173 | }
|
174 |
|
175 | if (adtInfo["enum"].bool_value()) {
|
176 | return branchIdx;
|
177 | }
|
178 |
|
179 | RamDomain emptyArgs = recordTable.pack(toVector<RamDomain>().data(), 0);
|
180 | const RamDomain record[] = {branchIdx, emptyArgs};
|
181 | return recordTable.pack(record, 2);
|
182 | }
|
183 |
|
184 | consumeChar(source, '(', pos);
|
185 |
|
186 | std::vector<RamDomain> branchArgs(branchTypes.size());
|
187 |
|
188 | for (std::size_t i = 0; i < branchTypes.size(); ++i) {
|
189 | auto argType = branchTypes[i].string_value();
|
190 | assert(!argType.empty());
|
191 |
|
192 | std::size_t consumed = 0;
|
193 |
|
194 | if (i > 0) {
|
195 | consumeChar(source, ',', pos);
|
196 | }
|
197 | consumeWhiteSpace(source, pos);
|
198 |
|
199 | switch (argType[0]) {
|
200 | case 's': {
|
201 | branchArgs[i] = symbolTable.encode(readSymbol(source, ",)", pos, &consumed));
|
202 | break;
|
203 | }
|
204 | case 'i': {
|
205 | branchArgs[i] = RamSignedFromString(source.substr(pos), &consumed);
|
206 | break;
|
207 | }
|
208 | case 'u': {
|
209 | branchArgs[i] = ramBitCast(RamUnsignedFromString(source.substr(pos), &consumed));
|
210 | break;
|
211 | }
|
212 | case 'f': {
|
213 | branchArgs[i] = ramBitCast(RamFloatFromString(source.substr(pos), &consumed));
|
214 | break;
|
215 | }
|
216 | case 'r': {
|
217 | branchArgs[i] = readRecord(source, argType, pos, &consumed);
|
218 | break;
|
219 | }
|
220 | case '+': {
|
221 | branchArgs[i] = readADT(source, argType, pos, &consumed);
|
222 | break;
|
223 | }
|
224 | default: fatal("Invalid type attribute");
|
225 | }
|
226 | pos += consumed;
|
227 | }
|
228 |
|
229 | consumeChar(source, ')', pos);
|
230 |
|
231 | if (charactersRead != nullptr) {
|
232 | *charactersRead = pos - initial_position;
|
233 | }
|
234 |
|
235 | // Store branch either as [branch_id, [arguments]] or [branch_id, argument].
|
236 | RamDomain branchValue = [&]() -> RamDomain {
|
237 | if (branchArgs.size() != 1) {
|
238 | return recordTable.pack(branchArgs.data(), branchArgs.size());
|
239 | } else {
|
240 | return branchArgs[0];
|
241 | }
|
242 | }();
|
243 |
|
244 | RamDomain rec[2] = {branchIdx, branchValue};
|
245 | return recordTable.pack(rec, 2);
|
246 | }
|
247 |
|
248 | /**
|
249 | * Read the next alphanumeric + ('_', '?') sequence (corresponding to IDENT).
|
250 | * Consume preceding whitespace.
|
251 | * TODO (darth_tytus): use std::string_view?
|
252 | */
|
253 | std::string readQualifiedName(const std::string& source, std::size_t& pos) {
|
254 | consumeWhiteSpace(source, pos);
|
255 | if (pos >= source.length()) {
|
256 | throw std::invalid_argument("Unexpected end of input");
|
257 | }
|
258 |
|
259 | const std::size_t bgn = pos;
|
260 | while (pos < source.length()) {
|
261 | unsigned char ch = static_cast<unsigned char>(source[pos]);
|
262 | bool valid = std::isalnum(ch) || ch == '_' || ch == '?' || ch == '.';
|
263 | if (!valid) break;
|
264 | ++pos;
|
265 | }
|
266 |
|
267 | return source.substr(bgn, pos - bgn);
|
268 | }
|
269 |
|
270 | std::string readUntil(const std::string& source, const std::string& stopChars, const std::size_t pos,
|
271 | std::size_t* charactersRead) {
|
272 | std::size_t endOfSymbol = source.find_first_of(stopChars, pos);
|
273 |
|
274 | if (endOfSymbol == std::string::npos) {
|
275 | throw std::invalid_argument("Unexpected end of input");
|
276 | }
|
277 |
|
278 | *charactersRead = endOfSymbol - pos;
|
279 |
|
280 | return source.substr(pos, *charactersRead);
|
281 | }
|
282 |
|
283 | std::string readQuotedSymbol(const std::string& source, std::size_t pos, std::size_t* charactersRead) {
|
284 | const std::size_t start = pos;
|
285 | const std::size_t end = source.length();
|
286 |
|
287 | const char quoteMark = source[pos];
|
288 | ++pos;
|
289 |
|
290 | const std::size_t startOfSymbol = pos;
|
291 | std::size_t endOfSymbol = std::string::npos;
|
292 | bool hasEscaped = false;
|
293 |
|
294 | bool escaped = false;
|
295 | while (pos < end) {
|
296 | if (escaped) {
|
297 | hasEscaped = true;
|
298 | escaped = false;
|
299 | ++pos;
|
300 | continue;
|
301 | }
|
302 |
|
303 | const char c = source[pos];
|
304 | if (c == quoteMark) {
|
305 | endOfSymbol = pos;
|
306 | ++pos;
|
307 | break;
|
308 | }
|
309 | if (c == '\\') {
|
310 | escaped = true;
|
311 | }
|
312 | ++pos;
|
313 | }
|
314 |
|
315 | if (endOfSymbol == std::string::npos) {
|
316 | throw std::invalid_argument("Unexpected end of input");
|
317 | }
|
318 |
|
319 | *charactersRead = pos - start;
|
320 |
|
321 | std::size_t lengthOfSymbol = endOfSymbol - startOfSymbol;
|
322 |
|
323 | // fast handling of symbol without escape sequence
|
324 | if (!hasEscaped) {
|
325 | return source.substr(startOfSymbol, lengthOfSymbol);
|
326 | } else {
|
327 | // slow handling of symbol with escape sequence
|
328 | std::string symbol;
|
329 | symbol.reserve(lengthOfSymbol);
|
330 | bool escaped = false;
|
331 | for (std::size_t pos = startOfSymbol; pos < endOfSymbol; ++pos) {
|
332 | char ch = source[pos];
|
333 | if (escaped || ch != '\\') {
|
334 | symbol.push_back(ch);
|
335 | escaped = false;
|
336 | } else {
|
337 | escaped = true;
|
338 | }
|
339 | }
|
340 | return symbol;
|
341 | }
|
342 | }
|
343 |
|
344 | /**
|
345 | * Read the next symbol.
|
346 | * It is either a double-quoted symbol with backslash-escaped chars, or the
|
347 | * longuest sequence that do not contains any of the given stopChars.
|
348 | * */
|
349 | std::string readSymbol(const std::string& source, const std::string& stopChars, const std::size_t pos,
|
350 | std::size_t* charactersRead) {
|
351 | if (source[pos] == '"') {
|
352 | return readQuotedSymbol(source, pos, charactersRead);
|
353 | } else {
|
354 | return readUntil(source, stopChars, pos, charactersRead);
|
355 | }
|
356 | }
|
357 |
|
358 | /**
|
359 | * Read past given character, consuming any preceding whitespace.
|
360 | */
|
361 | void consumeChar(const std::string& str, char c, std::size_t& pos) {
|
362 | consumeWhiteSpace(str, pos);
|
363 | if (pos >= str.length()) {
|
364 | throw std::invalid_argument("Unexpected end of input");
|
365 | }
|
366 | if (str[pos] != c) {
|
367 | std::stringstream error;
|
368 | error << "Expected: \'" << c << "\', got: " << str[pos];
|
369 | throw std::invalid_argument(error.str());
|
370 | }
|
371 | ++pos;
|
372 | }
|
373 |
|
374 | /**
|
375 | * Advance position in the string until first non-whitespace character.
|
376 | */
|
377 | void consumeWhiteSpace(const std::string& str, std::size_t& pos) {
|
378 | while (pos < str.length() && std::isspace(static_cast<unsigned char>(str[pos]))) {
|
379 | ++pos;
|
380 | }
|
381 | }
|
382 |
|
383 | virtual Own<RamDomain[]> readNextTuple() = 0;
|
384 | };
|
385 |
|
386 | class ReadStreamFactory {
|
387 | public:
|
388 | virtual Own<ReadStream> getReader(
|
389 | const std::map<std::string, std::string>&, SymbolTable&, RecordTable&) = 0;
|
390 | virtual const std::string& getName() const = 0;
|
391 | virtual ~ReadStreamFactory() = default;
|
392 | };
|
393 |
|
394 | } /* namespace souffle */
|