| 1 | /*
 | 
| 2 |  * Souffle - A Datalog Compiler
 | 
| 3 |  * Copyright (c) 2021, The Souffle Developers. All rights reserved
 | 
| 4 |  * Licensed under the Universal Permissive License v 1.0 as shown at:
 | 
| 5 |  * - https://opensource.org/licenses/UPL
 | 
| 6 |  * - <souffle root>/licenses/SOUFFLE-UPL.txt
 | 
| 7 |  */
 | 
| 8 | 
 | 
| 9 | /************************************************************************
 | 
| 10 |  *
 | 
| 11 |  * @file WriteStreamSQLite.h
 | 
| 12 |  *
 | 
| 13 |  ***********************************************************************/
 | 
| 14 | 
 | 
| 15 | #pragma once
 | 
| 16 | 
 | 
| 17 | #include "souffle/RamTypes.h"
 | 
| 18 | #include "souffle/RecordTable.h"
 | 
| 19 | #include "souffle/SymbolTable.h"
 | 
| 20 | #include "souffle/io/WriteStream.h"
 | 
| 21 | #include <cassert>
 | 
| 22 | #include <cstddef>
 | 
| 23 | #include <cstdint>
 | 
| 24 | #include <map>
 | 
| 25 | #include <memory>
 | 
| 26 | #include <sstream>
 | 
| 27 | #include <stdexcept>
 | 
| 28 | #include <string>
 | 
| 29 | #include <unordered_map>
 | 
| 30 | #include <vector>
 | 
| 31 | #include <sqlite3.h>
 | 
| 32 | 
 | 
