| 1 | /*
 | 
| 2 |  * Souffle - A Datalog Compiler
 | 
| 3 |  * Copyright (c) 2021, The Souffle Developers. All rights reserved
 | 
| 4 |  * Licensed under the Universal Permissive License v 1.0 as shown at:
 | 
| 5 |  * - https://opensource.org/licenses/UPL
 | 
| 6 |  * - <souffle root>/licenses/SOUFFLE-UPL.txt
 | 
| 7 |  */
 | 
| 8 | 
 | 
| 9 | /************************************************************************
 | 
| 10 |  *
 | 
| 11 |  * @file ReadStreamSQLite.h
 | 
| 12 |  *
 | 
| 13 |  ***********************************************************************/
 | 
| 14 | 
 | 
| 15 | #pragma once
 | 
| 16 | 
 | 
| 17 | #include "souffle/RamTypes.h"
 | 
| 18 | #include "souffle/RecordTable.h"
 | 
| 19 | #include "souffle/SymbolTable.h"
 | 
| 20 | #include "souffle/io/ReadStream.h"
 | 
| 21 | #include "souffle/utility/MiscUtil.h"
 | 
| 22 | #include "souffle/utility/StringUtil.h"
 | 
| 23 | #include <cassert>
 | 
| 24 | #include <cstdint>
 | 
| 25 | #include <fstream>
 | 
| 26 | #include <map>
 | 
| 27 | #include <memory>
 | 
| 28 | #include <stdexcept>
 | 
| 29 | #include <string>
 | 
| 30 | #include <vector>
 | 
| 31 | #include <sqlite3.h>
 | 
| 32 | 
 | 
