| 1 | /* | 
| 2 | * Souffle - A Datalog Compiler | 
| 3 | * Copyright (c) 2021, The Souffle Developers. All rights reserved | 
| 4 | * Licensed under the Universal Permissive License v 1.0 as shown at: | 
| 5 | * - https://opensource.org/licenses/UPL | 
| 6 | * - <souffle root>/licenses/SOUFFLE-UPL.txt | 
| 7 | */ | 
| 8 |  | 
| 9 | /************************************************************************ | 
| 10 | * | 
| 11 | * @file ReadStreamCSV.h | 
| 12 | * | 
| 13 | ***********************************************************************/ | 
| 14 |  | 
| 15 | #pragma once | 
| 16 |  | 
| 17 | #include "souffle/RamTypes.h" | 
| 18 | #include "souffle/RecordTable.h" | 
| 19 | #include "souffle/SymbolTable.h" | 
| 20 | #include "souffle/io/ReadStream.h" | 
| 21 | #include "souffle/utility/ContainerUtil.h" | 
| 22 | #include "souffle/utility/FileUtil.h" | 
| 23 | #include "souffle/utility/StringUtil.h" | 
| 24 |  | 
| 25 | #ifdef USE_LIBZ | 
| 26 | #include "souffle/io/gzfstream.h" | 
| 27 | #else | 
| 28 | #include <fstream> | 
| 29 | #endif | 
| 30 |  | 
| 31 | #include <algorithm> | 
| 32 | #include <cassert> | 
| 33 | #include <cstddef> | 
| 34 | #include <cstdint> | 
| 35 | #include <iostream> | 
| 36 | #include <map> | 
| 37 | #include <memory> | 
| 38 | #include <sstream> | 
| 39 | #include <stdexcept> | 
| 40 | #include <string> | 
| 41 | #include <vector> | 
| 42 |  | 
| 43 | namespace souffle { | 
| 44 |  | 
| 45 | class ReadStreamCSV : public ReadStream { | 
| 46 | public: | 
| 47 | ReadStreamCSV(std::istream& file, const std::map<std::string, std::string>& rwOperation, | 
| 48 | SymbolTable& symbolTable, RecordTable& recordTable) | 
| 49 | : ReadStream(rwOperation, symbolTable, recordTable), | 
| 50 | rfc4180(getOr(rwOperation, "rfc4180", "false") == std::string("true")), | 
| 51 | delimiter(getOr(rwOperation, "delimiter", (rfc4180 ? "," : "\t"))), file(file), lineNumber(0), | 
| 52 | inputMap(getInputColumnMap(rwOperation, static_cast<unsigned int>(arity))) { | 
| 53 | if (rfc4180 && delimiter.find('"') != std::string::npos) { | 
| 54 | std::stringstream errorMessage; | 
| 55 | errorMessage << "CSV delimiter cannot contain '\"' character when rfc4180 is enabled."; | 
| 56 | throw std::invalid_argument(errorMessage.str()); | 
| 57 | } | 
| 58 |  | 
| 59 | while (inputMap.size() < arity) { | 
| 60 | int size = static_cast<int>(inputMap.size()); | 
| 61 | inputMap[size] = size; | 
| 62 | } | 
| 63 | } | 
| 64 |  | 
| 65 | protected: | 
| 66 | bool readNextLine(std::string& line, bool& isCRLF) { | 
| 67 | if (!getline(file, line)) { | 
| 68 | return false; | 
| 69 | } | 
| 70 | // Handle Windows line endings on non-Windows systems | 
| 71 | isCRLF = !line.empty() && line.back() == '\r'; | 
| 72 | if (isCRLF) { | 
| 73 | line.pop_back(); | 
| 74 | } | 
| 75 | ++lineNumber; | 
| 76 | return true; | 
| 77 | } | 
| 78 |  | 
| 79 | /** | 
| 80 | * Read and return the next tuple. | 
| 81 | * | 
| 82 | * Returns nullptr if no tuple was readable. | 
| 83 | * @return | 
| 84 | */ | 
| 85 | Own<RamDomain[]> readNextTuple() override { | 
| 86 | if (file.eof()) { | 
| 87 | return nullptr; | 
| 88 | } | 
| 89 | std::string line; | 
| 90 | Own<RamDomain[]> tuple = mk<RamDomain[]>(typeAttributes.size()); | 
| 91 | bool wasCRLF = false; | 
| 92 | if (!readNextLine(line, wasCRLF)) { | 
| 93 | return nullptr; | 
| 94 | } | 
| 95 |  | 
| 96 | std::size_t start = 0; | 
| 97 | std::size_t columnsFilled = 0; | 
| 98 | for (uint32_t column = 0; columnsFilled < arity; column++) { | 
| 99 | std::size_t charactersRead = 0; | 
| 100 | std::string element = nextElement(line, start, wasCRLF); | 
| 101 | if (inputMap.count(column) == 0) { | 
| 102 | continue; | 
| 103 | } | 
| 104 | ++columnsFilled; | 
| 105 |  | 
| 106 | try { | 
| 107 | auto&& ty = typeAttributes.at(inputMap[column]); | 
| 108 | switch (ty[0]) { | 
| 109 | case 's': { | 
| 110 | tuple[inputMap[column]] = symbolTable.encode(element); | 
| 111 | charactersRead = element.size(); | 
| 112 | break; | 
| 113 | } | 
| 114 | case 'r': { | 
| 115 | tuple[inputMap[column]] = readRecord(element, ty, 0, &charactersRead); | 
| 116 | break; | 
| 117 | } | 
| 118 | case '+': { | 
| 119 | tuple[inputMap[column]] = readADT(element, ty, 0, &charactersRead); | 
| 120 | break; | 
| 121 | } | 
| 122 | case 'i': { | 
| 123 | tuple[inputMap[column]] = RamSignedFromString(element, &charactersRead); | 
| 124 | break; | 
| 125 | } | 
| 126 | case 'u': { | 
| 127 | tuple[inputMap[column]] = ramBitCast(readRamUnsigned(element, charactersRead)); | 
| 128 | break; | 
| 129 | } | 
| 130 | case 'f': { | 
| 131 | tuple[inputMap[column]] = ramBitCast(RamFloatFromString(element, &charactersRead)); | 
| 132 | break; | 
| 133 | } | 
| 134 | default: fatal("invalid type attribute: `%c`", ty[0]); | 
| 135 | } | 
| 136 | // Check if everything was read. | 
| 137 | if (charactersRead != element.size()) { | 
| 138 | throw std::invalid_argument( | 
| 139 | "Expected: " + delimiter + " or \\n. Got: " + element[charactersRead]); | 
| 140 | } | 
| 141 | } catch (...) { | 
| 142 | std::stringstream errorMessage; | 
| 143 | errorMessage << "Error converting <" + element + "> in column " << column + 1 << " in line " | 
| 144 | << lineNumber << "; "; | 
| 145 | throw std::invalid_argument(errorMessage.str()); | 
| 146 | } | 
| 147 | } | 
| 148 |  | 
| 149 | return tuple; | 
| 150 | } | 
| 151 |  | 
| 152 | /** | 
| 153 | * Read an unsigned element. Possible bases are 2, 10, 16 | 
| 154 | * Base is indicated by the first two chars. | 
| 155 | */ | 
| 156 | RamUnsigned readRamUnsigned(const std::string& element, std::size_t& charactersRead) { | 
| 157 | // Sanity check | 
| 158 | assert(element.size() > 0); | 
| 159 |  | 
| 160 | RamSigned value = 0; | 
| 161 |  | 
| 162 | // Check prefix and parse the input. | 
| 163 | if (isPrefix("0b", element)) { | 
| 164 | value = RamUnsignedFromString(element, &charactersRead, 2); | 
| 165 | } else if (isPrefix("0x", element)) { | 
| 166 | value = RamUnsignedFromString(element, &charactersRead, 16); | 
| 167 | } else { | 
| 168 | value = RamUnsignedFromString(element, &charactersRead); | 
| 169 | } | 
| 170 | return value; | 
| 171 | } | 
| 172 |  | 
| 173 | std::string nextElement(std::string& line, std::size_t& start, bool& wasCRLF) { | 
| 174 | std::string element; | 
| 175 |  | 
| 176 | if (rfc4180) { | 
| 177 | if (line[start] == '"') { | 
| 178 | // quoted field | 
| 179 | std::size_t end = line.length(); | 
| 180 | std::size_t pos = start + 1; | 
| 181 | bool foundEndQuote = false; | 
| 182 | while (!foundEndQuote) { | 
| 183 | if (pos == end) { | 
| 184 | bool newWasCRLF = false; | 
| 185 | if (!readNextLine(line, newWasCRLF)) { | 
| 186 | break; | 
| 187 | } | 
| 188 | // account for \r\n or \n that we had previously | 
| 189 | // read and thrown out. | 
| 190 | // since we're in a quote, we should restore | 
| 191 | // what the user provided | 
| 192 | if (wasCRLF) { | 
| 193 | element.push_back('\r'); | 
| 194 | } | 
| 195 | element.push_back('\n'); | 
| 196 |  | 
| 197 | // remember if we just read a CRLF sequence | 
| 198 | wasCRLF = newWasCRLF; | 
| 199 |  | 
| 200 | // start over | 
| 201 | pos = 0; | 
| 202 | end = line.length(); | 
| 203 | } | 
| 204 | if (pos == end) { | 
| 205 | // this means we've got a blank line and we need to read | 
| 206 | // more | 
| 207 | continue; | 
| 208 | } | 
| 209 |  | 
| 210 | char c = line[pos++]; | 
| 211 | if (c == '"' && (pos < end) && line[pos] == '"') { | 
| 212 | // two double-quote => one double-quote | 
| 213 | element.push_back('"'); | 
| 214 | ++pos; | 
| 215 | } else if (c == '"') { | 
| 216 | foundEndQuote = true; | 
| 217 | } else { | 
| 218 | element.push_back(c); | 
| 219 | } | 
| 220 | } | 
| 221 |  | 
| 222 | if (!foundEndQuote) { | 
| 223 | // missing closing quote | 
| 224 | std::stringstream errorMessage; | 
| 225 | errorMessage << "Unbalanced field quote in line " << lineNumber << "; "; | 
| 226 | throw std::invalid_argument(errorMessage.str()); | 
| 227 | } | 
| 228 |  | 
| 229 | // field must be immediately followed by delimiter or end of line | 
| 230 | if (pos != line.length()) { | 
| 231 | std::size_t nextDelimiter = line.find(delimiter, pos); | 
| 232 | if (nextDelimiter != pos) { | 
| 233 | std::stringstream errorMessage; | 
| 234 | errorMessage << "Separator expected immediately after quoted field in line " | 
| 235 | << lineNumber << "; "; | 
| 236 | throw std::invalid_argument(errorMessage.str()); | 
| 237 | } | 
| 238 | } | 
| 239 |  | 
| 240 | start = pos + delimiter.size(); | 
| 241 | return element; | 
| 242 | } else { | 
| 243 | // non-quoted field, span until next delimiter or end of line | 
| 244 | const std::size_t end = std::min(line.find(delimiter, start), line.length()); | 
| 245 | element = line.substr(start, end - start); | 
| 246 | start = end + delimiter.size(); | 
| 247 |  | 
| 248 | return element; | 
| 249 | } | 
| 250 | } | 
| 251 |  | 
| 252 | std::size_t end = start; | 
| 253 | // Handle record/tuple delimiter coincidence. | 
| 254 | if (delimiter.find(',') != std::string::npos) { | 
| 255 | int record_parens = 0; | 
| 256 | std::size_t next_delimiter = line.find(delimiter, start); | 
| 257 |  | 
| 258 | // Find first delimiter after the record. | 
| 259 | while (end < std::min(next_delimiter, line.length()) || record_parens != 0) { | 
| 260 | // Track the number of parenthesis. | 
| 261 | if (line[end] == '[') { | 
| 262 | ++record_parens; | 
| 263 | } else if (line[end] == ']') { | 
| 264 | --record_parens; | 
| 265 | } | 
| 266 |  | 
| 267 | // Check for unbalanced parenthesis. | 
| 268 | if (record_parens < 0) { | 
| 269 | break; | 
| 270 | }; | 
| 271 |  | 
| 272 | ++end; | 
| 273 |  | 
| 274 | // Find a next delimiter if the old one is invalid. | 
| 275 | // But only if inside the unbalance parenthesis. | 
| 276 | if (end == next_delimiter && record_parens != 0) { | 
| 277 | next_delimiter = line.find(delimiter, end); | 
| 278 | } | 
| 279 | } | 
| 280 |  | 
| 281 | // Handle the end-of-the-line case where parenthesis are unbalanced. | 
| 282 | if (record_parens != 0) { | 
| 283 | std::stringstream errorMessage; | 
| 284 | errorMessage << "Unbalanced record parenthesis in line " << lineNumber << "; "; | 
| 285 | throw std::invalid_argument(errorMessage.str()); | 
| 286 | } | 
| 287 | } else { | 
| 288 | end = std::min(line.find(delimiter, start), line.length()); | 
| 289 | } | 
| 290 |  | 
| 291 | // Check for missing value. | 
| 292 | if (start > end) { | 
| 293 | std::stringstream errorMessage; | 
| 294 | errorMessage << "Values missing in line " << lineNumber << "; "; | 
| 295 | throw std::invalid_argument(errorMessage.str()); | 
| 296 | } | 
| 297 |  | 
| 298 | element = line.substr(start, end - start); | 
| 299 | start = end + delimiter.size(); | 
| 300 |  | 
| 301 | return element; | 
| 302 | } | 
| 303 |  | 
| 304 | std::map<int, int> getInputColumnMap( | 
| 305 | const std::map<std::string, std::string>& rwOperation, const unsigned arity_) const { | 
| 306 | std::string columnString = getOr(rwOperation, "columns", ""); | 
| 307 | std::map<int, int> inputColumnMap; | 
| 308 |  | 
| 309 | if (!columnString.empty()) { | 
| 310 | std::istringstream iss(columnString); | 
| 311 | std::string mapping; | 
| 312 | int index = 0; | 
| 313 | while (std::getline(iss, mapping, ':')) { | 
| 314 | inputColumnMap[stoi(mapping)] = index++; | 
| 315 | } | 
| 316 | if (inputColumnMap.size() < arity_) { | 
| 317 | throw std::invalid_argument("Invalid column set was given: <" + columnString + ">"); | 
| 318 | } | 
| 319 | } else { | 
| 320 | while (inputColumnMap.size() < arity_) { | 
| 321 | int size = static_cast<int>(inputColumnMap.size()); | 
| 322 | inputColumnMap[size] = size; | 
| 323 | } | 
| 324 | } | 
| 325 | return inputColumnMap; | 
| 326 | } | 
| 327 |  | 
| 328 | const bool rfc4180; | 
| 329 | const std::string delimiter; | 
| 330 | std::istream& file; | 
| 331 | std::size_t lineNumber; | 
| 332 | std::map<int, int> inputMap; | 
| 333 | }; | 
| 334 |  | 
| 335 | class ReadFileCSV : public ReadStreamCSV { | 
| 336 | public: | 
| 337 | ReadFileCSV(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable, | 
| 338 | RecordTable& recordTable) | 
| 339 | : ReadStreamCSV(fileHandle, rwOperation, symbolTable, recordTable), | 
| 340 | baseName(souffle::baseName(getFileName(rwOperation))), | 
| 341 | fileHandle(getFileName(rwOperation), std::ios::in | std::ios::binary) { | 
| 342 | if (!fileHandle.is_open()) { | 
| 343 | // suppress error message in case file cannot be open when flag -w is set | 
| 344 | if (getOr(rwOperation, "no-warn", "false") != "true") { | 
| 345 | throw std::invalid_argument("Cannot open fact file " + baseName + "\n"); | 
| 346 | } | 
| 347 | } | 
| 348 | // Strip headers if we're using them | 
| 349 | if (getOr(rwOperation, "headers", "false") == "true") { | 
| 350 | std::string line; | 
| 351 | getline(file, line); | 
| 352 | } | 
| 353 | } | 
| 354 |  | 
| 355 | /** | 
| 356 | * Read and return the next tuple. | 
| 357 | * | 
| 358 | * Returns nullptr if no tuple was readable. | 
| 359 | * @return | 
| 360 | */ | 
| 361 | Own<RamDomain[]> readNextTuple() override { | 
| 362 | try { | 
| 363 | return ReadStreamCSV::readNextTuple(); | 
| 364 | } catch (std::exception& e) { | 
| 365 | std::stringstream errorMessage; | 
| 366 | errorMessage << e.what(); | 
| 367 | errorMessage << "cannot parse fact file " << baseName << "!\n"; | 
| 368 | throw std::invalid_argument(errorMessage.str()); | 
| 369 | } | 
| 370 | } | 
| 371 |  | 
| 372 | ~ReadFileCSV() override = default; | 
| 373 |  | 
| 374 | protected: | 
| 375 | /** | 
| 376 | * Return given filename or construct from relation name. | 
| 377 | * Default name is [configured path]/[relation name].facts | 
| 378 | * | 
| 379 | * @param rwOperation map of IO configuration options | 
| 380 | * @return input filename | 
| 381 | */ | 
| 382 | static std::string getFileName(const std::map<std::string, std::string>& rwOperation) { | 
| 383 | auto name = getOr(rwOperation, "filename", rwOperation.at("name") + ".facts"); | 
| 384 | if (!isAbsolute(name)) { | 
| 385 | name = getOr(rwOperation, "fact-dir", ".") + pathSeparator + name; | 
| 386 | } | 
| 387 | return name; | 
| 388 | } | 
| 389 |  | 
| 390 | std::string baseName; | 
| 391 | #ifdef USE_LIBZ | 
| 392 | gzfstream::igzfstream fileHandle; | 
| 393 | #else | 
| 394 | std::ifstream fileHandle; | 
| 395 | #endif | 
| 396 | }; | 
| 397 |  | 
| 398 | class ReadCinCSVFactory : public ReadStreamFactory { | 
| 399 | public: | 
| 400 | Own<ReadStream> getReader(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable, | 
| 401 | RecordTable& recordTable) override { | 
| 402 | return mk<ReadStreamCSV>(std::cin, rwOperation, symbolTable, recordTable); | 
| 403 | } | 
| 404 |  | 
| 405 | const std::string& getName() const override { | 
| 406 | static const std::string name = "stdin"; | 
| 407 | return name; | 
| 408 | } | 
| 409 | ~ReadCinCSVFactory() override = default; | 
| 410 | }; | 
| 411 |  | 
| 412 | class ReadFileCSVFactory : public ReadStreamFactory { | 
| 413 | public: | 
| 414 | Own<ReadStream> getReader(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable, | 
| 415 | RecordTable& recordTable) override { | 
| 416 | return mk<ReadFileCSV>(rwOperation, symbolTable, recordTable); | 
| 417 | } | 
| 418 |  | 
| 419 | const std::string& getName() const override { | 
| 420 | static const std::string name = "file"; | 
| 421 | return name; | 
| 422 | } | 
| 423 |  | 
| 424 | ~ReadFileCSVFactory() override = default; | 
| 425 | }; | 
| 426 |  | 
| 427 | } /* namespace souffle */ |