| 1 | /*
 | 
| 2 |  * Souffle - A Datalog Compiler
 | 
| 3 |  * Copyright (c) 2021, The Souffle Developers. All rights reserved
 | 
| 4 |  * Licensed under the Universal Permissive License v 1.0 as shown at:
 | 
| 5 |  * - https://opensource.org/licenses/UPL
 | 
| 6 |  * - <souffle root>/licenses/SOUFFLE-UPL.txt
 | 
| 7 |  */
 | 
| 8 | 
 | 
| 9 | /************************************************************************
 | 
| 10 |  *
 | 
| 11 |  * @file WriteStream.h
 | 
| 12 |  *
 | 
| 13 |  ***********************************************************************/
 | 
| 14 | 
 | 
| 15 | #pragma once
 | 
| 16 | 
 | 
| 17 | #include "souffle/RamTypes.h"
 | 
| 18 | #include "souffle/RecordTable.h"
 | 
| 19 | #include "souffle/SymbolTable.h"
 | 
| 20 | #include "souffle/io/SerialisationStream.h"
 | 
| 21 | #include "souffle/utility/MiscUtil.h"
 | 
| 22 | #include "souffle/utility/json11.h"
 | 
| 23 | #include <cassert>
 | 
| 24 | #include <cstddef>
 | 
| 25 | #include <iomanip>
 | 
| 26 | #include <map>
 | 
| 27 | #include <memory>
 | 
| 28 | #include <ostream>
 | 
| 29 | #include <string>
 | 
| 30 | 
 | 