| 33 | namespace souffle {
 | 
| 34 | 
 | 
| 35 | class ReadStreamSQLite : public ReadStream {
 | 
| 36 | public:
 | 
| 37 |     ReadStreamSQLite(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable,
 | 
| 38 |             RecordTable& recordTable)
 | 
| 39 |             : ReadStream(rwOperation, symbolTable, recordTable), dbFilename(getFileName(rwOperation)),
 | 
| 40 |               relationName(rwOperation.at("name")) {
 | 
| 41 |         openDB();
 | 
| 42 |         checkTableExists();
 | 
| 43 |         prepareSelectStatement();
 | 
| 44 |     }
 | 
| 45 | 
 | 
| 46 |     ~ReadStreamSQLite() override {
 | 
| 47 |         sqlite3_finalize(selectStatement);
 | 
| 48 |         sqlite3_close(db);
 | 
| 49 |     }
 | 
| 50 | 
 | 
| 51 | protected:
 | 
| 52 |     /**
 | 
| 53 |      * Read and return the next tuple.
 | 
| 54 |      *
 | 
| 55 |      * Returns nullptr if no tuple was readable.
 | 
| 56 |      * @return
 | 
| 57 |      */
 | 
| 58 |     Own<RamDomain[]> readNextTuple() override {
 | 
| 59 |         if (sqlite3_step(selectStatement) != SQLITE_ROW) {
 | 
| 60 |             return nullptr;
 | 
| 61 |         }
 | 
| 62 | 
 | 
| 63 |         Own<RamDomain[]> tuple = mk<RamDomain[]>(arity + auxiliaryArity);
 | 
| 64 | 
 | 
| 65 |         uint32_t column;
 | 
| 66 |         for (column = 0; column < arity; column++) {
 | 
| 67 |             std::string element;
 | 
| 68 |             if (0 == sqlite3_column_bytes(selectStatement, column)) {
 | 
| 69 |                 element = "";
 | 
| 70 |             } else {
 | 
| 71 |                 element = reinterpret_cast<const char*>(sqlite3_column_text(selectStatement, column));
 | 
| 72 | 
 | 
| 73 |                 if (element.empty()) {
 | 
| 74 |                     element = "";
 | 
| 75 |                 }
 | 
| 76 |             }
 | 
| 77 | 
 | 
| 78 |             try {
 | 
| 79 |                 auto&& ty = typeAttributes.at(column);
 | 
| 80 |                 switch (ty[0]) {
 | 
| 81 |                     case 's': tuple[column] = symbolTable.encode(element); break;
 | 
| 82 |                     case 'f': tuple[column] = ramBitCast(RamFloatFromString(element)); break;
 | 
| 83 |                     case 'i':
 | 
| 84 |                     case 'u':
 | 
| 85 |                     case 'r': tuple[column] = RamSignedFromString(element); break;
 | 
| 86 |                     default: fatal("invalid type attribute: `%c`", ty[0]);
 | 
| 87 |                 }
 | 
| 88 |             } catch (...) {
 | 
| 89 |                 std::stringstream errorMessage;
 | 
| 90 |                 errorMessage << "Error converting number in column " << (column) + 1;
 | 
| 91 |                 throw std::invalid_argument(errorMessage.str());
 | 
| 92 |             }
 | 
| 93 |         }
 | 
| 94 | 
 | 
| 95 |         return tuple;
 | 
| 96 |     }
 | 
| 97 | 
 | 
| 98 |     void executeSQL(const std::string& sql) {
 | 
| 99 |         assert(db && "Database connection is closed");
 | 
| 100 | 
 | 
| 101 |         char* errorMessage = nullptr;
 | 
| 102 |         /* Execute SQL statement */
 | 
| 103 |         int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &errorMessage);
 | 
| 104 |         if (rc != SQLITE_OK) {
 | 
| 105 |             std::stringstream error;
 | 
| 106 |             error << "SQLite error in sqlite3_exec: " << sqlite3_errmsg(db) << "\n";
 | 
| 107 |             error << "SQL error: " << errorMessage << "\n";
 | 
| 108 |             error << "SQL: " << sql << "\n";
 | 
| 109 |             sqlite3_free(errorMessage);
 | 
| 110 |             throw std::invalid_argument(error.str());
 | 
| 111 |         }
 | 
| 112 |     }
 | 
| 113 | 
 | 
| 114 |     void throwError(const std::string& message) {
 | 
| 115 |         std::stringstream error;
 | 
| 116 |         error << message << sqlite3_errmsg(db) << "\n";
 | 
| 117 |         throw std::invalid_argument(error.str());
 | 
| 118 |     }
 | 
| 119 | 
 | 
| 120 |     void prepareSelectStatement() {
 | 
| 121 |         std::stringstream selectSQL;
 | 
| 122 |         selectSQL << "SELECT * FROM '" << relationName << "'";
 | 
| 123 |         const char* tail = nullptr;
 | 
| 124 |         if (sqlite3_prepare_v2(db, selectSQL.str().c_str(), -1, &selectStatement, &tail) != SQLITE_OK) {
 | 
| 125 |             throwError("SQLite error in sqlite3_prepare_v2: ");
 | 
| 126 |         }
 | 
| 127 |     }
 | 
| 128 | 
 | 
| 129 |     void openDB() {
 | 
| 130 |         sqlite3_config(SQLITE_CONFIG_URI, 1);
 | 
| 131 |         if (sqlite3_open(dbFilename.c_str(), &db) != SQLITE_OK) {
 | 
| 132 |             throwError("SQLite error in sqlite3_open: ");
 | 
| 133 |         }
 | 
| 134 |         sqlite3_extended_result_codes(db, 1);
 | 
| 135 |         executeSQL("PRAGMA synchronous = OFF");
 | 
| 136 |         executeSQL("PRAGMA journal_mode = MEMORY");
 | 
| 137 |     }
 | 
| 138 | 
 | 
| 139 |     void checkTableExists() {
 | 
| 140 |         sqlite3_stmt* tableStatement;
 | 
| 141 |         std::stringstream selectSQL;
 | 
| 142 |         selectSQL << "SELECT count(*) FROM sqlite_master WHERE type IN ('table', 'view') AND ";
 | 
| 143 |         selectSQL << " name = '" << relationName << "';";
 | 
| 144 |         const char* tail = nullptr;
 | 
| 145 | 
 | 
| 146 |         if (sqlite3_prepare_v2(db, selectSQL.str().c_str(), -1, &tableStatement, &tail) != SQLITE_OK) {
 | 
| 147 |             throwError("SQLite error in sqlite3_prepare_v2: ");
 | 
| 148 |         }
 | 
| 149 | 
 | 
| 150 |         if (sqlite3_step(tableStatement) == SQLITE_ROW) {
 | 
| 151 |             int count = sqlite3_column_int(tableStatement, 0);
 | 
| 152 |             if (count > 0) {
 | 
| 153 |                 sqlite3_finalize(tableStatement);
 | 
| 154 |                 return;
 | 
| 155 |             }
 | 
| 156 |         }
 | 
| 157 |         sqlite3_finalize(tableStatement);
 | 
| 158 |         throw std::invalid_argument(
 | 
| 159 |                 "Required table or view does not exist in " + dbFilename + " for relation " + relationName);
 | 
| 160 |     }
 | 
| 161 | 
 | 
| 162 |     /**
 | 
| 163 |      * Return given filename or construct from relation name.
 | 
| 164 |      * Default name is [configured path]/[relation name].sqlite
 | 
| 165 |      *
 | 
| 166 |      * @param rwOperation map of IO configuration options
 | 
| 167 |      * @return input filename
 | 
| 168 |      */
 | 
| 169 |     static std::string getFileName(const std::map<std::string, std::string>& rwOperation) {
 | 
| 170 |         // legacy support for SQLite prior to 2020-03-18
 | 
| 171 |         // convert dbname to filename
 | 
| 172 |         auto name = getOr(rwOperation, "dbname", rwOperation.at("name") + ".sqlite");
 | 
| 173 |         name = getOr(rwOperation, "filename", name);
 | 
| 174 | 
 | 
| 175 |         if (name.rfind("file:", 0) == 0 || name.rfind(":memory:", 0) == 0) {
 | 
| 176 |             return name;
 | 
| 177 |         }
 | 
| 178 | 
 | 
| 179 |         if (name.front() != '/') {
 | 
| 180 |             name = getOr(rwOperation, "fact-dir", ".") + "/" + name;
 | 
| 181 |         }
 | 
| 182 |         return name;
 | 
| 183 |     }
 | 
| 184 | 
 | 
| 185 |     const std::string dbFilename;
 | 
| 186 |     const std::string relationName;
 | 
| 187 |     sqlite3_stmt* selectStatement = nullptr;
 | 
| 188 |     sqlite3* db = nullptr;
 | 
| 189 | };
 | 
| 190 | 
 | 
| 191 | class ReadSQLiteFactory : public ReadStreamFactory {
 | 
| 192 | public:
 | 
| 193 |     Own<ReadStream> getReader(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable,
 | 
| 194 |             RecordTable& recordTable) override {
 | 
| 195 |         return mk<ReadStreamSQLite>(rwOperation, symbolTable, recordTable);
 | 
| 196 |     }
 | 
| 197 | 
 | 
| 198 |     const std::string& getName() const override {
 | 
| 199 |         static const std::string name = "sqlite";
 | 
| 200 |         return name;
 | 
| 201 |     }
 | 
| 202 |     ~ReadSQLiteFactory() override = default;
 | 
| 203 | };
 | 
| 204 | 
 | 
| 205 | } /* namespace souffle */
 |