OILS / vendor / souffle / io / WriteStreamSQLite.h View on Github | oilshell.org

311 lines, 209 significant
1/*
2 * Souffle - A Datalog Compiler
3 * Copyright (c) 2021, The Souffle Developers. All rights reserved
4 * Licensed under the Universal Permissive License v 1.0 as shown at:
5 * - https://opensource.org/licenses/UPL
6 * - <souffle root>/licenses/SOUFFLE-UPL.txt
7 */
8
9/************************************************************************
10 *
11 * @file WriteStreamSQLite.h
12 *
13 ***********************************************************************/
14
15#pragma once
16
17#include "souffle/RamTypes.h"
18#include "souffle/RecordTable.h"
19#include "souffle/SymbolTable.h"
20#include "souffle/io/WriteStream.h"
21#include <cassert>
22#include <cstddef>
23#include <cstdint>
24#include <map>
25#include <memory>
26#include <sstream>
27#include <stdexcept>
28#include <string>
29#include <unordered_map>
30#include <vector>
31#include <sqlite3.h>
32
33namespace souffle {
34
35class WriteStreamSQLite : public WriteStream {
36public:
37 WriteStreamSQLite(const std::map<std::string, std::string>& rwOperation, const SymbolTable& symbolTable,
38 const RecordTable& recordTable)
39 : WriteStream(rwOperation, symbolTable, recordTable), dbFilename(getFileName(rwOperation)),
40 relationName(rwOperation.at("name")) {
41 openDB();
42 createTables();
43 prepareStatements();
44 executeSQL("BEGIN TRANSACTION", db);
45 }
46
47 ~WriteStreamSQLite() override {
48 executeSQL("COMMIT", db);
49 sqlite3_finalize(insertStatement);
50 sqlite3_finalize(symbolInsertStatement);
51 sqlite3_finalize(symbolSelectStatement);
52 sqlite3_close(db);
53 }
54
55protected:
56 void writeNullary() override {}
57
58 void writeNextTuple(const RamDomain* tuple) override {
59 for (std::size_t i = 0; i < arity; i++) {
60 RamDomain value = 0; // Silence warning
61
62 switch (typeAttributes.at(i)[0]) {
63 case 's': value = getSymbolTableID(tuple[i]); break;
64 default: value = tuple[i]; break;
65 }
66
67#if RAM_DOMAIN_SIZE == 64
68 if (sqlite3_bind_int64(insertStatement, static_cast<int>(i + 1),
69 static_cast<sqlite3_int64>(value)) != SQLITE_OK) {
70#else
71 if (sqlite3_bind_int(insertStatement, static_cast<int>(i + 1), static_cast<int>(value)) !=
72 SQLITE_OK) {
73#endif
74 throwError("SQLite error in sqlite3_bind_text: ");
75 }
76 }
77 if (sqlite3_step(insertStatement) != SQLITE_DONE) {
78 throwError("SQLite error in sqlite3_step: ");
79 }
80 sqlite3_clear_bindings(insertStatement);
81 sqlite3_reset(insertStatement);
82 }
83
84private:
85 void executeSQL(const std::string& sql, sqlite3* db) {
86 assert(db && "Database connection is closed");
87
88 char* errorMessage = nullptr;
89 /* Execute SQL statement */
90 int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &errorMessage);
91 if (rc != SQLITE_OK) {
92 std::stringstream error;
93 error << "SQLite error in sqlite3_exec: " << sqlite3_errmsg(db) << "\n";
94 error << "SQL error: " << errorMessage << "\n";
95 error << "SQL: " << sql << "\n";
96 sqlite3_free(errorMessage);
97 throw std::invalid_argument(error.str());
98 }
99 }
100
101 void throwError(const std::string& message) {
102 std::stringstream error;
103 error << message << sqlite3_errmsg(db) << "\n";
104 throw std::invalid_argument(error.str());
105 }
106
107 uint64_t getSymbolTableIDFromDB(std::size_t index) {
108 if (sqlite3_bind_text(symbolSelectStatement, 1, symbolTable.decode(index).c_str(), -1,
109 SQLITE_TRANSIENT) != SQLITE_OK) {
110 throwError("SQLite error in sqlite3_bind_text: ");
111 }
112 if (sqlite3_step(symbolSelectStatement) != SQLITE_ROW) {
113 throwError("SQLite error in sqlite3_step: ");
114 }
115 uint64_t rowid = sqlite3_column_int64(symbolSelectStatement, 0);
116 sqlite3_clear_bindings(symbolSelectStatement);
117 sqlite3_reset(symbolSelectStatement);
118 return rowid;
119 }
120 uint64_t getSymbolTableID(std::size_t index) {
121 if (dbSymbolTable.count(index) != 0) {
122 return dbSymbolTable[index];
123 }
124
125 if (sqlite3_bind_text(symbolInsertStatement, 1, symbolTable.decode(index).c_str(), -1,
126 SQLITE_TRANSIENT) != SQLITE_OK) {
127 throwError("SQLite error in sqlite3_bind_text: ");
128 }
129 // Either the insert succeeds and we have a new row id or it already exists and a select is needed.
130 uint64_t rowid;
131 if (sqlite3_step(symbolInsertStatement) != SQLITE_DONE) {
132 // The symbol already exists so select it.
133 rowid = getSymbolTableIDFromDB(index);
134 } else {
135 rowid = sqlite3_last_insert_rowid(db);
136 }
137 sqlite3_clear_bindings(symbolInsertStatement);
138 sqlite3_reset(symbolInsertStatement);
139
140 dbSymbolTable[index] = rowid;
141 return rowid;
142 }
143
144 void openDB() {
145 sqlite3_config(SQLITE_CONFIG_URI, 1);
146 if (sqlite3_open(dbFilename.c_str(), &db) != SQLITE_OK) {
147 throwError("SQLite error in sqlite3_open: ");
148 }
149 sqlite3_extended_result_codes(db, 1);
150 executeSQL("PRAGMA synchronous = OFF", db);
151 executeSQL("PRAGMA journal_mode = MEMORY", db);
152 }
153
154 void prepareStatements() {
155 prepareInsertStatement();
156 prepareSymbolInsertStatement();
157 prepareSymbolSelectStatement();
158 }
159 void prepareSymbolInsertStatement() {
160 std::stringstream insertSQL;
161 insertSQL << "INSERT INTO " << symbolTableName;
162 insertSQL << " VALUES(null,@V0);";
163 const char* tail = nullptr;
164 if (sqlite3_prepare_v2(db, insertSQL.str().c_str(), -1, &symbolInsertStatement, &tail) != SQLITE_OK) {
165 throwError("SQLite error in sqlite3_prepare_v2: ");
166 }
167 }
168
169 void prepareSymbolSelectStatement() {
170 std::stringstream selectSQL;
171 selectSQL << "SELECT id FROM " << symbolTableName;
172 selectSQL << " WHERE symbol = @V0;";
173 const char* tail = nullptr;
174 if (sqlite3_prepare_v2(db, selectSQL.str().c_str(), -1, &symbolSelectStatement, &tail) != SQLITE_OK) {
175 throwError("SQLite error in sqlite3_prepare_v2: ");
176 }
177 }
178
179 void prepareInsertStatement() {
180 std::stringstream insertSQL;
181 insertSQL << "INSERT INTO '_" << relationName << "' VALUES ";
182 insertSQL << "(@V0";
183 for (unsigned int i = 1; i < arity; i++) {
184 insertSQL << ",@V" << i;
185 }
186 insertSQL << ");";
187 const char* tail = nullptr;
188 if (sqlite3_prepare_v2(db, insertSQL.str().c_str(), -1, &insertStatement, &tail) != SQLITE_OK) {
189 throwError("SQLite error in sqlite3_prepare_v2: ");
190 }
191 }
192
193 void createTables() {
194 createRelationTable();
195 createRelationView();
196 createSymbolTable();
197 }
198
199 void createRelationTable() {
200 std::stringstream createTableText;
201 createTableText << "CREATE TABLE IF NOT EXISTS '_" << relationName << "' (";
202 if (arity > 0) {
203 createTableText << "'0' INTEGER";
204 for (unsigned int i = 1; i < arity; i++) {
205 createTableText << ",'" << std::to_string(i) << "' ";
206 createTableText << "INTEGER";
207 }
208 }
209 createTableText << ");";
210 executeSQL(createTableText.str(), db);
211 executeSQL("DELETE FROM '_" + relationName + "';", db);
212 }
213
214 void createRelationView() {
215 // Create view with symbol strings resolved
216
217 const auto columnNames = params["relation"]["params"].array_items();
218
219 std::stringstream createViewText;
220 createViewText << "CREATE VIEW IF NOT EXISTS '" << relationName << "' AS ";
221 std::stringstream projectionClause;
222 std::stringstream fromClause;
223 fromClause << "'_" << relationName << "'";
224 std::stringstream whereClause;
225 bool firstWhere = true;
226 for (unsigned int i = 0; i < arity; i++) {
227 const std::string tableColumnName = std::to_string(i);
228 const auto& viewColumnName =
229 (columnNames[i].is_string() ? columnNames[i].string_value() : tableColumnName);
230 if (i != 0) {
231 projectionClause << ",";
232 }
233 if (typeAttributes.at(i)[0] == 's') {
234 projectionClause << "'_symtab_" << tableColumnName << "'.symbol AS '" << viewColumnName
235 << "'";
236 fromClause << ",'" << symbolTableName << "' AS '_symtab_" << tableColumnName << "'";
237 if (!firstWhere) {
238 whereClause << " AND ";
239 } else {
240 firstWhere = false;
241 }
242 whereClause << "'_" << relationName << "'.'" << tableColumnName << "' = "
243 << "'_symtab_" << tableColumnName << "'.id";
244 } else {
245 projectionClause << "'_" << relationName << "'.'" << tableColumnName << "' AS '"
246 << viewColumnName << "'";
247 }
248 }
249 createViewText << "SELECT " << projectionClause.str() << " FROM " << fromClause.str();
250 if (!firstWhere) {
251 createViewText << " WHERE " << whereClause.str();
252 }
253 createViewText << ";";
254 executeSQL(createViewText.str(), db);
255 }
256 void createSymbolTable() {
257 std::stringstream createTableText;
258 createTableText << "CREATE TABLE IF NOT EXISTS '" << symbolTableName << "' ";
259 createTableText << "(id INTEGER PRIMARY KEY, symbol TEXT UNIQUE);";
260 executeSQL(createTableText.str(), db);
261 }
262
263 /**
264 * Return given filename or construct from relation name.
265 * Default name is [configured path]/[relation name].sqlite
266 *
267 * @param rwOperation map of IO configuration options
268 * @return input filename
269 */
270 static std::string getFileName(const std::map<std::string, std::string>& rwOperation) {
271 // legacy support for SQLite prior to 2020-03-18
272 // convert dbname to filename
273 auto name = getOr(rwOperation, "dbname", rwOperation.at("name") + ".sqlite");
274 name = getOr(rwOperation, "filename", name);
275
276 if (name.rfind("file:", 0) == 0 || name.rfind(":memory:", 0) == 0) {
277 return name;
278 }
279
280 if (name.front() != '/') {
281 name = getOr(rwOperation, "output-dir", ".") + "/" + name;
282 }
283 return name;
284 }
285
286 const std::string dbFilename;
287 const std::string relationName;
288 const std::string symbolTableName = "__SymbolTable";
289
290 std::unordered_map<uint64_t, uint64_t> dbSymbolTable;
291 sqlite3_stmt* insertStatement = nullptr;
292 sqlite3_stmt* symbolInsertStatement = nullptr;
293 sqlite3_stmt* symbolSelectStatement = nullptr;
294 sqlite3* db = nullptr;
295};
296
297class WriteSQLiteFactory : public WriteStreamFactory {
298public:
299 Own<WriteStream> getWriter(const std::map<std::string, std::string>& rwOperation,
300 const SymbolTable& symbolTable, const RecordTable& recordTable) override {
301 return mk<WriteStreamSQLite>(rwOperation, symbolTable, recordTable);
302 }
303
304 const std::string& getName() const override {
305 static const std::string name = "sqlite";
306 return name;
307 }
308 ~WriteSQLiteFactory() override = default;
309};
310
311} /* namespace souffle */