| 33 | namespace souffle {
 | 
| 34 | 
 | 
| 35 | class WriteStreamSQLite : public WriteStream {
 | 
| 36 | public:
 | 
| 37 |     WriteStreamSQLite(const std::map<std::string, std::string>& rwOperation, const SymbolTable& symbolTable,
 | 
| 38 |             const RecordTable& recordTable)
 | 
| 39 |             : WriteStream(rwOperation, symbolTable, recordTable), dbFilename(getFileName(rwOperation)),
 | 
| 40 |               relationName(rwOperation.at("name")) {
 | 
| 41 |         openDB();
 | 
| 42 |         createTables();
 | 
| 43 |         prepareStatements();
 | 
| 44 |         executeSQL("BEGIN TRANSACTION", db);
 | 
| 45 |     }
 | 
| 46 | 
 | 
| 47 |     ~WriteStreamSQLite() override {
 | 
| 48 |         executeSQL("COMMIT", db);
 | 
| 49 |         sqlite3_finalize(insertStatement);
 | 
| 50 |         sqlite3_finalize(symbolInsertStatement);
 | 
| 51 |         sqlite3_finalize(symbolSelectStatement);
 | 
| 52 |         sqlite3_close(db);
 | 
| 53 |     }
 | 
| 54 | 
 | 
| 55 | protected:
 | 
| 56 |     void writeNullary() override {}
 | 
| 57 | 
 | 
| 58 |     void writeNextTuple(const RamDomain* tuple) override {
 | 
| 59 |         for (std::size_t i = 0; i < arity; i++) {
 | 
| 60 |             RamDomain value = 0;  // Silence warning
 | 
| 61 | 
 | 
| 62 |             switch (typeAttributes.at(i)[0]) {
 | 
| 63 |                 case 's': value = getSymbolTableID(tuple[i]); break;
 | 
| 64 |                 default: value = tuple[i]; break;
 | 
| 65 |             }
 | 
| 66 | 
 | 
| 67 | #if RAM_DOMAIN_SIZE == 64
 | 
| 68 |             if (sqlite3_bind_int64(insertStatement, static_cast<int>(i + 1),
 | 
| 69 |                         static_cast<sqlite3_int64>(value)) != SQLITE_OK) {
 | 
| 70 | #else
 | 
| 71 |             if (sqlite3_bind_int(insertStatement, static_cast<int>(i + 1), static_cast<int>(value)) !=
 | 
| 72 |                     SQLITE_OK) {
 | 
| 73 | #endif
 | 
| 74 |                 throwError("SQLite error in sqlite3_bind_text: ");
 | 
| 75 |             }
 | 
| 76 |         }
 | 
| 77 |         if (sqlite3_step(insertStatement) != SQLITE_DONE) {
 | 
| 78 |             throwError("SQLite error in sqlite3_step: ");
 | 
| 79 |         }
 | 
| 80 |         sqlite3_clear_bindings(insertStatement);
 | 
| 81 |         sqlite3_reset(insertStatement);
 | 
| 82 |     }
 | 
| 83 | 
 | 
| 84 | private:
 | 
| 85 |     void executeSQL(const std::string& sql, sqlite3* db) {
 | 
| 86 |         assert(db && "Database connection is closed");
 | 
| 87 | 
 | 
| 88 |         char* errorMessage = nullptr;
 | 
| 89 |         /* Execute SQL statement */
 | 
| 90 |         int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &errorMessage);
 | 
| 91 |         if (rc != SQLITE_OK) {
 | 
| 92 |             std::stringstream error;
 | 
| 93 |             error << "SQLite error in sqlite3_exec: " << sqlite3_errmsg(db) << "\n";
 | 
| 94 |             error << "SQL error: " << errorMessage << "\n";
 | 
| 95 |             error << "SQL: " << sql << "\n";
 | 
| 96 |             sqlite3_free(errorMessage);
 | 
| 97 |             throw std::invalid_argument(error.str());
 | 
| 98 |         }
 | 
| 99 |     }
 | 
| 100 | 
 | 
| 101 |     void throwError(const std::string& message) {
 | 
| 102 |         std::stringstream error;
 | 
| 103 |         error << message << sqlite3_errmsg(db) << "\n";
 | 
| 104 |         throw std::invalid_argument(error.str());
 | 
| 105 |     }
 | 
| 106 | 
 | 
| 107 |     uint64_t getSymbolTableIDFromDB(std::size_t index) {
 | 
| 108 |         if (sqlite3_bind_text(symbolSelectStatement, 1, symbolTable.decode(index).c_str(), -1,
 | 
| 109 |                     SQLITE_TRANSIENT) != SQLITE_OK) {
 | 
| 110 |             throwError("SQLite error in sqlite3_bind_text: ");
 | 
| 111 |         }
 | 
| 112 |         if (sqlite3_step(symbolSelectStatement) != SQLITE_ROW) {
 | 
| 113 |             throwError("SQLite error in sqlite3_step: ");
 | 
| 114 |         }
 | 
| 115 |         uint64_t rowid = sqlite3_column_int64(symbolSelectStatement, 0);
 | 
| 116 |         sqlite3_clear_bindings(symbolSelectStatement);
 | 
| 117 |         sqlite3_reset(symbolSelectStatement);
 | 
| 118 |         return rowid;
 | 
| 119 |     }
 | 
| 120 |     uint64_t getSymbolTableID(std::size_t index) {
 | 
| 121 |         if (dbSymbolTable.count(index) != 0) {
 | 
| 122 |             return dbSymbolTable[index];
 | 
| 123 |         }
 | 
| 124 | 
 | 
| 125 |         if (sqlite3_bind_text(symbolInsertStatement, 1, symbolTable.decode(index).c_str(), -1,
 | 
| 126 |                     SQLITE_TRANSIENT) != SQLITE_OK) {
 | 
| 127 |             throwError("SQLite error in sqlite3_bind_text: ");
 | 
| 128 |         }
 | 
| 129 |         // Either the insert succeeds and we have a new row id or it already exists and a select is needed.
 | 
| 130 |         uint64_t rowid;
 | 
| 131 |         if (sqlite3_step(symbolInsertStatement) != SQLITE_DONE) {
 | 
| 132 |             // The symbol already exists so select it.
 | 
| 133 |             rowid = getSymbolTableIDFromDB(index);
 | 
| 134 |         } else {
 | 
| 135 |             rowid = sqlite3_last_insert_rowid(db);
 | 
| 136 |         }
 | 
| 137 |         sqlite3_clear_bindings(symbolInsertStatement);
 | 
| 138 |         sqlite3_reset(symbolInsertStatement);
 | 
| 139 | 
 | 
| 140 |         dbSymbolTable[index] = rowid;
 | 
| 141 |         return rowid;
 | 
| 142 |     }
 | 
| 143 | 
 | 
| 144 |     void openDB() {
 | 
| 145 |         sqlite3_config(SQLITE_CONFIG_URI, 1);
 | 
| 146 |         if (sqlite3_open(dbFilename.c_str(), &db) != SQLITE_OK) {
 | 
| 147 |             throwError("SQLite error in sqlite3_open: ");
 | 
| 148 |         }
 | 
| 149 |         sqlite3_extended_result_codes(db, 1);
 | 
| 150 |         executeSQL("PRAGMA synchronous = OFF", db);
 | 
| 151 |         executeSQL("PRAGMA journal_mode = MEMORY", db);
 | 
| 152 |     }
 | 
| 153 | 
 | 
| 154 |     void prepareStatements() {
 | 
| 155 |         prepareInsertStatement();
 | 
| 156 |         prepareSymbolInsertStatement();
 | 
| 157 |         prepareSymbolSelectStatement();
 | 
| 158 |     }
 | 
| 159 |     void prepareSymbolInsertStatement() {
 | 
| 160 |         std::stringstream insertSQL;
 | 
| 161 |         insertSQL << "INSERT INTO " << symbolTableName;
 | 
| 162 |         insertSQL << " VALUES(null,@V0);";
 | 
| 163 |         const char* tail = nullptr;
 | 
| 164 |         if (sqlite3_prepare_v2(db, insertSQL.str().c_str(), -1, &symbolInsertStatement, &tail) != SQLITE_OK) {
 | 
| 165 |             throwError("SQLite error in sqlite3_prepare_v2: ");
 | 
| 166 |         }
 | 
| 167 |     }
 | 
| 168 | 
 | 
| 169 |     void prepareSymbolSelectStatement() {
 | 
| 170 |         std::stringstream selectSQL;
 | 
| 171 |         selectSQL << "SELECT id FROM " << symbolTableName;
 | 
| 172 |         selectSQL << " WHERE symbol = @V0;";
 | 
| 173 |         const char* tail = nullptr;
 | 
| 174 |         if (sqlite3_prepare_v2(db, selectSQL.str().c_str(), -1, &symbolSelectStatement, &tail) != SQLITE_OK) {
 | 
| 175 |             throwError("SQLite error in sqlite3_prepare_v2: ");
 | 
| 176 |         }
 | 
| 177 |     }
 | 
| 178 | 
 | 
| 179 |     void prepareInsertStatement() {
 | 
| 180 |         std::stringstream insertSQL;
 | 
| 181 |         insertSQL << "INSERT INTO '_" << relationName << "' VALUES ";
 | 
| 182 |         insertSQL << "(@V0";
 | 
| 183 |         for (unsigned int i = 1; i < arity; i++) {
 | 
| 184 |             insertSQL << ",@V" << i;
 | 
| 185 |         }
 | 
| 186 |         insertSQL << ");";
 | 
| 187 |         const char* tail = nullptr;
 | 
| 188 |         if (sqlite3_prepare_v2(db, insertSQL.str().c_str(), -1, &insertStatement, &tail) != SQLITE_OK) {
 | 
| 189 |             throwError("SQLite error in sqlite3_prepare_v2: ");
 | 
| 190 |         }
 | 
| 191 |     }
 | 
| 192 | 
 | 
| 193 |     void createTables() {
 | 
| 194 |         createRelationTable();
 | 
| 195 |         createRelationView();
 | 
| 196 |         createSymbolTable();
 | 
| 197 |     }
 | 
| 198 | 
 | 
| 199 |     void createRelationTable() {
 | 
| 200 |         std::stringstream createTableText;
 | 
| 201 |         createTableText << "CREATE TABLE IF NOT EXISTS '_" << relationName << "' (";
 | 
| 202 |         if (arity > 0) {
 | 
| 203 |             createTableText << "'0' INTEGER";
 | 
| 204 |             for (unsigned int i = 1; i < arity; i++) {
 | 
| 205 |                 createTableText << ",'" << std::to_string(i) << "' ";
 | 
| 206 |                 createTableText << "INTEGER";
 | 
| 207 |             }
 | 
| 208 |         }
 | 
| 209 |         createTableText << ");";
 | 
| 210 |         executeSQL(createTableText.str(), db);
 | 
| 211 |         executeSQL("DELETE FROM '_" + relationName + "';", db);
 | 
| 212 |     }
 | 
| 213 | 
 | 
| 214 |     void createRelationView() {
 | 
| 215 |         // Create view with symbol strings resolved
 | 
| 216 | 
 | 
| 217 |         const auto columnNames = params["relation"]["params"].array_items();
 | 
| 218 | 
 | 
| 219 |         std::stringstream createViewText;
 | 
| 220 |         createViewText << "CREATE VIEW IF NOT EXISTS '" << relationName << "' AS ";
 | 
| 221 |         std::stringstream projectionClause;
 | 
| 222 |         std::stringstream fromClause;
 | 
| 223 |         fromClause << "'_" << relationName << "'";
 | 
| 224 |         std::stringstream whereClause;
 | 
| 225 |         bool firstWhere = true;
 | 
| 226 |         for (unsigned int i = 0; i < arity; i++) {
 | 
| 227 |             const std::string tableColumnName = std::to_string(i);
 | 
| 228 |             const auto& viewColumnName =
 | 
| 229 |                     (columnNames[i].is_string() ? columnNames[i].string_value() : tableColumnName);
 | 
| 230 |             if (i != 0) {
 | 
| 231 |                 projectionClause << ",";
 | 
| 232 |             }
 | 
| 233 |             if (typeAttributes.at(i)[0] == 's') {
 | 
| 234 |                 projectionClause << "'_symtab_" << tableColumnName << "'.symbol AS '" << viewColumnName
 | 
| 235 |                                  << "'";
 | 
| 236 |                 fromClause << ",'" << symbolTableName << "' AS '_symtab_" << tableColumnName << "'";
 | 
| 237 |                 if (!firstWhere) {
 | 
| 238 |                     whereClause << " AND ";
 | 
| 239 |                 } else {
 | 
| 240 |                     firstWhere = false;
 | 
| 241 |                 }
 | 
| 242 |                 whereClause << "'_" << relationName << "'.'" << tableColumnName << "' = "
 | 
| 243 |                             << "'_symtab_" << tableColumnName << "'.id";
 | 
| 244 |             } else {
 | 
| 245 |                 projectionClause << "'_" << relationName << "'.'" << tableColumnName << "' AS '"
 | 
| 246 |                                  << viewColumnName << "'";
 | 
| 247 |             }
 | 
| 248 |         }
 | 
| 249 |         createViewText << "SELECT " << projectionClause.str() << " FROM " << fromClause.str();
 | 
| 250 |         if (!firstWhere) {
 | 
| 251 |             createViewText << " WHERE " << whereClause.str();
 | 
| 252 |         }
 | 
| 253 |         createViewText << ";";
 | 
| 254 |         executeSQL(createViewText.str(), db);
 | 
| 255 |     }
 | 
| 256 |     void createSymbolTable() {
 | 
| 257 |         std::stringstream createTableText;
 | 
| 258 |         createTableText << "CREATE TABLE IF NOT EXISTS '" << symbolTableName << "' ";
 | 
| 259 |         createTableText << "(id INTEGER PRIMARY KEY, symbol TEXT UNIQUE);";
 | 
| 260 |         executeSQL(createTableText.str(), db);
 | 
| 261 |     }
 | 
| 262 | 
 | 
| 263 |     /**
 | 
| 264 |      * Return given filename or construct from relation name.
 | 
| 265 |      * Default name is [configured path]/[relation name].sqlite
 | 
| 266 |      *
 | 
| 267 |      * @param rwOperation map of IO configuration options
 | 
| 268 |      * @return input filename
 | 
| 269 |      */
 | 
| 270 |     static std::string getFileName(const std::map<std::string, std::string>& rwOperation) {
 | 
| 271 |         // legacy support for SQLite prior to 2020-03-18
 | 
| 272 |         // convert dbname to filename
 | 
| 273 |         auto name = getOr(rwOperation, "dbname", rwOperation.at("name") + ".sqlite");
 | 
| 274 |         name = getOr(rwOperation, "filename", name);
 | 
| 275 | 
 | 
| 276 |         if (name.rfind("file:", 0) == 0 || name.rfind(":memory:", 0) == 0) {
 | 
| 277 |             return name;
 | 
| 278 |         }
 | 
| 279 | 
 | 
| 280 |         if (name.front() != '/') {
 | 
| 281 |             name = getOr(rwOperation, "output-dir", ".") + "/" + name;
 | 
| 282 |         }
 | 
| 283 |         return name;
 | 
| 284 |     }
 | 
| 285 | 
 | 
| 286 |     const std::string dbFilename;
 | 
| 287 |     const std::string relationName;
 | 
| 288 |     const std::string symbolTableName = "__SymbolTable";
 | 
| 289 | 
 | 
| 290 |     std::unordered_map<uint64_t, uint64_t> dbSymbolTable;
 | 
| 291 |     sqlite3_stmt* insertStatement = nullptr;
 | 
| 292 |     sqlite3_stmt* symbolInsertStatement = nullptr;
 | 
| 293 |     sqlite3_stmt* symbolSelectStatement = nullptr;
 | 
| 294 |     sqlite3* db = nullptr;
 | 
| 295 | };
 | 
| 296 | 
 | 
| 297 | class WriteSQLiteFactory : public WriteStreamFactory {
 | 
| 298 | public:
 | 
| 299 |     Own<WriteStream> getWriter(const std::map<std::string, std::string>& rwOperation,
 | 
| 300 |             const SymbolTable& symbolTable, const RecordTable& recordTable) override {
 | 
| 301 |         return mk<WriteStreamSQLite>(rwOperation, symbolTable, recordTable);
 | 
| 302 |     }
 | 
| 303 | 
 | 
| 304 |     const std::string& getName() const override {
 | 
| 305 |         static const std::string name = "sqlite";
 | 
| 306 |         return name;
 | 
| 307 |     }
 | 
| 308 |     ~WriteSQLiteFactory() override = default;
 | 
| 309 | };
 | 
| 310 | 
 | 
| 311 | } /* namespace souffle */
 |