| 1 | /*
 | 
| 2 |  * Souffle - A Datalog Compiler
 | 
| 3 |  * Copyright (c) 2021, The Souffle Developers. All rights reserved
 | 
| 4 |  * Licensed under the Universal Permissive License v 1.0 as shown at:
 | 
| 5 |  * - https://opensource.org/licenses/UPL
 | 
| 6 |  * - <souffle root>/licenses/SOUFFLE-UPL.txt
 | 
| 7 |  */
 | 
| 8 | 
 | 
| 9 | /************************************************************************
 | 
| 10 |  *
 | 
| 11 |  * @file ReadStreamCSV.h
 | 
| 12 |  *
 | 
| 13 |  ***********************************************************************/
 | 
| 14 | 
 | 
| 15 | #pragma once
 | 
| 16 | 
 | 
| 17 | #include "souffle/RamTypes.h"
 | 
| 18 | #include "souffle/RecordTable.h"
 | 
| 19 | #include "souffle/SymbolTable.h"
 | 
| 20 | #include "souffle/io/ReadStream.h"
 | 
| 21 | #include "souffle/utility/ContainerUtil.h"
 | 
| 22 | #include "souffle/utility/FileUtil.h"
 | 
| 23 | #include "souffle/utility/StringUtil.h"
 | 
| 24 | 
 | 
| 25 | #ifdef USE_LIBZ
 | 
| 26 | #include "souffle/io/gzfstream.h"
 | 
| 27 | #else
 | 
| 28 | #include <fstream>
 | 
| 29 | #endif
 | 
| 30 | 
 | 
| 31 | #include <algorithm>
 | 
| 32 | #include <cassert>
 | 
| 33 | #include <cstddef>
 | 
| 34 | #include <cstdint>
 | 
| 35 | #include <iostream>
 | 
| 36 | #include <map>
 | 
| 37 | #include <memory>
 | 
| 38 | #include <sstream>
 | 
| 39 | #include <stdexcept>
 | 
| 40 | #include <string>
 | 
| 41 | #include <vector>
 | 
| 42 | 
 | 
