OILS / vendor / souffle / io / WriteStreamCSV.h View on Github | oilshell.org

309 lines, 195 significant
1/*
2 * Souffle - A Datalog Compiler
3 * Copyright (c) 2021, The Souffle Developers. All rights reserved
4 * Licensed under the Universal Permissive License v 1.0 as shown at:
5 * - https://opensource.org/licenses/UPL
6 * - <souffle root>/licenses/SOUFFLE-UPL.txt
7 */
8
9/************************************************************************
10 *
11 * @file WriteStreamCSV.h
12 *
13 ***********************************************************************/
14
15#pragma once
16
17#include "souffle/RamTypes.h"
18#include "souffle/RecordTable.h"
19#include "souffle/SymbolTable.h"
20#include "souffle/io/WriteStream.h"
21#include "souffle/utility/ContainerUtil.h"
22#include "souffle/utility/MiscUtil.h"
23#include "souffle/utility/ParallelUtil.h"
24#ifdef USE_LIBZ
25#include "souffle/io/gzfstream.h"
26#endif
27
28#include <cstddef>
29#include <fstream>
30#include <iomanip>
31#include <iostream>
32#include <map>
33#include <ostream>
34#include <string>
35#include <vector>
36
37namespace souffle {
38
39class WriteStreamCSV : public WriteStream {
40protected:
41 WriteStreamCSV(const std::map<std::string, std::string>& rwOperation, const SymbolTable& symbolTable,
42 const RecordTable& recordTable)
43 : WriteStream(rwOperation, symbolTable, recordTable),
44 rfc4180(getOr(rwOperation, "rfc4180", "false") == std::string("true")),
45 delimiter(getOr(rwOperation, "delimiter", (rfc4180 ? "," : "\t"))) {
46 if (rfc4180 && delimiter.find('"') != std::string::npos) {
47 std::stringstream errorMessage;
48 errorMessage << "CSV delimiter cannot contain '\"' character when rfc4180 is enabled.";
49 throw std::invalid_argument(errorMessage.str());
50 }
51 };
52
53 const bool rfc4180;
54
55 const std::string delimiter;
56
57 void writeNextTupleCSV(std::ostream& destination, const RamDomain* tuple) {
58 writeNextTupleElement(destination, typeAttributes.at(0), tuple[0]);
59
60 for (std::size_t col = 1; col < arity; ++col) {
61 destination << delimiter;
62 writeNextTupleElement(destination, typeAttributes.at(col), tuple[col]);
63 }
64
65 destination << "\n";
66 }
67
68 virtual void outputSymbol(std::ostream& destination, const std::string& value) {
69 outputSymbol(destination, value, false);
70 }
71
72 void outputSymbol(std::ostream& destination, const std::string& value, bool fieldValue) {
73 if (rfc4180) {
74 if (!fieldValue) {
75 destination << '"';
76 }
77 destination << '"';
78
79 const std::size_t end = value.length();
80 for (std::size_t pos = 0; pos < end; ++pos) {
81 char ch = value[pos];
82 if (ch == '"') {
83 destination << '\\';
84 destination << '"';
85 }
86 destination << ch;
87 }
88
89 if (!fieldValue) {
90 destination << '"';
91 }
92 destination << '"';
93 } else {
94 destination << value;
95 }
96 }
97
98 void writeNextTupleElement(std::ostream& destination, const std::string& type, RamDomain value) {
99 switch (type[0]) {
100 case 's': outputSymbol(destination, symbolTable.decode(value), true); break;
101 case 'i': destination << value; break;
102 case 'u': destination << ramBitCast<RamUnsigned>(value); break;
103 case 'f': destination << ramBitCast<RamFloat>(value); break;
104 case 'r':
105 if (rfc4180) {
106 destination << '"';
107 }
108 outputRecord(destination, value, type);
109 if (rfc4180) {
110 destination << '"';
111 }
112 break;
113 case '+':
114 if (rfc4180) {
115 destination << '"';
116 }
117 outputADT(destination, value, type);
118 if (rfc4180) {
119 destination << '"';
120 }
121 break;
122 default: fatal("unsupported type attribute: `%c`", type[0]);
123 }
124 }
125};
126
127class WriteFileCSV : public WriteStreamCSV {
128public:
129 WriteFileCSV(const std::map<std::string, std::string>& rwOperation, const SymbolTable& symbolTable,
130 const RecordTable& recordTable)
131 : WriteStreamCSV(rwOperation, symbolTable, recordTable),
132 file(getFileName(rwOperation), std::ios::out | std::ios::binary) {
133 if (getOr(rwOperation, "headers", "false") == "true") {
134 file << rwOperation.at("attributeNames") << std::endl;
135 }
136 file << std::setprecision(std::numeric_limits<RamFloat>::max_digits10);
137 }
138
139 ~WriteFileCSV() override = default;
140
141protected:
142 std::ofstream file;
143
144 void writeNullary() override {
145 file << "()\n";
146 }
147
148 void writeNextTuple(const RamDomain* tuple) override {
149 writeNextTupleCSV(file, tuple);
150 }
151
152 /**
153 * Return given filename or construct from relation name.
154 * Default name is [configured path]/[relation name].csv
155 *
156 * @param rwOperation map of IO configuration options
157 * @return input filename
158 */
159 static std::string getFileName(const std::map<std::string, std::string>& rwOperation) {
160 auto name = getOr(rwOperation, "filename", rwOperation.at("name") + ".csv");
161 if (name.front() != '/') {
162 name = getOr(rwOperation, "output-dir", ".") + "/" + name;
163 }
164 return name;
165 }
166};
167
168#ifdef USE_LIBZ
169class WriteGZipFileCSV : public WriteStreamCSV {
170public:
171 WriteGZipFileCSV(const std::map<std::string, std::string>& rwOperation, const SymbolTable& symbolTable,
172 const RecordTable& recordTable)
173 : WriteStreamCSV(rwOperation, symbolTable, recordTable),
174 file(getFileName(rwOperation), std::ios::out | std::ios::binary) {
175 if (getOr(rwOperation, "headers", "false") == "true") {
176 file << rwOperation.at("attributeNames") << std::endl;
177 }
178 file << std::setprecision(std::numeric_limits<RamFloat>::max_digits10);
179 }
180
181 ~WriteGZipFileCSV() override = default;
182
183protected:
184 void writeNullary() override {
185 file << "()\n";
186 }
187
188 void writeNextTuple(const RamDomain* tuple) override {
189 writeNextTupleCSV(file, tuple);
190 }
191
192 /**
193 * Return given filename or construct from relation name.
194 * Default name is [configured path]/[relation name].csv
195 *
196 * @param rwOperation map of IO configuration options
197 * @return input filename
198 */
199 static std::string getFileName(const std::map<std::string, std::string>& rwOperation) {
200 auto name = getOr(rwOperation, "filename", rwOperation.at("name") + ".csv.gz");
201 if (name.front() != '/') {
202 name = getOr(rwOperation, "output-dir", ".") + "/" + name;
203 }
204 return name;
205 }
206
207 gzfstream::ogzfstream file;
208};
209#endif
210
211class WriteCoutCSV : public WriteStreamCSV {
212public:
213 WriteCoutCSV(const std::map<std::string, std::string>& rwOperation, const SymbolTable& symbolTable,
214 const RecordTable& recordTable)
215 : WriteStreamCSV(rwOperation, symbolTable, recordTable) {
216 std::cout << "---------------\n" << rwOperation.at("name");
217 if (getOr(rwOperation, "headers", "false") == "true") {
218 std::cout << "\n" << rwOperation.at("attributeNames");
219 }
220 std::cout << "\n===============\n";
221 std::cout << std::setprecision(std::numeric_limits<RamFloat>::max_digits10);
222 }
223
224 ~WriteCoutCSV() override {
225 std::cout << "===============\n";
226 }
227
228protected:
229 void writeNullary() override {
230 std::cout << "()\n";
231 }
232
233 void writeNextTuple(const RamDomain* tuple) override {
234 writeNextTupleCSV(std::cout, tuple);
235 }
236};
237
238class WriteCoutPrintSize : public WriteStream {
239public:
240 WriteCoutPrintSize(const std::map<std::string, std::string>& rwOperation, const SymbolTable& symbolTable,
241 const RecordTable& recordTable)
242 : WriteStream(rwOperation, symbolTable, recordTable), lease(souffle::getOutputLock().acquire()) {
243 std::cout << rwOperation.at("name") << "\t";
244 }
245
246 ~WriteCoutPrintSize() override = default;
247
248protected:
249 void writeNullary() override {
250 fatal("attempting to iterate over a print size operation");
251 }
252
253 void writeNextTuple(const RamDomain* /* tuple */) override {
254 fatal("attempting to iterate over a print size operation");
255 }
256
257 void writeSize(std::size_t size) override {
258 std::cout << size << "\n";
259 }
260
261 Lock::Lease lease;
262};
263
264class WriteFileCSVFactory : public WriteStreamFactory {
265public:
266 Own<WriteStream> getWriter(const std::map<std::string, std::string>& rwOperation,
267 const SymbolTable& symbolTable, const RecordTable& recordTable) override {
268#ifdef USE_LIBZ
269 if (contains(rwOperation, "compress")) {
270 return mk<WriteGZipFileCSV>(rwOperation, symbolTable, recordTable);
271 }
272#endif
273 return mk<WriteFileCSV>(rwOperation, symbolTable, recordTable);
274 }
275 const std::string& getName() const override {
276 static const std::string name = "file";
277 return name;
278 }
279 ~WriteFileCSVFactory() override = default;
280};
281
282class WriteCoutCSVFactory : public WriteStreamFactory {
283public:
284 Own<WriteStream> getWriter(const std::map<std::string, std::string>& rwOperation,
285 const SymbolTable& symbolTable, const RecordTable& recordTable) override {
286 return mk<WriteCoutCSV>(rwOperation, symbolTable, recordTable);
287 }
288
289 const std::string& getName() const override {
290 static const std::string name = "stdout";
291 return name;
292 }
293 ~WriteCoutCSVFactory() override = default;
294};
295
296class WriteCoutPrintSizeFactory : public WriteStreamFactory {
297public:
298 Own<WriteStream> getWriter(const std::map<std::string, std::string>& rwOperation,
299 const SymbolTable& symbolTable, const RecordTable& recordTable) override {
300 return mk<WriteCoutPrintSize>(rwOperation, symbolTable, recordTable);
301 }
302 const std::string& getName() const override {
303 static const std::string name = "stdoutprintsize";
304 return name;
305 }
306 ~WriteCoutPrintSizeFactory() override = default;
307};
308
309} /* namespace souffle */