| 1 | /*
 | 
| 2 |  * Souffle - A Datalog Compiler
 | 
| 3 |  * Copyright (c) 2021, The Souffle Developers. All rights reserved
 | 
| 4 |  * Licensed under the Universal Permissive License v 1.0 as shown at:
 | 
| 5 |  * - https://opensource.org/licenses/UPL
 | 
| 6 |  * - <souffle root>/licenses/SOUFFLE-UPL.txt
 | 
| 7 |  */
 | 
| 8 | 
 | 
| 9 | /************************************************************************
 | 
| 10 |  *
 | 
| 11 |  * @file ReadStream.h
 | 
| 12 |  *
 | 
| 13 |  ***********************************************************************/
 | 
| 14 | 
 | 
| 15 | #pragma once
 | 
| 16 | 
 | 
| 17 | #include "souffle/RamTypes.h"
 | 
| 18 | #include "souffle/RecordTable.h"
 | 
| 19 | #include "souffle/SymbolTable.h"
 | 
| 20 | #include "souffle/io/SerialisationStream.h"
 | 
| 21 | #include "souffle/utility/ContainerUtil.h"
 | 
| 22 | #include "souffle/utility/MiscUtil.h"
 | 
| 23 | #include "souffle/utility/StringUtil.h"
 | 
| 24 | #include "souffle/utility/json11.h"
 | 
| 25 | #include <cctype>
 | 
| 26 | #include <cstddef>
 | 
| 27 | #include <map>
 | 
| 28 | #include <memory>
 | 
| 29 | #include <ostream>
 | 
| 30 | #include <stdexcept>
 | 
| 31 | #include <string>
 | 
| 32 | #include <vector>
 | 
| 33 | 
 | 
