1 | /*
|
2 | * Souffle - A Datalog Compiler
|
3 | * Copyright (c) 2021, The Souffle Developers. All rights reserved
|
4 | * Licensed under the Universal Permissive License v 1.0 as shown at:
|
5 | * - https://opensource.org/licenses/UPL
|
6 | * - <souffle root>/licenses/SOUFFLE-UPL.txt
|
7 | */
|
8 |
|
9 | /************************************************************************
|
10 | *
|
11 | * @file WriteStreamSQLite.h
|
12 | *
|
13 | ***********************************************************************/
|
14 |
|
15 | #pragma once
|
16 |
|
17 | #include "souffle/RamTypes.h"
|
18 | #include "souffle/RecordTable.h"
|
19 | #include "souffle/SymbolTable.h"
|
20 | #include "souffle/io/WriteStream.h"
|
21 | #include <cassert>
|
22 | #include <cstddef>
|
23 | #include <cstdint>
|
24 | #include <map>
|
25 | #include <memory>
|
26 | #include <sstream>
|
27 | #include <stdexcept>
|
28 | #include <string>
|
29 | #include <unordered_map>
|
30 | #include <vector>
|
31 | #include <sqlite3.h>
|
32 |
|
33 | namespace souffle {
|
34 |
|
35 | class WriteStreamSQLite : public WriteStream {
|
36 | public:
|
37 | WriteStreamSQLite(const std::map<std::string, std::string>& rwOperation, const SymbolTable& symbolTable,
|
38 | const RecordTable& recordTable)
|
39 | : WriteStream(rwOperation, symbolTable, recordTable), dbFilename(getFileName(rwOperation)),
|
40 | relationName(rwOperation.at("name")) {
|
41 | openDB();
|
42 | createTables();
|
43 | prepareStatements();
|
44 | executeSQL("BEGIN TRANSACTION", db);
|
45 | }
|
46 |
|
47 | ~WriteStreamSQLite() override {
|
48 | executeSQL("COMMIT", db);
|
49 | sqlite3_finalize(insertStatement);
|
50 | sqlite3_finalize(symbolInsertStatement);
|
51 | sqlite3_finalize(symbolSelectStatement);
|
52 | sqlite3_close(db);
|
53 | }
|
54 |
|
55 | protected:
|
56 | void writeNullary() override {}
|
57 |
|
58 | void writeNextTuple(const RamDomain* tuple) override {
|
59 | for (std::size_t i = 0; i < arity; i++) {
|
60 | RamDomain value = 0; // Silence warning
|
61 |
|
62 | switch (typeAttributes.at(i)[0]) {
|
63 | case 's': value = getSymbolTableID(tuple[i]); break;
|
64 | default: value = tuple[i]; break;
|
65 | }
|
66 |
|
67 | #if RAM_DOMAIN_SIZE == 64
|
68 | if (sqlite3_bind_int64(insertStatement, static_cast<int>(i + 1),
|
69 | static_cast<sqlite3_int64>(value)) != SQLITE_OK) {
|
70 | #else
|
71 | if (sqlite3_bind_int(insertStatement, static_cast<int>(i + 1), static_cast<int>(value)) !=
|
72 | SQLITE_OK) {
|
73 | #endif
|
74 | throwError("SQLite error in sqlite3_bind_text: ");
|
75 | }
|
76 | }
|
77 | if (sqlite3_step(insertStatement) != SQLITE_DONE) {
|
78 | throwError("SQLite error in sqlite3_step: ");
|
79 | }
|
80 | sqlite3_clear_bindings(insertStatement);
|
81 | sqlite3_reset(insertStatement);
|
82 | }
|
83 |
|
84 | private:
|
85 | void executeSQL(const std::string& sql, sqlite3* db) {
|
86 | assert(db && "Database connection is closed");
|
87 |
|
88 | char* errorMessage = nullptr;
|
89 | /* Execute SQL statement */
|
90 | int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &errorMessage);
|
91 | if (rc != SQLITE_OK) {
|
92 | std::stringstream error;
|
93 | error << "SQLite error in sqlite3_exec: " << sqlite3_errmsg(db) << "\n";
|
94 | error << "SQL error: " << errorMessage << "\n";
|
95 | error << "SQL: " << sql << "\n";
|
96 | sqlite3_free(errorMessage);
|
97 | throw std::invalid_argument(error.str());
|
98 | }
|
99 | }
|
100 |
|
101 | void throwError(const std::string& message) {
|
102 | std::stringstream error;
|
103 | error << message << sqlite3_errmsg(db) << "\n";
|
104 | throw std::invalid_argument(error.str());
|
105 | }
|
106 |
|
107 | uint64_t getSymbolTableIDFromDB(std::size_t index) {
|
108 | if (sqlite3_bind_text(symbolSelectStatement, 1, symbolTable.decode(index).c_str(), -1,
|
109 | SQLITE_TRANSIENT) != SQLITE_OK) {
|
110 | throwError("SQLite error in sqlite3_bind_text: ");
|
111 | }
|
112 | if (sqlite3_step(symbolSelectStatement) != SQLITE_ROW) {
|
113 | throwError("SQLite error in sqlite3_step: ");
|
114 | }
|
115 | uint64_t rowid = sqlite3_column_int64(symbolSelectStatement, 0);
|
116 | sqlite3_clear_bindings(symbolSelectStatement);
|
117 | sqlite3_reset(symbolSelectStatement);
|
118 | return rowid;
|
119 | }
|
120 | uint64_t getSymbolTableID(std::size_t index) {
|
121 | if (dbSymbolTable.count(index) != 0) {
|
122 | return dbSymbolTable[index];
|
123 | }
|
124 |
|
125 | if (sqlite3_bind_text(symbolInsertStatement, 1, symbolTable.decode(index).c_str(), -1,
|
126 | SQLITE_TRANSIENT) != SQLITE_OK) {
|
127 | throwError("SQLite error in sqlite3_bind_text: ");
|
128 | }
|
129 | // Either the insert succeeds and we have a new row id or it already exists and a select is needed.
|
130 | uint64_t rowid;
|
131 | if (sqlite3_step(symbolInsertStatement) != SQLITE_DONE) {
|
132 | // The symbol already exists so select it.
|
133 | rowid = getSymbolTableIDFromDB(index);
|
134 | } else {
|
135 | rowid = sqlite3_last_insert_rowid(db);
|
136 | }
|
137 | sqlite3_clear_bindings(symbolInsertStatement);
|
138 | sqlite3_reset(symbolInsertStatement);
|
139 |
|
140 | dbSymbolTable[index] = rowid;
|
141 | return rowid;
|
142 | }
|
143 |
|
144 | void openDB() {
|
145 | sqlite3_config(SQLITE_CONFIG_URI, 1);
|
146 | if (sqlite3_open(dbFilename.c_str(), &db) != SQLITE_OK) {
|
147 | throwError("SQLite error in sqlite3_open: ");
|
148 | }
|
149 | sqlite3_extended_result_codes(db, 1);
|
150 | executeSQL("PRAGMA synchronous = OFF", db);
|
151 | executeSQL("PRAGMA journal_mode = MEMORY", db);
|
152 | }
|
153 |
|
154 | void prepareStatements() {
|
155 | prepareInsertStatement();
|
156 | prepareSymbolInsertStatement();
|
157 | prepareSymbolSelectStatement();
|
158 | }
|
159 | void prepareSymbolInsertStatement() {
|
160 | std::stringstream insertSQL;
|
161 | insertSQL << "INSERT INTO " << symbolTableName;
|
162 | insertSQL << " VALUES(null,@V0);";
|
163 | const char* tail = nullptr;
|
164 | if (sqlite3_prepare_v2(db, insertSQL.str().c_str(), -1, &symbolInsertStatement, &tail) != SQLITE_OK) {
|
165 | throwError("SQLite error in sqlite3_prepare_v2: ");
|
166 | }
|
167 | }
|
168 |
|
169 | void prepareSymbolSelectStatement() {
|
170 | std::stringstream selectSQL;
|
171 | selectSQL << "SELECT id FROM " << symbolTableName;
|
172 | selectSQL << " WHERE symbol = @V0;";
|
173 | const char* tail = nullptr;
|
174 | if (sqlite3_prepare_v2(db, selectSQL.str().c_str(), -1, &symbolSelectStatement, &tail) != SQLITE_OK) {
|
175 | throwError("SQLite error in sqlite3_prepare_v2: ");
|
176 | }
|
177 | }
|
178 |
|
179 | void prepareInsertStatement() {
|
180 | std::stringstream insertSQL;
|
181 | insertSQL << "INSERT INTO '_" << relationName << "' VALUES ";
|
182 | insertSQL << "(@V0";
|
183 | for (unsigned int i = 1; i < arity; i++) {
|
184 | insertSQL << ",@V" << i;
|
185 | }
|
186 | insertSQL << ");";
|
187 | const char* tail = nullptr;
|
188 | if (sqlite3_prepare_v2(db, insertSQL.str().c_str(), -1, &insertStatement, &tail) != SQLITE_OK) {
|
189 | throwError("SQLite error in sqlite3_prepare_v2: ");
|
190 | }
|
191 | }
|
192 |
|
193 | void createTables() {
|
194 | createRelationTable();
|
195 | createRelationView();
|
196 | createSymbolTable();
|
197 | }
|
198 |
|
199 | void createRelationTable() {
|
200 | std::stringstream createTableText;
|
201 | createTableText << "CREATE TABLE IF NOT EXISTS '_" << relationName << "' (";
|
202 | if (arity > 0) {
|
203 | createTableText << "'0' INTEGER";
|
204 | for (unsigned int i = 1; i < arity; i++) {
|
205 | createTableText << ",'" << std::to_string(i) << "' ";
|
206 | createTableText << "INTEGER";
|
207 | }
|
208 | }
|
209 | createTableText << ");";
|
210 | executeSQL(createTableText.str(), db);
|
211 | executeSQL("DELETE FROM '_" + relationName + "';", db);
|
212 | }
|
213 |
|
214 | void createRelationView() {
|
215 | // Create view with symbol strings resolved
|
216 |
|
217 | const auto columnNames = params["relation"]["params"].array_items();
|
218 |
|
219 | std::stringstream createViewText;
|
220 | createViewText << "CREATE VIEW IF NOT EXISTS '" << relationName << "' AS ";
|
221 | std::stringstream projectionClause;
|
222 | std::stringstream fromClause;
|
223 | fromClause << "'_" << relationName << "'";
|
224 | std::stringstream whereClause;
|
225 | bool firstWhere = true;
|
226 | for (unsigned int i = 0; i < arity; i++) {
|
227 | const std::string tableColumnName = std::to_string(i);
|
228 | const auto& viewColumnName =
|
229 | (columnNames[i].is_string() ? columnNames[i].string_value() : tableColumnName);
|
230 | if (i != 0) {
|
231 | projectionClause << ",";
|
232 | }
|
233 | if (typeAttributes.at(i)[0] == 's') {
|
234 | projectionClause << "'_symtab_" << tableColumnName << "'.symbol AS '" << viewColumnName
|
235 | << "'";
|
236 | fromClause << ",'" << symbolTableName << "' AS '_symtab_" << tableColumnName << "'";
|
237 | if (!firstWhere) {
|
238 | whereClause << " AND ";
|
239 | } else {
|
240 | firstWhere = false;
|
241 | }
|
242 | whereClause << "'_" << relationName << "'.'" << tableColumnName << "' = "
|
243 | << "'_symtab_" << tableColumnName << "'.id";
|
244 | } else {
|
245 | projectionClause << "'_" << relationName << "'.'" << tableColumnName << "' AS '"
|
246 | << viewColumnName << "'";
|
247 | }
|
248 | }
|
249 | createViewText << "SELECT " << projectionClause.str() << " FROM " << fromClause.str();
|
250 | if (!firstWhere) {
|
251 | createViewText << " WHERE " << whereClause.str();
|
252 | }
|
253 | createViewText << ";";
|
254 | executeSQL(createViewText.str(), db);
|
255 | }
|
256 | void createSymbolTable() {
|
257 | std::stringstream createTableText;
|
258 | createTableText << "CREATE TABLE IF NOT EXISTS '" << symbolTableName << "' ";
|
259 | createTableText << "(id INTEGER PRIMARY KEY, symbol TEXT UNIQUE);";
|
260 | executeSQL(createTableText.str(), db);
|
261 | }
|
262 |
|
263 | /**
|
264 | * Return given filename or construct from relation name.
|
265 | * Default name is [configured path]/[relation name].sqlite
|
266 | *
|
267 | * @param rwOperation map of IO configuration options
|
268 | * @return input filename
|
269 | */
|
270 | static std::string getFileName(const std::map<std::string, std::string>& rwOperation) {
|
271 | // legacy support for SQLite prior to 2020-03-18
|
272 | // convert dbname to filename
|
273 | auto name = getOr(rwOperation, "dbname", rwOperation.at("name") + ".sqlite");
|
274 | name = getOr(rwOperation, "filename", name);
|
275 |
|
276 | if (name.rfind("file:", 0) == 0 || name.rfind(":memory:", 0) == 0) {
|
277 | return name;
|
278 | }
|
279 |
|
280 | if (name.front() != '/') {
|
281 | name = getOr(rwOperation, "output-dir", ".") + "/" + name;
|
282 | }
|
283 | return name;
|
284 | }
|
285 |
|
286 | const std::string dbFilename;
|
287 | const std::string relationName;
|
288 | const std::string symbolTableName = "__SymbolTable";
|
289 |
|
290 | std::unordered_map<uint64_t, uint64_t> dbSymbolTable;
|
291 | sqlite3_stmt* insertStatement = nullptr;
|
292 | sqlite3_stmt* symbolInsertStatement = nullptr;
|
293 | sqlite3_stmt* symbolSelectStatement = nullptr;
|
294 | sqlite3* db = nullptr;
|
295 | };
|
296 |
|
297 | class WriteSQLiteFactory : public WriteStreamFactory {
|
298 | public:
|
299 | Own<WriteStream> getWriter(const std::map<std::string, std::string>& rwOperation,
|
300 | const SymbolTable& symbolTable, const RecordTable& recordTable) override {
|
301 | return mk<WriteStreamSQLite>(rwOperation, symbolTable, recordTable);
|
302 | }
|
303 |
|
304 | const std::string& getName() const override {
|
305 | static const std::string name = "sqlite";
|
306 | return name;
|
307 | }
|
308 | ~WriteSQLiteFactory() override = default;
|
309 | };
|
310 |
|
311 | } /* namespace souffle */
|