| 43 | namespace souffle {
 | 
| 44 | 
 | 
| 45 | class ReadStreamCSV : public ReadStream {
 | 
| 46 | public:
 | 
| 47 |     ReadStreamCSV(std::istream& file, const std::map<std::string, std::string>& rwOperation,
 | 
| 48 |             SymbolTable& symbolTable, RecordTable& recordTable)
 | 
| 49 |             : ReadStream(rwOperation, symbolTable, recordTable),
 | 
| 50 |               rfc4180(getOr(rwOperation, "rfc4180", "false") == std::string("true")),
 | 
| 51 |               delimiter(getOr(rwOperation, "delimiter", (rfc4180 ? "," : "\t"))), file(file), lineNumber(0),
 | 
| 52 |               inputMap(getInputColumnMap(rwOperation, static_cast<unsigned int>(arity))) {
 | 
| 53 |         if (rfc4180 && delimiter.find('"') != std::string::npos) {
 | 
| 54 |             std::stringstream errorMessage;
 | 
| 55 |             errorMessage << "CSV delimiter cannot contain '\"' character when rfc4180 is enabled.";
 | 
| 56 |             throw std::invalid_argument(errorMessage.str());
 | 
| 57 |         }
 | 
| 58 | 
 | 
| 59 |         while (inputMap.size() < arity) {
 | 
| 60 |             int size = static_cast<int>(inputMap.size());
 | 
| 61 |             inputMap[size] = size;
 | 
| 62 |         }
 | 
| 63 |     }
 | 
| 64 | 
 | 
| 65 | protected:
 | 
| 66 |     bool readNextLine(std::string& line, bool& isCRLF) {
 | 
| 67 |         if (!getline(file, line)) {
 | 
| 68 |             return false;
 | 
| 69 |         }
 | 
| 70 |         // Handle Windows line endings on non-Windows systems
 | 
| 71 |         isCRLF = !line.empty() && line.back() == '\r';
 | 
| 72 |         if (isCRLF) {
 | 
| 73 |             line.pop_back();
 | 
| 74 |         }
 | 
| 75 |         ++lineNumber;
 | 
| 76 |         return true;
 | 
| 77 |     }
 | 
| 78 | 
 | 
| 79 |     /**
 | 
| 80 |      * Read and return the next tuple.
 | 
| 81 |      *
 | 
| 82 |      * Returns nullptr if no tuple was readable.
 | 
| 83 |      * @return
 | 
| 84 |      */
 | 
| 85 |     Own<RamDomain[]> readNextTuple() override {
 | 
| 86 |         if (file.eof()) {
 | 
| 87 |             return nullptr;
 | 
| 88 |         }
 | 
| 89 |         std::string line;
 | 
| 90 |         Own<RamDomain[]> tuple = mk<RamDomain[]>(typeAttributes.size());
 | 
| 91 |         bool wasCRLF = false;
 | 
| 92 |         if (!readNextLine(line, wasCRLF)) {
 | 
| 93 |             return nullptr;
 | 
| 94 |         }
 | 
| 95 | 
 | 
| 96 |         std::size_t start = 0;
 | 
| 97 |         std::size_t columnsFilled = 0;
 | 
| 98 |         for (uint32_t column = 0; columnsFilled < arity; column++) {
 | 
| 99 |             std::size_t charactersRead = 0;
 | 
| 100 |             std::string element = nextElement(line, start, wasCRLF);
 | 
| 101 |             if (inputMap.count(column) == 0) {
 | 
| 102 |                 continue;
 | 
| 103 |             }
 | 
| 104 |             ++columnsFilled;
 | 
| 105 | 
 | 
| 106 |             try {
 | 
| 107 |                 auto&& ty = typeAttributes.at(inputMap[column]);
 | 
| 108 |                 switch (ty[0]) {
 | 
| 109 |                     case 's': {
 | 
| 110 |                         tuple[inputMap[column]] = symbolTable.encode(element);
 | 
| 111 |                         charactersRead = element.size();
 | 
| 112 |                         break;
 | 
| 113 |                     }
 | 
| 114 |                     case 'r': {
 | 
| 115 |                         tuple[inputMap[column]] = readRecord(element, ty, 0, &charactersRead);
 | 
| 116 |                         break;
 | 
| 117 |                     }
 | 
| 118 |                     case '+': {
 | 
| 119 |                         tuple[inputMap[column]] = readADT(element, ty, 0, &charactersRead);
 | 
| 120 |                         break;
 | 
| 121 |                     }
 | 
| 122 |                     case 'i': {
 | 
| 123 |                         tuple[inputMap[column]] = RamSignedFromString(element, &charactersRead);
 | 
| 124 |                         break;
 | 
| 125 |                     }
 | 
| 126 |                     case 'u': {
 | 
| 127 |                         tuple[inputMap[column]] = ramBitCast(readRamUnsigned(element, charactersRead));
 | 
| 128 |                         break;
 | 
| 129 |                     }
 | 
| 130 |                     case 'f': {
 | 
| 131 |                         tuple[inputMap[column]] = ramBitCast(RamFloatFromString(element, &charactersRead));
 | 
| 132 |                         break;
 | 
| 133 |                     }
 | 
| 134 |                     default: fatal("invalid type attribute: `%c`", ty[0]);
 | 
| 135 |                 }
 | 
| 136 |                 // Check if everything was read.
 | 
| 137 |                 if (charactersRead != element.size()) {
 | 
| 138 |                     throw std::invalid_argument(
 | 
| 139 |                             "Expected: " + delimiter + " or \\n. Got: " + element[charactersRead]);
 | 
| 140 |                 }
 | 
| 141 |             } catch (...) {
 | 
| 142 |                 std::stringstream errorMessage;
 | 
| 143 |                 errorMessage << "Error converting <" + element + "> in column " << column + 1 << " in line "
 | 
| 144 |                              << lineNumber << "; ";
 | 
| 145 |                 throw std::invalid_argument(errorMessage.str());
 | 
| 146 |             }
 | 
| 147 |         }
 | 
| 148 | 
 | 
| 149 |         return tuple;
 | 
| 150 |     }
 | 
| 151 | 
 | 
| 152 |     /**
 | 
| 153 |      * Read an unsigned element. Possible bases are 2, 10, 16
 | 
| 154 |      * Base is indicated by the first two chars.
 | 
| 155 |      */
 | 
| 156 |     RamUnsigned readRamUnsigned(const std::string& element, std::size_t& charactersRead) {
 | 
| 157 |         // Sanity check
 | 
| 158 |         assert(element.size() > 0);
 | 
| 159 | 
 | 
| 160 |         RamSigned value = 0;
 | 
| 161 | 
 | 
| 162 |         // Check prefix and parse the input.
 | 
| 163 |         if (isPrefix("0b", element)) {
 | 
| 164 |             value = RamUnsignedFromString(element, &charactersRead, 2);
 | 
| 165 |         } else if (isPrefix("0x", element)) {
 | 
| 166 |             value = RamUnsignedFromString(element, &charactersRead, 16);
 | 
| 167 |         } else {
 | 
| 168 |             value = RamUnsignedFromString(element, &charactersRead);
 | 
| 169 |         }
 | 
| 170 |         return value;
 | 
| 171 |     }
 | 
| 172 | 
 | 
| 173 |     std::string nextElement(std::string& line, std::size_t& start, bool& wasCRLF) {
 | 
| 174 |         std::string element;
 | 
| 175 | 
 | 
| 176 |         if (rfc4180) {
 | 
| 177 |             if (line[start] == '"') {
 | 
| 178 |                 // quoted field
 | 
| 179 |                 std::size_t end = line.length();
 | 
| 180 |                 std::size_t pos = start + 1;
 | 
| 181 |                 bool foundEndQuote = false;
 | 
| 182 |                 while (!foundEndQuote) {
 | 
| 183 |                     if (pos == end) {
 | 
| 184 |                         bool newWasCRLF = false;
 | 
| 185 |                         if (!readNextLine(line, newWasCRLF)) {
 | 
| 186 |                             break;
 | 
| 187 |                         }
 | 
| 188 |                         // account for \r\n or \n that we had previously
 | 
| 189 |                         // read and thrown out.
 | 
| 190 |                         // since we're in a quote, we should restore
 | 
| 191 |                         // what the user provided
 | 
| 192 |                         if (wasCRLF) {
 | 
| 193 |                             element.push_back('\r');
 | 
| 194 |                         }
 | 
| 195 |                         element.push_back('\n');
 | 
| 196 | 
 | 
| 197 |                         // remember if we just read a CRLF sequence
 | 
| 198 |                         wasCRLF = newWasCRLF;
 | 
| 199 | 
 | 
| 200 |                         // start over
 | 
| 201 |                         pos = 0;
 | 
| 202 |                         end = line.length();
 | 
| 203 |                     }
 | 
| 204 |                     if (pos == end) {
 | 
| 205 |                         // this means we've got a blank line and we need to read
 | 
| 206 |                         // more
 | 
| 207 |                         continue;
 | 
| 208 |                     }
 | 
| 209 | 
 | 
| 210 |                     char c = line[pos++];
 | 
| 211 |                     if (c == '"' && (pos < end) && line[pos] == '"') {
 | 
| 212 |                         // two double-quote => one double-quote
 | 
| 213 |                         element.push_back('"');
 | 
| 214 |                         ++pos;
 | 
| 215 |                     } else if (c == '"') {
 | 
| 216 |                         foundEndQuote = true;
 | 
| 217 |                     } else {
 | 
| 218 |                         element.push_back(c);
 | 
| 219 |                     }
 | 
| 220 |                 }
 | 
| 221 | 
 | 
| 222 |                 if (!foundEndQuote) {
 | 
| 223 |                     // missing closing quote
 | 
| 224 |                     std::stringstream errorMessage;
 | 
| 225 |                     errorMessage << "Unbalanced field quote in line " << lineNumber << "; ";
 | 
| 226 |                     throw std::invalid_argument(errorMessage.str());
 | 
| 227 |                 }
 | 
| 228 | 
 | 
| 229 |                 // field must be immediately followed by delimiter or end of line
 | 
| 230 |                 if (pos != line.length()) {
 | 
| 231 |                     std::size_t nextDelimiter = line.find(delimiter, pos);
 | 
| 232 |                     if (nextDelimiter != pos) {
 | 
| 233 |                         std::stringstream errorMessage;
 | 
| 234 |                         errorMessage << "Separator expected immediately after quoted field in line "
 | 
| 235 |                                      << lineNumber << "; ";
 | 
| 236 |                         throw std::invalid_argument(errorMessage.str());
 | 
| 237 |                     }
 | 
| 238 |                 }
 | 
| 239 | 
 | 
| 240 |                 start = pos + delimiter.size();
 | 
| 241 |                 return element;
 | 
| 242 |             } else {
 | 
| 243 |                 // non-quoted field, span until next delimiter or end of line
 | 
| 244 |                 const std::size_t end = std::min(line.find(delimiter, start), line.length());
 | 
| 245 |                 element = line.substr(start, end - start);
 | 
| 246 |                 start = end + delimiter.size();
 | 
| 247 | 
 | 
| 248 |                 return element;
 | 
| 249 |             }
 | 
| 250 |         }
 | 
| 251 | 
 | 
| 252 |         std::size_t end = start;
 | 
| 253 |         // Handle record/tuple delimiter coincidence.
 | 
| 254 |         if (delimiter.find(',') != std::string::npos) {
 | 
| 255 |             int record_parens = 0;
 | 
| 256 |             std::size_t next_delimiter = line.find(delimiter, start);
 | 
| 257 | 
 | 
| 258 |             // Find first delimiter after the record.
 | 
| 259 |             while (end < std::min(next_delimiter, line.length()) || record_parens != 0) {
 | 
| 260 |                 // Track the number of parenthesis.
 | 
| 261 |                 if (line[end] == '[') {
 | 
| 262 |                     ++record_parens;
 | 
| 263 |                 } else if (line[end] == ']') {
 | 
| 264 |                     --record_parens;
 | 
| 265 |                 }
 | 
| 266 | 
 | 
| 267 |                 // Check for unbalanced parenthesis.
 | 
| 268 |                 if (record_parens < 0) {
 | 
| 269 |                     break;
 | 
| 270 |                 };
 | 
| 271 | 
 | 
| 272 |                 ++end;
 | 
| 273 | 
 | 
| 274 |                 // Find a next delimiter if the old one is invalid.
 | 
| 275 |                 // But only if inside the unbalance parenthesis.
 | 
| 276 |                 if (end == next_delimiter && record_parens != 0) {
 | 
| 277 |                     next_delimiter = line.find(delimiter, end);
 | 
| 278 |                 }
 | 
| 279 |             }
 | 
| 280 | 
 | 
| 281 |             // Handle the end-of-the-line case where parenthesis are unbalanced.
 | 
| 282 |             if (record_parens != 0) {
 | 
| 283 |                 std::stringstream errorMessage;
 | 
| 284 |                 errorMessage << "Unbalanced record parenthesis in line " << lineNumber << "; ";
 | 
| 285 |                 throw std::invalid_argument(errorMessage.str());
 | 
| 286 |             }
 | 
| 287 |         } else {
 | 
| 288 |             end = std::min(line.find(delimiter, start), line.length());
 | 
| 289 |         }
 | 
| 290 | 
 | 
| 291 |         // Check for missing value.
 | 
| 292 |         if (start > end) {
 | 
| 293 |             std::stringstream errorMessage;
 | 
| 294 |             errorMessage << "Values missing in line " << lineNumber << "; ";
 | 
| 295 |             throw std::invalid_argument(errorMessage.str());
 | 
| 296 |         }
 | 
| 297 | 
 | 
| 298 |         element = line.substr(start, end - start);
 | 
| 299 |         start = end + delimiter.size();
 | 
| 300 | 
 | 
| 301 |         return element;
 | 
| 302 |     }
 | 
| 303 | 
 | 
| 304 |     std::map<int, int> getInputColumnMap(
 | 
| 305 |             const std::map<std::string, std::string>& rwOperation, const unsigned arity_) const {
 | 
| 306 |         std::string columnString = getOr(rwOperation, "columns", "");
 | 
| 307 |         std::map<int, int> inputColumnMap;
 | 
| 308 | 
 | 
| 309 |         if (!columnString.empty()) {
 | 
| 310 |             std::istringstream iss(columnString);
 | 
| 311 |             std::string mapping;
 | 
| 312 |             int index = 0;
 | 
| 313 |             while (std::getline(iss, mapping, ':')) {
 | 
| 314 |                 inputColumnMap[stoi(mapping)] = index++;
 | 
| 315 |             }
 | 
| 316 |             if (inputColumnMap.size() < arity_) {
 | 
| 317 |                 throw std::invalid_argument("Invalid column set was given: <" + columnString + ">");
 | 
| 318 |             }
 | 
| 319 |         } else {
 | 
| 320 |             while (inputColumnMap.size() < arity_) {
 | 
| 321 |                 int size = static_cast<int>(inputColumnMap.size());
 | 
| 322 |                 inputColumnMap[size] = size;
 | 
| 323 |             }
 | 
| 324 |         }
 | 
| 325 |         return inputColumnMap;
 | 
| 326 |     }
 | 
| 327 | 
 | 
| 328 |     const bool rfc4180;
 | 
| 329 |     const std::string delimiter;
 | 
| 330 |     std::istream& file;
 | 
| 331 |     std::size_t lineNumber;
 | 
| 332 |     std::map<int, int> inputMap;
 | 
| 333 | };
 | 
| 334 | 
 | 
| 335 | class ReadFileCSV : public ReadStreamCSV {
 | 
| 336 | public:
 | 
| 337 |     ReadFileCSV(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable,
 | 
| 338 |             RecordTable& recordTable)
 | 
| 339 |             : ReadStreamCSV(fileHandle, rwOperation, symbolTable, recordTable),
 | 
| 340 |               baseName(souffle::baseName(getFileName(rwOperation))),
 | 
| 341 |               fileHandle(getFileName(rwOperation), std::ios::in | std::ios::binary) {
 | 
| 342 |         if (!fileHandle.is_open()) {
 | 
| 343 |             // suppress error message in case file cannot be open when flag -w is set
 | 
| 344 |             if (getOr(rwOperation, "no-warn", "false") != "true") {
 | 
| 345 |                 throw std::invalid_argument("Cannot open fact file " + baseName + "\n");
 | 
| 346 |             }
 | 
| 347 |         }
 | 
| 348 |         // Strip headers if we're using them
 | 
| 349 |         if (getOr(rwOperation, "headers", "false") == "true") {
 | 
| 350 |             std::string line;
 | 
| 351 |             getline(file, line);
 | 
| 352 |         }
 | 
| 353 |     }
 | 
| 354 | 
 | 
| 355 |     /**
 | 
| 356 |      * Read and return the next tuple.
 | 
| 357 |      *
 | 
| 358 |      * Returns nullptr if no tuple was readable.
 | 
| 359 |      * @return
 | 
| 360 |      */
 | 
| 361 |     Own<RamDomain[]> readNextTuple() override {
 | 
| 362 |         try {
 | 
| 363 |             return ReadStreamCSV::readNextTuple();
 | 
| 364 |         } catch (std::exception& e) {
 | 
| 365 |             std::stringstream errorMessage;
 | 
| 366 |             errorMessage << e.what();
 | 
| 367 |             errorMessage << "cannot parse fact file " << baseName << "!\n";
 | 
| 368 |             throw std::invalid_argument(errorMessage.str());
 | 
| 369 |         }
 | 
| 370 |     }
 | 
| 371 | 
 | 
| 372 |     ~ReadFileCSV() override = default;
 | 
| 373 | 
 | 
| 374 | protected:
 | 
| 375 |     /**
 | 
| 376 |      * Return given filename or construct from relation name.
 | 
| 377 |      * Default name is [configured path]/[relation name].facts
 | 
| 378 |      *
 | 
| 379 |      * @param rwOperation map of IO configuration options
 | 
| 380 |      * @return input filename
 | 
| 381 |      */
 | 
| 382 |     static std::string getFileName(const std::map<std::string, std::string>& rwOperation) {
 | 
| 383 |         auto name = getOr(rwOperation, "filename", rwOperation.at("name") + ".facts");
 | 
| 384 |         if (!isAbsolute(name)) {
 | 
| 385 |             name = getOr(rwOperation, "fact-dir", ".") + pathSeparator + name;
 | 
| 386 |         }
 | 
| 387 |         return name;
 | 
| 388 |     }
 | 
| 389 | 
 | 
| 390 |     std::string baseName;
 | 
| 391 | #ifdef USE_LIBZ
 | 
| 392 |     gzfstream::igzfstream fileHandle;
 | 
| 393 | #else
 | 
| 394 |     std::ifstream fileHandle;
 | 
| 395 | #endif
 | 
| 396 | };
 | 
| 397 | 
 | 
| 398 | class ReadCinCSVFactory : public ReadStreamFactory {
 | 
| 399 | public:
 | 
| 400 |     Own<ReadStream> getReader(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable,
 | 
| 401 |             RecordTable& recordTable) override {
 | 
| 402 |         return mk<ReadStreamCSV>(std::cin, rwOperation, symbolTable, recordTable);
 | 
| 403 |     }
 | 
| 404 | 
 | 
| 405 |     const std::string& getName() const override {
 | 
| 406 |         static const std::string name = "stdin";
 | 
| 407 |         return name;
 | 
| 408 |     }
 | 
| 409 |     ~ReadCinCSVFactory() override = default;
 | 
| 410 | };
 | 
| 411 | 
 | 
| 412 | class ReadFileCSVFactory : public ReadStreamFactory {
 | 
| 413 | public:
 | 
| 414 |     Own<ReadStream> getReader(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable,
 | 
| 415 |             RecordTable& recordTable) override {
 | 
| 416 |         return mk<ReadFileCSV>(rwOperation, symbolTable, recordTable);
 | 
| 417 |     }
 | 
| 418 | 
 | 
| 419 |     const std::string& getName() const override {
 | 
| 420 |         static const std::string name = "file";
 | 
| 421 |         return name;
 | 
| 422 |     }
 | 
| 423 | 
 | 
| 424 |     ~ReadFileCSVFactory() override = default;
 | 
| 425 | };
 | 
| 426 | 
 | 
| 427 | } /* namespace souffle */
 |