| 34 | namespace souffle {
 | 
| 35 | 
 | 
| 36 | class ReadStream : public SerialisationStream<false> {
 | 
| 37 | protected:
 | 
| 38 |     ReadStream(
 | 
| 39 |             const std::map<std::string, std::string>& rwOperation, SymbolTable& symTab, RecordTable& recTab)
 | 
| 40 |             : SerialisationStream(symTab, recTab, rwOperation) {}
 | 
| 41 | 
 | 
| 42 | public:
 | 
| 43 |     template <typename T>
 | 
| 44 |     void readAll(T& relation) {
 | 
| 45 |         while (const auto next = readNextTuple()) {
 | 
| 46 |             const RamDomain* ramDomain = next.get();
 | 
| 47 |             relation.insert(ramDomain);
 | 
| 48 |         }
 | 
| 49 |     }
 | 
| 50 | 
 | 
| 51 | protected:
 | 
| 52 |     /**
 | 
| 53 |      * Read a record from a string.
 | 
| 54 |      *
 | 
| 55 |      * @param source - string containing a record
 | 
| 56 |      * @param recordTypeName - record type.
 | 
| 57 |      * @parem pos - start parsing from this position.
 | 
| 58 |      * @param consumed - if not nullptr: number of characters read.
 | 
| 59 |      *
 | 
| 60 |      */
 | 
| 61 |     RamDomain readRecord(const std::string& source, const std::string& recordTypeName, std::size_t pos = 0,
 | 
| 62 |             std::size_t* charactersRead = nullptr) {
 | 
| 63 |         const std::size_t initial_position = pos;
 | 
| 64 | 
 | 
| 65 |         // Check if record type information are present
 | 
| 66 |         auto&& recordInfo = types["records"][recordTypeName];
 | 
| 67 |         if (recordInfo.is_null()) {
 | 
| 68 |             throw std::invalid_argument("Missing record type information: " + recordTypeName);
 | 
| 69 |         }
 | 
| 70 | 
 | 
| 71 |         // Handle nil case
 | 
| 72 |         consumeWhiteSpace(source, pos);
 | 
| 73 |         if (source.substr(pos, 3) == "nil") {
 | 
| 74 |             if (charactersRead != nullptr) {
 | 
| 75 |                 *charactersRead = 3;
 | 
| 76 |             }
 | 
| 77 |             return 0;
 | 
| 78 |         }
 | 
| 79 | 
 | 
| 80 |         auto&& recordTypes = recordInfo["types"];
 | 
| 81 |         const std::size_t recordArity = recordInfo["arity"].long_value();
 | 
| 82 | 
 | 
| 83 |         std::vector<RamDomain> recordValues(recordArity);
 | 
| 84 | 
 | 
| 85 |         consumeChar(source, '[', pos);
 | 
| 86 | 
 | 
| 87 |         for (std::size_t i = 0; i < recordArity; ++i) {
 | 
| 88 |             const std::string& recordType = recordTypes[i].string_value();
 | 
| 89 |             std::size_t consumed = 0;
 | 
| 90 | 
 | 
| 91 |             if (i > 0) {
 | 
| 92 |                 consumeChar(source, ',', pos);
 | 
| 93 |             }
 | 
| 94 |             consumeWhiteSpace(source, pos);
 | 
| 95 |             switch (recordType[0]) {
 | 
| 96 |                 case 's': {
 | 
| 97 |                     recordValues[i] = symbolTable.encode(readSymbol(source, ",]", pos, &consumed));
 | 
| 98 |                     break;
 | 
| 99 |                 }
 | 
| 100 |                 case 'i': {
 | 
| 101 |                     recordValues[i] = RamSignedFromString(source.substr(pos), &consumed);
 | 
| 102 |                     break;
 | 
| 103 |                 }
 | 
| 104 |                 case 'u': {
 | 
| 105 |                     recordValues[i] = ramBitCast(RamUnsignedFromString(source.substr(pos), &consumed));
 | 
| 106 |                     break;
 | 
| 107 |                 }
 | 
| 108 |                 case 'f': {
 | 
| 109 |                     recordValues[i] = ramBitCast(RamFloatFromString(source.substr(pos), &consumed));
 | 
| 110 |                     break;
 | 
| 111 |                 }
 | 
| 112 |                 case 'r': {
 | 
| 113 |                     recordValues[i] = readRecord(source, recordType, pos, &consumed);
 | 
| 114 |                     break;
 | 
| 115 |                 }
 | 
| 116 |                 case '+': {
 | 
| 117 |                     recordValues[i] = readADT(source, recordType, pos, &consumed);
 | 
| 118 |                     break;
 | 
| 119 |                 }
 | 
| 120 |                 default: fatal("Invalid type attribute");
 | 
| 121 |             }
 | 
| 122 |             pos += consumed;
 | 
| 123 |         }
 | 
| 124 |         consumeChar(source, ']', pos);
 | 
| 125 | 
 | 
| 126 |         if (charactersRead != nullptr) {
 | 
| 127 |             *charactersRead = pos - initial_position;
 | 
| 128 |         }
 | 
| 129 | 
 | 
| 130 |         return recordTable.pack(recordValues.data(), recordValues.size());
 | 
| 131 |     }
 | 
| 132 | 
 | 
| 133 |     RamDomain readADT(const std::string& source, const std::string& adtName, std::size_t pos = 0,
 | 
| 134 |             std::size_t* charactersRead = nullptr) {
 | 
| 135 |         const std::size_t initial_position = pos;
 | 
| 136 | 
 | 
| 137 |         // Branch will are encoded as one of the:
 | 
| 138 |         // [branchIdx, [branchValues...]]
 | 
| 139 |         // [branchIdx, branchValue]
 | 
| 140 |         // branchIdx
 | 
| 141 |         RamDomain branchIdx = -1;
 | 
| 142 | 
 | 
| 143 |         auto&& adtInfo = types["ADTs"][adtName];
 | 
| 144 |         const auto& branches = adtInfo["branches"];
 | 
| 145 | 
 | 
| 146 |         if (adtInfo.is_null() || !branches.is_array()) {
 | 
| 147 |             throw std::invalid_argument("Missing ADT information: " + adtName);
 | 
| 148 |         }
 | 
| 149 | 
 | 
| 150 |         // Consume initial character
 | 
| 151 |         consumeChar(source, '$', pos);
 | 
| 152 |         std::string constructor = readQualifiedName(source, pos);
 | 
| 153 | 
 | 
| 154 |         json11::Json branchInfo = [&]() -> json11::Json {
 | 
| 155 |             for (auto branch : branches.array_items()) {
 | 
| 156 |                 ++branchIdx;
 | 
| 157 | 
 | 
| 158 |                 if (branch["name"].string_value() == constructor) {
 | 
| 159 |                     return branch;
 | 
| 160 |                 }
 | 
| 161 |             }
 | 
| 162 | 
 | 
| 163 |             throw std::invalid_argument("Missing branch information: " + constructor);
 | 
| 164 |         }();
 | 
| 165 | 
 | 
| 166 |         assert(branchInfo["types"].is_array());
 | 
| 167 |         auto branchTypes = branchInfo["types"].array_items();
 | 
| 168 | 
 | 
| 169 |         // Handle a branch without arguments.
 | 
| 170 |         if (branchTypes.empty()) {
 | 
| 171 |             if (charactersRead != nullptr) {
 | 
| 172 |                 *charactersRead = pos - initial_position;
 | 
| 173 |             }
 | 
| 174 | 
 | 
| 175 |             if (adtInfo["enum"].bool_value()) {
 | 
| 176 |                 return branchIdx;
 | 
| 177 |             }
 | 
| 178 | 
 | 
| 179 |             RamDomain emptyArgs = recordTable.pack(toVector<RamDomain>().data(), 0);
 | 
| 180 |             const RamDomain record[] = {branchIdx, emptyArgs};
 | 
| 181 |             return recordTable.pack(record, 2);
 | 
| 182 |         }
 | 
| 183 | 
 | 
| 184 |         consumeChar(source, '(', pos);
 | 
| 185 | 
 | 
| 186 |         std::vector<RamDomain> branchArgs(branchTypes.size());
 | 
| 187 | 
 | 
| 188 |         for (std::size_t i = 0; i < branchTypes.size(); ++i) {
 | 
| 189 |             auto argType = branchTypes[i].string_value();
 | 
| 190 |             assert(!argType.empty());
 | 
| 191 | 
 | 
| 192 |             std::size_t consumed = 0;
 | 
| 193 | 
 | 
| 194 |             if (i > 0) {
 | 
| 195 |                 consumeChar(source, ',', pos);
 | 
| 196 |             }
 | 
| 197 |             consumeWhiteSpace(source, pos);
 | 
| 198 | 
 | 
| 199 |             switch (argType[0]) {
 | 
| 200 |                 case 's': {
 | 
| 201 |                     branchArgs[i] = symbolTable.encode(readSymbol(source, ",)", pos, &consumed));
 | 
| 202 |                     break;
 | 
| 203 |                 }
 | 
| 204 |                 case 'i': {
 | 
| 205 |                     branchArgs[i] = RamSignedFromString(source.substr(pos), &consumed);
 | 
| 206 |                     break;
 | 
| 207 |                 }
 | 
| 208 |                 case 'u': {
 | 
| 209 |                     branchArgs[i] = ramBitCast(RamUnsignedFromString(source.substr(pos), &consumed));
 | 
| 210 |                     break;
 | 
| 211 |                 }
 | 
| 212 |                 case 'f': {
 | 
| 213 |                     branchArgs[i] = ramBitCast(RamFloatFromString(source.substr(pos), &consumed));
 | 
| 214 |                     break;
 | 
| 215 |                 }
 | 
| 216 |                 case 'r': {
 | 
| 217 |                     branchArgs[i] = readRecord(source, argType, pos, &consumed);
 | 
| 218 |                     break;
 | 
| 219 |                 }
 | 
| 220 |                 case '+': {
 | 
| 221 |                     branchArgs[i] = readADT(source, argType, pos, &consumed);
 | 
| 222 |                     break;
 | 
| 223 |                 }
 | 
| 224 |                 default: fatal("Invalid type attribute");
 | 
| 225 |             }
 | 
| 226 |             pos += consumed;
 | 
| 227 |         }
 | 
| 228 | 
 | 
| 229 |         consumeChar(source, ')', pos);
 | 
| 230 | 
 | 
| 231 |         if (charactersRead != nullptr) {
 | 
| 232 |             *charactersRead = pos - initial_position;
 | 
| 233 |         }
 | 
| 234 | 
 | 
| 235 |         // Store branch either as [branch_id, [arguments]] or [branch_id, argument].
 | 
| 236 |         RamDomain branchValue = [&]() -> RamDomain {
 | 
| 237 |             if (branchArgs.size() != 1) {
 | 
| 238 |                 return recordTable.pack(branchArgs.data(), branchArgs.size());
 | 
| 239 |             } else {
 | 
| 240 |                 return branchArgs[0];
 | 
| 241 |             }
 | 
| 242 |         }();
 | 
| 243 | 
 | 
| 244 |         RamDomain rec[2] = {branchIdx, branchValue};
 | 
| 245 |         return recordTable.pack(rec, 2);
 | 
| 246 |     }
 | 
| 247 | 
 | 
| 248 |     /**
 | 
| 249 |      * Read the next alphanumeric + ('_', '?') sequence (corresponding to IDENT).
 | 
| 250 |      * Consume preceding whitespace.
 | 
| 251 |      * TODO (darth_tytus): use std::string_view?
 | 
| 252 |      */
 | 
| 253 |     std::string readQualifiedName(const std::string& source, std::size_t& pos) {
 | 
| 254 |         consumeWhiteSpace(source, pos);
 | 
| 255 |         if (pos >= source.length()) {
 | 
| 256 |             throw std::invalid_argument("Unexpected end of input");
 | 
| 257 |         }
 | 
| 258 | 
 | 
| 259 |         const std::size_t bgn = pos;
 | 
| 260 |         while (pos < source.length()) {
 | 
| 261 |             unsigned char ch = static_cast<unsigned char>(source[pos]);
 | 
| 262 |             bool valid = std::isalnum(ch) || ch == '_' || ch == '?' || ch == '.';
 | 
| 263 |             if (!valid) break;
 | 
| 264 |             ++pos;
 | 
| 265 |         }
 | 
| 266 | 
 | 
| 267 |         return source.substr(bgn, pos - bgn);
 | 
| 268 |     }
 | 
| 269 | 
 | 
| 270 |     std::string readUntil(const std::string& source, const std::string& stopChars, const std::size_t pos,
 | 
| 271 |             std::size_t* charactersRead) {
 | 
| 272 |         std::size_t endOfSymbol = source.find_first_of(stopChars, pos);
 | 
| 273 | 
 | 
| 274 |         if (endOfSymbol == std::string::npos) {
 | 
| 275 |             throw std::invalid_argument("Unexpected end of input");
 | 
| 276 |         }
 | 
| 277 | 
 | 
| 278 |         *charactersRead = endOfSymbol - pos;
 | 
| 279 | 
 | 
| 280 |         return source.substr(pos, *charactersRead);
 | 
| 281 |     }
 | 
| 282 | 
 | 
| 283 |     std::string readQuotedSymbol(const std::string& source, std::size_t pos, std::size_t* charactersRead) {
 | 
| 284 |         const std::size_t start = pos;
 | 
| 285 |         const std::size_t end = source.length();
 | 
| 286 | 
 | 
| 287 |         const char quoteMark = source[pos];
 | 
| 288 |         ++pos;
 | 
| 289 | 
 | 
| 290 |         const std::size_t startOfSymbol = pos;
 | 
| 291 |         std::size_t endOfSymbol = std::string::npos;
 | 
| 292 |         bool hasEscaped = false;
 | 
| 293 | 
 | 
| 294 |         bool escaped = false;
 | 
| 295 |         while (pos < end) {
 | 
| 296 |             if (escaped) {
 | 
| 297 |                 hasEscaped = true;
 | 
| 298 |                 escaped = false;
 | 
| 299 |                 ++pos;
 | 
| 300 |                 continue;
 | 
| 301 |             }
 | 
| 302 | 
 | 
| 303 |             const char c = source[pos];
 | 
| 304 |             if (c == quoteMark) {
 | 
| 305 |                 endOfSymbol = pos;
 | 
| 306 |                 ++pos;
 | 
| 307 |                 break;
 | 
| 308 |             }
 | 
| 309 |             if (c == '\\') {
 | 
| 310 |                 escaped = true;
 | 
| 311 |             }
 | 
| 312 |             ++pos;
 | 
| 313 |         }
 | 
| 314 | 
 | 
| 315 |         if (endOfSymbol == std::string::npos) {
 | 
| 316 |             throw std::invalid_argument("Unexpected end of input");
 | 
| 317 |         }
 | 
| 318 | 
 | 
| 319 |         *charactersRead = pos - start;
 | 
| 320 | 
 | 
| 321 |         std::size_t lengthOfSymbol = endOfSymbol - startOfSymbol;
 | 
| 322 | 
 | 
| 323 |         // fast handling of symbol without escape sequence
 | 
| 324 |         if (!hasEscaped) {
 | 
| 325 |             return source.substr(startOfSymbol, lengthOfSymbol);
 | 
| 326 |         } else {
 | 
| 327 |             // slow handling of symbol with escape sequence
 | 
| 328 |             std::string symbol;
 | 
| 329 |             symbol.reserve(lengthOfSymbol);
 | 
| 330 |             bool escaped = false;
 | 
| 331 |             for (std::size_t pos = startOfSymbol; pos < endOfSymbol; ++pos) {
 | 
| 332 |                 char ch = source[pos];
 | 
| 333 |                 if (escaped || ch != '\\') {
 | 
| 334 |                     symbol.push_back(ch);
 | 
| 335 |                     escaped = false;
 | 
| 336 |                 } else {
 | 
| 337 |                     escaped = true;
 | 
| 338 |                 }
 | 
| 339 |             }
 | 
| 340 |             return symbol;
 | 
| 341 |         }
 | 
| 342 |     }
 | 
| 343 | 
 | 
| 344 |     /**
 | 
| 345 |      * Read the next symbol.
 | 
| 346 |      * It is either a double-quoted symbol with backslash-escaped chars, or the
 | 
| 347 |      * longuest sequence that do not contains any of the given stopChars.
 | 
| 348 |      * */
 | 
| 349 |     std::string readSymbol(const std::string& source, const std::string& stopChars, const std::size_t pos,
 | 
| 350 |             std::size_t* charactersRead) {
 | 
| 351 |         if (source[pos] == '"') {
 | 
| 352 |             return readQuotedSymbol(source, pos, charactersRead);
 | 
| 353 |         } else {
 | 
| 354 |             return readUntil(source, stopChars, pos, charactersRead);
 | 
| 355 |         }
 | 
| 356 |     }
 | 
| 357 | 
 | 
| 358 |     /**
 | 
| 359 |      * Read past given character, consuming any preceding whitespace.
 | 
| 360 |      */
 | 
| 361 |     void consumeChar(const std::string& str, char c, std::size_t& pos) {
 | 
| 362 |         consumeWhiteSpace(str, pos);
 | 
| 363 |         if (pos >= str.length()) {
 | 
| 364 |             throw std::invalid_argument("Unexpected end of input");
 | 
| 365 |         }
 | 
| 366 |         if (str[pos] != c) {
 | 
| 367 |             std::stringstream error;
 | 
| 368 |             error << "Expected: \'" << c << "\', got: " << str[pos];
 | 
| 369 |             throw std::invalid_argument(error.str());
 | 
| 370 |         }
 | 
| 371 |         ++pos;
 | 
| 372 |     }
 | 
| 373 | 
 | 
| 374 |     /**
 | 
| 375 |      * Advance position in the string until first non-whitespace character.
 | 
| 376 |      */
 | 
| 377 |     void consumeWhiteSpace(const std::string& str, std::size_t& pos) {
 | 
| 378 |         while (pos < str.length() && std::isspace(static_cast<unsigned char>(str[pos]))) {
 | 
| 379 |             ++pos;
 | 
| 380 |         }
 | 
| 381 |     }
 | 
| 382 | 
 | 
| 383 |     virtual Own<RamDomain[]> readNextTuple() = 0;
 | 
| 384 | };
 | 
| 385 | 
 | 
| 386 | class ReadStreamFactory {
 | 
| 387 | public:
 | 
| 388 |     virtual Own<ReadStream> getReader(
 | 
| 389 |             const std::map<std::string, std::string>&, SymbolTable&, RecordTable&) = 0;
 | 
| 390 |     virtual const std::string& getName() const = 0;
 | 
| 391 |     virtual ~ReadStreamFactory() = default;
 | 
| 392 | };
 | 
| 393 | 
 | 
| 394 | } /* namespace souffle */
 |