OILS / vendor / souffle / io / ReadStream.h View on Github | oilshell.org

394 lines, 230 significant
1/*
2 * Souffle - A Datalog Compiler
3 * Copyright (c) 2021, The Souffle Developers. All rights reserved
4 * Licensed under the Universal Permissive License v 1.0 as shown at:
5 * - https://opensource.org/licenses/UPL
6 * - <souffle root>/licenses/SOUFFLE-UPL.txt
7 */
8
9/************************************************************************
10 *
11 * @file ReadStream.h
12 *
13 ***********************************************************************/
14
15#pragma once
16
17#include "souffle/RamTypes.h"
18#include "souffle/RecordTable.h"
19#include "souffle/SymbolTable.h"
20#include "souffle/io/SerialisationStream.h"
21#include "souffle/utility/ContainerUtil.h"
22#include "souffle/utility/MiscUtil.h"
23#include "souffle/utility/StringUtil.h"
24#include "souffle/utility/json11.h"
25#include <cctype>
26#include <cstddef>
27#include <map>
28#include <memory>
29#include <ostream>
30#include <stdexcept>
31#include <string>
32#include <vector>
33
34namespace souffle {
35
36class ReadStream : public SerialisationStream<false> {
37protected:
38 ReadStream(
39 const std::map<std::string, std::string>& rwOperation, SymbolTable& symTab, RecordTable& recTab)
40 : SerialisationStream(symTab, recTab, rwOperation) {}
41
42public:
43 template <typename T>
44 void readAll(T& relation) {
45 while (const auto next = readNextTuple()) {
46 const RamDomain* ramDomain = next.get();
47 relation.insert(ramDomain);
48 }
49 }
50
51protected:
52 /**
53 * Read a record from a string.
54 *
55 * @param source - string containing a record
56 * @param recordTypeName - record type.
57 * @parem pos - start parsing from this position.
58 * @param consumed - if not nullptr: number of characters read.
59 *
60 */
61 RamDomain readRecord(const std::string& source, const std::string& recordTypeName, std::size_t pos = 0,
62 std::size_t* charactersRead = nullptr) {
63 const std::size_t initial_position = pos;
64
65 // Check if record type information are present
66 auto&& recordInfo = types["records"][recordTypeName];
67 if (recordInfo.is_null()) {
68 throw std::invalid_argument("Missing record type information: " + recordTypeName);
69 }
70
71 // Handle nil case
72 consumeWhiteSpace(source, pos);
73 if (source.substr(pos, 3) == "nil") {
74 if (charactersRead != nullptr) {
75 *charactersRead = 3;
76 }
77 return 0;
78 }
79
80 auto&& recordTypes = recordInfo["types"];
81 const std::size_t recordArity = recordInfo["arity"].long_value();
82
83 std::vector<RamDomain> recordValues(recordArity);
84
85 consumeChar(source, '[', pos);
86
87 for (std::size_t i = 0; i < recordArity; ++i) {
88 const std::string& recordType = recordTypes[i].string_value();
89 std::size_t consumed = 0;
90
91 if (i > 0) {
92 consumeChar(source, ',', pos);
93 }
94 consumeWhiteSpace(source, pos);
95 switch (recordType[0]) {
96 case 's': {
97 recordValues[i] = symbolTable.encode(readSymbol(source, ",]", pos, &consumed));
98 break;
99 }
100 case 'i': {
101 recordValues[i] = RamSignedFromString(source.substr(pos), &consumed);
102 break;
103 }
104 case 'u': {
105 recordValues[i] = ramBitCast(RamUnsignedFromString(source.substr(pos), &consumed));
106 break;
107 }
108 case 'f': {
109 recordValues[i] = ramBitCast(RamFloatFromString(source.substr(pos), &consumed));
110 break;
111 }
112 case 'r': {
113 recordValues[i] = readRecord(source, recordType, pos, &consumed);
114 break;
115 }
116 case '+': {
117 recordValues[i] = readADT(source, recordType, pos, &consumed);
118 break;
119 }
120 default: fatal("Invalid type attribute");
121 }
122 pos += consumed;
123 }
124 consumeChar(source, ']', pos);
125
126 if (charactersRead != nullptr) {
127 *charactersRead = pos - initial_position;
128 }
129
130 return recordTable.pack(recordValues.data(), recordValues.size());
131 }
132
133 RamDomain readADT(const std::string& source, const std::string& adtName, std::size_t pos = 0,
134 std::size_t* charactersRead = nullptr) {
135 const std::size_t initial_position = pos;
136
137 // Branch will are encoded as one of the:
138 // [branchIdx, [branchValues...]]
139 // [branchIdx, branchValue]
140 // branchIdx
141 RamDomain branchIdx = -1;
142
143 auto&& adtInfo = types["ADTs"][adtName];
144 const auto& branches = adtInfo["branches"];
145
146 if (adtInfo.is_null() || !branches.is_array()) {
147 throw std::invalid_argument("Missing ADT information: " + adtName);
148 }
149
150 // Consume initial character
151 consumeChar(source, '$', pos);
152 std::string constructor = readQualifiedName(source, pos);
153
154 json11::Json branchInfo = [&]() -> json11::Json {
155 for (auto branch : branches.array_items()) {
156 ++branchIdx;
157
158 if (branch["name"].string_value() == constructor) {
159 return branch;
160 }
161 }
162
163 throw std::invalid_argument("Missing branch information: " + constructor);
164 }();
165
166 assert(branchInfo["types"].is_array());
167 auto branchTypes = branchInfo["types"].array_items();
168
169 // Handle a branch without arguments.
170 if (branchTypes.empty()) {
171 if (charactersRead != nullptr) {
172 *charactersRead = pos - initial_position;
173 }
174
175 if (adtInfo["enum"].bool_value()) {
176 return branchIdx;
177 }
178
179 RamDomain emptyArgs = recordTable.pack(toVector<RamDomain>().data(), 0);
180 const RamDomain record[] = {branchIdx, emptyArgs};
181 return recordTable.pack(record, 2);
182 }
183
184 consumeChar(source, '(', pos);
185
186 std::vector<RamDomain> branchArgs(branchTypes.size());
187
188 for (std::size_t i = 0; i < branchTypes.size(); ++i) {
189 auto argType = branchTypes[i].string_value();
190 assert(!argType.empty());
191
192 std::size_t consumed = 0;
193
194 if (i > 0) {
195 consumeChar(source, ',', pos);
196 }
197 consumeWhiteSpace(source, pos);
198
199 switch (argType[0]) {
200 case 's': {
201 branchArgs[i] = symbolTable.encode(readSymbol(source, ",)", pos, &consumed));
202 break;
203 }
204 case 'i': {
205 branchArgs[i] = RamSignedFromString(source.substr(pos), &consumed);
206 break;
207 }
208 case 'u': {
209 branchArgs[i] = ramBitCast(RamUnsignedFromString(source.substr(pos), &consumed));
210 break;
211 }
212 case 'f': {
213 branchArgs[i] = ramBitCast(RamFloatFromString(source.substr(pos), &consumed));
214 break;
215 }
216 case 'r': {
217 branchArgs[i] = readRecord(source, argType, pos, &consumed);
218 break;
219 }
220 case '+': {
221 branchArgs[i] = readADT(source, argType, pos, &consumed);
222 break;
223 }
224 default: fatal("Invalid type attribute");
225 }
226 pos += consumed;
227 }
228
229 consumeChar(source, ')', pos);
230
231 if (charactersRead != nullptr) {
232 *charactersRead = pos - initial_position;
233 }
234
235 // Store branch either as [branch_id, [arguments]] or [branch_id, argument].
236 RamDomain branchValue = [&]() -> RamDomain {
237 if (branchArgs.size() != 1) {
238 return recordTable.pack(branchArgs.data(), branchArgs.size());
239 } else {
240 return branchArgs[0];
241 }
242 }();
243
244 RamDomain rec[2] = {branchIdx, branchValue};
245 return recordTable.pack(rec, 2);
246 }
247
248 /**
249 * Read the next alphanumeric + ('_', '?') sequence (corresponding to IDENT).
250 * Consume preceding whitespace.
251 * TODO (darth_tytus): use std::string_view?
252 */
253 std::string readQualifiedName(const std::string& source, std::size_t& pos) {
254 consumeWhiteSpace(source, pos);
255 if (pos >= source.length()) {
256 throw std::invalid_argument("Unexpected end of input");
257 }
258
259 const std::size_t bgn = pos;
260 while (pos < source.length()) {
261 unsigned char ch = static_cast<unsigned char>(source[pos]);
262 bool valid = std::isalnum(ch) || ch == '_' || ch == '?' || ch == '.';
263 if (!valid) break;
264 ++pos;
265 }
266
267 return source.substr(bgn, pos - bgn);
268 }
269
270 std::string readUntil(const std::string& source, const std::string& stopChars, const std::size_t pos,
271 std::size_t* charactersRead) {
272 std::size_t endOfSymbol = source.find_first_of(stopChars, pos);
273
274 if (endOfSymbol == std::string::npos) {
275 throw std::invalid_argument("Unexpected end of input");
276 }
277
278 *charactersRead = endOfSymbol - pos;
279
280 return source.substr(pos, *charactersRead);
281 }
282
283 std::string readQuotedSymbol(const std::string& source, std::size_t pos, std::size_t* charactersRead) {
284 const std::size_t start = pos;
285 const std::size_t end = source.length();
286
287 const char quoteMark = source[pos];
288 ++pos;
289
290 const std::size_t startOfSymbol = pos;
291 std::size_t endOfSymbol = std::string::npos;
292 bool hasEscaped = false;
293
294 bool escaped = false;
295 while (pos < end) {
296 if (escaped) {
297 hasEscaped = true;
298 escaped = false;
299 ++pos;
300 continue;
301 }
302
303 const char c = source[pos];
304 if (c == quoteMark) {
305 endOfSymbol = pos;
306 ++pos;
307 break;
308 }
309 if (c == '\\') {
310 escaped = true;
311 }
312 ++pos;
313 }
314
315 if (endOfSymbol == std::string::npos) {
316 throw std::invalid_argument("Unexpected end of input");
317 }
318
319 *charactersRead = pos - start;
320
321 std::size_t lengthOfSymbol = endOfSymbol - startOfSymbol;
322
323 // fast handling of symbol without escape sequence
324 if (!hasEscaped) {
325 return source.substr(startOfSymbol, lengthOfSymbol);
326 } else {
327 // slow handling of symbol with escape sequence
328 std::string symbol;
329 symbol.reserve(lengthOfSymbol);
330 bool escaped = false;
331 for (std::size_t pos = startOfSymbol; pos < endOfSymbol; ++pos) {
332 char ch = source[pos];
333 if (escaped || ch != '\\') {
334 symbol.push_back(ch);
335 escaped = false;
336 } else {
337 escaped = true;
338 }
339 }
340 return symbol;
341 }
342 }
343
344 /**
345 * Read the next symbol.
346 * It is either a double-quoted symbol with backslash-escaped chars, or the
347 * longuest sequence that do not contains any of the given stopChars.
348 * */
349 std::string readSymbol(const std::string& source, const std::string& stopChars, const std::size_t pos,
350 std::size_t* charactersRead) {
351 if (source[pos] == '"') {
352 return readQuotedSymbol(source, pos, charactersRead);
353 } else {
354 return readUntil(source, stopChars, pos, charactersRead);
355 }
356 }
357
358 /**
359 * Read past given character, consuming any preceding whitespace.
360 */
361 void consumeChar(const std::string& str, char c, std::size_t& pos) {
362 consumeWhiteSpace(str, pos);
363 if (pos >= str.length()) {
364 throw std::invalid_argument("Unexpected end of input");
365 }
366 if (str[pos] != c) {
367 std::stringstream error;
368 error << "Expected: \'" << c << "\', got: " << str[pos];
369 throw std::invalid_argument(error.str());
370 }
371 ++pos;
372 }
373
374 /**
375 * Advance position in the string until first non-whitespace character.
376 */
377 void consumeWhiteSpace(const std::string& str, std::size_t& pos) {
378 while (pos < str.length() && std::isspace(static_cast<unsigned char>(str[pos]))) {
379 ++pos;
380 }
381 }
382
383 virtual Own<RamDomain[]> readNextTuple() = 0;
384};
385
386class ReadStreamFactory {
387public:
388 virtual Own<ReadStream> getReader(
389 const std::map<std::string, std::string>&, SymbolTable&, RecordTable&) = 0;
390 virtual const std::string& getName() const = 0;
391 virtual ~ReadStreamFactory() = default;
392};
393
394} /* namespace souffle */