OILS / vendor / souffle / io / ReadStreamJSON.h View on Github | oilshell.org

382 lines, 249 significant
1/*
2 * Souffle - A Datalog Compiler
3 * Copyright (c) 2020, The Souffle Developers. All rights reserved
4 * Licensed under the Universal Permissive License v 1.0 as shown at:
5 * - https://opensource.org/licenses/UPL
6 * - <souffle root>/licenses/SOUFFLE-UPL.txt
7 */
8
9/************************************************************************
10 *
11 * @file ReadStreamJSON.h
12 *
13 ***********************************************************************/
14
15#pragma once
16
17#include "souffle/RamTypes.h"
18#include "souffle/RecordTable.h"
19#include "souffle/SymbolTable.h"
20#include "souffle/io/ReadStream.h"
21#include "souffle/utility/ContainerUtil.h"
22#include "souffle/utility/FileUtil.h"
23#include "souffle/utility/StringUtil.h"
24
25#include <algorithm>
26#include <cassert>
27#include <cstddef>
28#include <cstdint>
29#include <fstream>
30#include <iostream>
31#include <map>
32#include <memory>
33#include <queue>
34#include <sstream>
35#include <stdexcept>
36#include <string>
37#include <tuple>
38#include <vector>
39
40namespace souffle {
41
42template <typename... T>
43[[noreturn]] static void throwError(T const&... t) {
44 std::ostringstream out;
45 (out << ... << t);
46 throw std::runtime_error(out.str());
47}
48
49class ReadStreamJSON : public ReadStream {
50public:
51 ReadStreamJSON(std::istream& file, const std::map<std::string, std::string>& rwOperation,
52 SymbolTable& symbolTable, RecordTable& recordTable)
53 : ReadStream(rwOperation, symbolTable, recordTable), file(file), pos(0), isInitialized(false) {
54 std::string err;
55 params = Json::parse(rwOperation.at("params"), err);
56 if (err.length() > 0) {
57 throwError("cannot get internal params: ", err);
58 }
59 }
60
61protected:
62 std::istream& file;
63 std::size_t pos;
64 Json jsonSource;
65 Json params;
66 bool isInitialized;
67 bool useObjects;
68 std::map<const std::string, const std::size_t> paramIndex;
69
70 Own<RamDomain[]> readNextTuple() override {
71 // for some reasons we cannot initalized our json objects in constructor
72 // otherwise it will segfault, so we initialize in the first call
73 if (!isInitialized) {
74 isInitialized = true;
75 std::string error = "";
76 std::string source(std::istreambuf_iterator<char>(file), {});
77
78 jsonSource = Json::parse(source, error);
79 // it should be wrapped by an extra array
80 if (error.length() > 0 || !jsonSource.is_array()) {
81 throwError("cannot deserialize json because ", error, ":\n", source);
82 }
83
84 if (jsonSource.array_items().empty()) {
85 // No tuples defined
86 return nullptr;
87 }
88
89 // we only check the first one, since there are extra checks
90 // in readNextTupleObject/readNextTupleList
91 if (jsonSource[0].is_array()) {
92 useObjects = false;
93 } else if (jsonSource[0].is_object()) {
94 useObjects = true;
95 std::size_t index_pos = 0;
96 for (auto param : params["relation"]["params"].array_items()) {
97 paramIndex.insert(std::make_pair(param.string_value(), index_pos));
98 index_pos++;
99 }
100 } else {
101 throwError("the input is neither list nor object format");
102 }
103 }
104
105 if (useObjects) {
106 return readNextTupleObject();
107 } else {
108 return readNextTupleList();
109 }
110 }
111
112 Own<RamDomain[]> readNextTupleList() {
113 if (pos >= jsonSource.array_items().size()) {
114 return nullptr;
115 }
116
117 Own<RamDomain[]> tuple = mk<RamDomain[]>(typeAttributes.size());
118 const Json& jsonObj = jsonSource[pos];
119 assert(jsonObj.is_array() && "the input is not json array");
120 pos++;
121 for (std::size_t i = 0; i < typeAttributes.size(); ++i) {
122 try {
123 auto&& ty = typeAttributes.at(i);
124 switch (ty[0]) {
125 case 's': {
126 tuple[i] = symbolTable.encode(jsonObj[i].string_value());
127 break;
128 }
129 case 'r': {
130 tuple[i] = readNextElementList(jsonObj[i], ty);
131 break;
132 }
133 case 'i': {
134 tuple[i] = jsonObj[i].int_value();
135 break;
136 }
137 case 'u': {
138 tuple[i] = jsonObj[i].int_value();
139 break;
140 }
141 case 'f': {
142 tuple[i] = static_cast<RamDomain>(jsonObj[i].number_value());
143 break;
144 }
145 default: throwError("invalid type attribute: '", ty[0], "'");
146 }
147 } catch (...) {
148 std::stringstream errorMessage;
149 if (jsonObj.is_array() && i < jsonObj.array_items().size()) {
150 errorMessage << "Error converting: " << jsonObj[i].dump();
151 } else {
152 errorMessage << "Invalid index: " << i;
153 }
154 throw std::invalid_argument(errorMessage.str());
155 }
156 }
157
158 return tuple;
159 }
160
161 RamDomain readNextElementList(const Json& source, const std::string& recordTypeName) {
162 auto&& recordInfo = types["records"][recordTypeName];
163
164 if (recordInfo.is_null()) {
165 throw std::invalid_argument("Missing record type information: " + recordTypeName);
166 }
167
168 // Handle null case
169 if (source.is_null()) {
170 return 0;
171 }
172
173 assert(source.is_array() && "the input is not json array");
174 auto&& recordTypes = recordInfo["types"];
175 const std::size_t recordArity = recordInfo["arity"].long_value();
176 std::vector<RamDomain> recordValues(recordArity);
177 for (std::size_t i = 0; i < recordArity; ++i) {
178 const std::string& recordType = recordTypes[i].string_value();
179 switch (recordType[0]) {
180 case 's': {
181 recordValues[i] = symbolTable.encode(source[i].string_value());
182 break;
183 }
184 case 'r': {
185 recordValues[i] = readNextElementList(source[i], recordType);
186 break;
187 }
188 case 'i': {
189 recordValues[i] = source[i].int_value();
190 break;
191 }
192 case 'u': {
193 recordValues[i] = source[i].int_value();
194 break;
195 }
196 case 'f': {
197 recordValues[i] = static_cast<RamDomain>(source[i].number_value());
198 break;
199 }
200 default: throwError("invalid type attribute");
201 }
202 }
203
204 return recordTable.pack(recordValues.data(), recordValues.size());
205 }
206
207 Own<RamDomain[]> readNextTupleObject() {
208 if (pos >= jsonSource.array_items().size()) {
209 return nullptr;
210 }
211
212 Own<RamDomain[]> tuple = mk<RamDomain[]>(typeAttributes.size());
213 const Json& jsonObj = jsonSource[pos];
214 assert(jsonObj.is_object() && "the input is not json object");
215 pos++;
216 for (auto p : jsonObj.object_items()) {
217 try {
218 // get the corresponding position by parameter name
219 if (paramIndex.find(p.first) == paramIndex.end()) {
220 throwError("invalid parameter: ", p.first);
221 }
222 std::size_t i = paramIndex.at(p.first);
223 auto&& ty = typeAttributes.at(i);
224 switch (ty[0]) {
225 case 's': {
226 tuple[i] = symbolTable.encode(p.second.string_value());
227 break;
228 }
229 case 'r': {
230 tuple[i] = readNextElementObject(p.second, ty);
231 break;
232 }
233 case 'i': {
234 tuple[i] = p.second.int_value();
235 break;
236 }
237 case 'u': {
238 tuple[i] = p.second.int_value();
239 break;
240 }
241 case 'f': {
242 tuple[i] = static_cast<RamDomain>(p.second.number_value());
243 break;
244 }
245 default: throwError("invalid type attribute: '", ty[0], "'");
246 }
247 } catch (...) {
248 std::stringstream errorMessage;
249 errorMessage << "Error converting: " << p.second.dump();
250 throw std::invalid_argument(errorMessage.str());
251 }
252 }
253
254 return tuple;
255 }
256
257 RamDomain readNextElementObject(const Json& source, const std::string& recordTypeName) {
258 auto&& recordInfo = types["records"][recordTypeName];
259 const std::string recordName = recordTypeName.substr(2);
260 std::map<const std::string, const std::size_t> recordIndex;
261
262 std::size_t index_pos = 0;
263 for (auto param : params["records"][recordName]["params"].array_items()) {
264 recordIndex.insert(std::make_pair(param.string_value(), index_pos));
265 index_pos++;
266 }
267
268 if (recordInfo.is_null()) {
269 throw std::invalid_argument("Missing record type information: " + recordTypeName);
270 }
271
272 // Handle null case
273 if (source.is_null()) {
274 return 0;
275 }
276
277 assert(source.is_object() && "the input is not json object");
278 auto&& recordTypes = recordInfo["types"];
279 const std::size_t recordArity = recordInfo["arity"].long_value();
280 std::vector<RamDomain> recordValues(recordArity);
281 recordValues.reserve(recordIndex.size());
282 for (auto readParam : source.object_items()) {
283 // get the corresponding position by parameter name
284 if (recordIndex.find(readParam.first) == recordIndex.end()) {
285 throwError("invalid parameter: ", readParam.first);
286 }
287 std::size_t i = recordIndex.at(readParam.first);
288 auto&& type = recordTypes[i].string_value();
289 switch (type[0]) {
290 case 's': {
291 recordValues[i] = symbolTable.encode(readParam.second.string_value());
292 break;
293 }
294 case 'r': {
295 recordValues[i] = readNextElementObject(readParam.second, type);
296 break;
297 }
298 case 'i': {
299 recordValues[i] = readParam.second.int_value();
300 break;
301 }
302 case 'u': {
303 recordValues[i] = readParam.second.int_value();
304 break;
305 }
306 case 'f': {
307 recordValues[i] = static_cast<RamDomain>(readParam.second.number_value());
308 break;
309 }
310 default: throwError("invalid type attribute: '", type[0], "'");
311 }
312 }
313
314 return recordTable.pack(recordValues.data(), recordValues.size());
315 }
316};
317
318class ReadFileJSON : public ReadStreamJSON {
319public:
320 ReadFileJSON(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable,
321 RecordTable& recordTable)
322 // FIXME: This is bordering on UB - we're passing an unconstructed
323 // object (fileHandle) to the base class
324 : ReadStreamJSON(fileHandle, rwOperation, symbolTable, recordTable),
325 baseName(souffle::baseName(getFileName(rwOperation))),
326 fileHandle(getFileName(rwOperation), std::ios::in | std::ios::binary) {
327 if (!fileHandle.is_open()) {
328 throw std::invalid_argument("Cannot open json file " + baseName + "\n");
329 }
330 }
331
332 ~ReadFileJSON() override = default;
333
334protected:
335 /**
336 * Return given filename or construct from relation name.
337 * Default name is [configured path]/[relation name].json
338 *
339 * @param rwOperation map of IO configuration options
340 * @return input filename
341 */
342 static std::string getFileName(const std::map<std::string, std::string>& rwOperation) {
343 auto name = getOr(rwOperation, "filename", rwOperation.at("name") + ".json");
344 if (name.front() != '/') {
345 name = getOr(rwOperation, "fact-dir", ".") + "/" + name;
346 }
347 return name;
348 }
349
350 std::string baseName;
351 std::ifstream fileHandle;
352};
353
354class ReadCinJSONFactory : public ReadStreamFactory {
355public:
356 Own<ReadStream> getReader(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable,
357 RecordTable& recordTable) override {
358 return mk<ReadStreamJSON>(std::cin, rwOperation, symbolTable, recordTable);
359 }
360
361 const std::string& getName() const override {
362 static const std::string name = "json";
363 return name;
364 }
365 ~ReadCinJSONFactory() override = default;
366};
367
368class ReadFileJSONFactory : public ReadStreamFactory {
369public:
370 Own<ReadStream> getReader(const std::map<std::string, std::string>& rwOperation, SymbolTable& symbolTable,
371 RecordTable& recordTable) override {
372 return mk<ReadFileJSON>(rwOperation, symbolTable, recordTable);
373 }
374
375 const std::string& getName() const override {
376 static const std::string name = "jsonfile";
377 return name;
378 }
379
380 ~ReadFileJSONFactory() override = default;
381};
382} // namespace souffle