| 31 | namespace souffle {
 | 
| 32 | 
 | 
| 33 | using json11::Json;
 | 
| 34 | 
 | 
| 35 | class WriteStream : public SerialisationStream<true> {
 | 
| 36 | public:
 | 
| 37 |     WriteStream(const std::map<std::string, std::string>& rwOperation, const SymbolTable& symbolTable,
 | 
| 38 |             const RecordTable& recordTable)
 | 
| 39 |             : SerialisationStream(symbolTable, recordTable, rwOperation),
 | 
| 40 |               summary(rwOperation.at("IO") == "stdoutprintsize") {}
 | 
| 41 | 
 | 
| 42 |     template <typename T>
 | 
| 43 |     void writeAll(const T& relation) {
 | 
| 44 |         if (summary) {
 | 
| 45 |             return writeSize(relation.size());
 | 
| 46 |         }
 | 
| 47 |         if (arity == 0) {
 | 
| 48 |             if (relation.begin() != relation.end()) {
 | 
| 49 |                 writeNullary();
 | 
| 50 |             }
 | 
| 51 |             return;
 | 
| 52 |         }
 | 
| 53 |         for (const auto& current : relation) {
 | 
| 54 |             writeNext(current);
 | 
| 55 |         }
 | 
| 56 |     }
 | 
| 57 | 
 | 
| 58 |     template <typename T>
 | 
| 59 |     void writeSize(const T& relation) {
 | 
| 60 |         writeSize(relation.size());
 | 
| 61 |     }
 | 
| 62 | 
 | 
| 63 | protected:
 | 
| 64 |     const bool summary;
 | 
| 65 | 
 | 
| 66 |     virtual void writeNullary() = 0;
 | 
| 67 |     virtual void writeNextTuple(const RamDomain* tuple) = 0;
 | 
| 68 |     virtual void writeSize(std::size_t) {
 | 
| 69 |         fatal("attempting to print size of a write operation");
 | 
| 70 |     }
 | 
| 71 | 
 | 
| 72 |     template <typename Tuple>
 | 
| 73 |     void writeNext(const Tuple tuple) {
 | 
| 74 |         using tcb::make_span;
 | 
| 75 |         writeNextTuple(make_span(tuple).data());
 | 
| 76 |     }
 | 
| 77 | 
 | 
| 78 |     virtual void outputSymbol(std::ostream& destination, const std::string& value) {
 | 
| 79 |         destination << value;
 | 
| 80 |     }
 | 
| 81 | 
 | 
| 82 |     void outputRecord(std::ostream& destination, const RamDomain value, const std::string& name) {
 | 
| 83 |         auto&& recordInfo = types["records"][name];
 | 
| 84 | 
 | 
| 85 |         // Check if record type information are present
 | 
| 86 |         assert(!recordInfo.is_null() && "Missing record type information");
 | 
| 87 | 
 | 
| 88 |         // Check for nil
 | 
| 89 |         if (value == 0) {
 | 
| 90 |             destination << "nil";
 | 
| 91 |             return;
 | 
| 92 |         }
 | 
| 93 | 
 | 
| 94 |         auto&& recordTypes = recordInfo["types"];
 | 
| 95 |         const std::size_t recordArity = recordInfo["arity"].long_value();
 | 
| 96 | 
 | 
| 97 |         const RamDomain* tuplePtr = recordTable.unpack(value, recordArity);
 | 
| 98 | 
 | 
| 99 |         destination << "[";
 | 
| 100 | 
 | 
| 101 |         // print record's elements
 | 
| 102 |         for (std::size_t i = 0; i < recordArity; ++i) {
 | 
| 103 |             if (i > 0) {
 | 
| 104 |                 destination << ", ";
 | 
| 105 |             }
 | 
| 106 | 
 | 
| 107 |             const std::string& recordType = recordTypes[i].string_value();
 | 
| 108 |             const RamDomain recordValue = tuplePtr[i];
 | 
| 109 | 
 | 
| 110 |             switch (recordType[0]) {
 | 
| 111 |                 case 'i': destination << recordValue; break;
 | 
| 112 |                 case 'f': destination << ramBitCast<RamFloat>(recordValue); break;
 | 
| 113 |                 case 'u': destination << ramBitCast<RamUnsigned>(recordValue); break;
 | 
| 114 |                 case 's': outputSymbol(destination, symbolTable.decode(recordValue)); break;
 | 
| 115 |                 case 'r': outputRecord(destination, recordValue, recordType); break;
 | 
| 116 |                 case '+': outputADT(destination, recordValue, recordType); break;
 | 
| 117 |                 default: fatal("Unsupported type attribute: `%c`", recordType[0]);
 | 
| 118 |             }
 | 
| 119 |         }
 | 
| 120 |         destination << "]";
 | 
| 121 |     }
 | 
| 122 | 
 | 
| 123 |     void outputADT(std::ostream& destination, const RamDomain value, const std::string& name) {
 | 
| 124 |         auto&& adtInfo = types["ADTs"][name];
 | 
| 125 | 
 | 
| 126 |         assert(!adtInfo.is_null() && "Missing adt type information");
 | 
| 127 |         assert(adtInfo["arity"].long_value() > 0);
 | 
| 128 | 
 | 
| 129 |         // adt is encoded in one of three possible ways:
 | 
| 130 |         // [branchID, [branch_args]] when |branch_args| != 1
 | 
| 131 |         // [branchID, arg] when a branch takes a single argument.
 | 
| 132 |         // branchID when ADT is an enumeration.
 | 
| 133 |         bool isEnum = adtInfo["enum"].bool_value();
 | 
| 134 | 
 | 
| 135 |         RamDomain branchId = value;
 | 
| 136 |         const RamDomain* branchArgs = nullptr;
 | 
| 137 |         json11::Json branchInfo;
 | 
| 138 |         json11::Json::array branchTypes;
 | 
| 139 | 
 | 
| 140 |         if (!isEnum) {
 | 
| 141 |             const RamDomain* tuplePtr = recordTable.unpack(value, 2);
 | 
| 142 | 
 | 
| 143 |             branchId = tuplePtr[0];
 | 
| 144 |             branchInfo = adtInfo["branches"][branchId];
 | 
| 145 |             branchTypes = branchInfo["types"].array_items();
 | 
| 146 | 
 | 
| 147 |             // Prepare branch's arguments for output.
 | 
| 148 |             branchArgs = [&]() -> const RamDomain* {
 | 
| 149 |                 if (branchTypes.size() > 1) {
 | 
| 150 |                     return recordTable.unpack(tuplePtr[1], branchTypes.size());
 | 
| 151 |                 } else {
 | 
| 152 |                     return &tuplePtr[1];
 | 
| 153 |                 }
 | 
| 154 |             }();
 | 
| 155 |         } else {
 | 
| 156 |             branchInfo = adtInfo["branches"][branchId];
 | 
| 157 |             branchTypes = branchInfo["types"].array_items();
 | 
| 158 |         }
 | 
| 159 | 
 | 
| 160 |         destination << "$" << branchInfo["name"].string_value();
 | 
| 161 | 
 | 
| 162 |         if (branchTypes.size() > 0) {
 | 
| 163 |             destination << "(";
 | 
| 164 |         }
 | 
| 165 | 
 | 
| 166 |         // Print arguments
 | 
| 167 |         for (std::size_t i = 0; i < branchTypes.size(); ++i) {
 | 
| 168 |             if (i > 0) {
 | 
| 169 |                 destination << ", ";
 | 
| 170 |             }
 | 
| 171 | 
 | 
| 172 |             auto argType = branchTypes[i].string_value();
 | 
| 173 |             switch (argType[0]) {
 | 
| 174 |                 case 'i': destination << branchArgs[i]; break;
 | 
| 175 |                 case 'f': destination << ramBitCast<RamFloat>(branchArgs[i]); break;
 | 
| 176 |                 case 'u': destination << ramBitCast<RamUnsigned>(branchArgs[i]); break;
 | 
| 177 |                 case 's': outputSymbol(destination, symbolTable.decode(branchArgs[i])); break;
 | 
| 178 |                 case 'r': outputRecord(destination, branchArgs[i], argType); break;
 | 
| 179 |                 case '+': outputADT(destination, branchArgs[i], argType); break;
 | 
| 180 |                 default: fatal("Unsupported type attribute: `%c`", argType[0]);
 | 
| 181 |             }
 | 
| 182 |         }
 | 
| 183 | 
 | 
| 184 |         if (branchTypes.size() > 0) {
 | 
| 185 |             destination << ")";
 | 
| 186 |         }
 | 
| 187 |     }
 | 
| 188 | };
 | 
| 189 | 
 | 
| 190 | class WriteStreamFactory {
 | 
| 191 | public:
 | 
| 192 |     virtual Own<WriteStream> getWriter(const std::map<std::string, std::string>& rwOperation,
 | 
| 193 |             const SymbolTable& symbolTable, const RecordTable& recordTable) = 0;
 | 
| 194 | 
 | 
| 195 |     virtual const std::string& getName() const = 0;
 | 
| 196 |     virtual ~WriteStreamFactory() = default;
 | 
| 197 | };
 | 
| 198 | 
 | 
| 199 | template <>
 | 
| 200 | inline void WriteStream::writeNext(const RamDomain* tuple) {
 | 
| 201 |     writeNextTuple(tuple);
 | 
| 202 | }
 | 
| 203 | 
 | 
| 204 | } /* namespace souffle */
 |