| 1 | /*
 | 
| 2 |  * Souffle - A Datalog Compiler
 | 
| 3 |  * Copyright (c) 2021, The Souffle Developers. All rights reserved
 | 
| 4 |  * Licensed under the Universal Permissive License v 1.0 as shown at:
 | 
| 5 |  * - https://opensource.org/licenses/UPL
 | 
| 6 |  * - <souffle root>/licenses/SOUFFLE-UPL.txt
 | 
| 7 |  */
 | 
| 8 | 
 | 
| 9 | /************************************************************************
 | 
| 10 |  *
 | 
| 11 |  * @file Visitor.h
 | 
| 12 |  *
 | 
| 13 |  * Defines a generic visitor pattern for nodes
 | 
| 14 |  *
 | 
| 15 |  ***********************************************************************/
 | 
| 16 | 
 | 
| 17 | #pragma once
 | 
| 18 | 
 | 
| 19 | #include "souffle/utility/FunctionalUtil.h"
 | 
| 20 | #include "souffle/utility/Types.h"
 | 
| 21 | #include "souffle/utility/VisitorFwd.h"
 | 
| 22 | #include <cassert>
 | 
| 23 | #include <type_traits>
 | 
| 24 | #include <utility>
 | 
| 25 | 
 | 
| 26 | namespace souffle {
 | 
| 27 | 
 | 
| 28 | /**
 | 
| 29 |  * Extension point for visitors.
 | 
| 30 |  * Can be orderloaded in the namespace of Node
 | 
| 31 |  */
 | 
| 32 | template <class Node>
 | 
| 33 | auto getChildNodes(Node&& node) -> decltype(node.getChildNodes()) {
 | 
| 34 |     static_assert(
 | 
| 35 |             std::is_const_v<std::remove_reference_t<Node>> ==
 | 
| 36 |             std::is_const_v<
 | 
| 37 |                     std::remove_pointer_t<std::remove_reference_t<decltype(*node.getChildNodes().begin())>>>);
 | 
| 38 |     return node.getChildNodes();
 | 
| 39 | }
 | 
| 40 | 
 | 
| 41 | namespace detail {
 | 
| 42 | 
 | 
| 43 | /** A tag type required for the is_visitor type trait to identify Visitors */
 | 
| 44 | struct visitor_tag {};
 | 
| 45 | 
 | 
| 46 | template <typename T>
 | 
| 47 | constexpr bool is_visitor = std::is_base_of_v<visitor_tag, remove_cvref_t<T>>;
 | 
| 48 | 
 | 
| 49 | template <typename T, typename U = void>
 | 
| 50 | struct is_visitable_impl : std::false_type {};
 | 
| 51 | 
 | 
| 52 | template <typename T>
 | 
| 53 | struct is_visitable_impl<T, std::void_t<decltype(getChildNodes(std::declval<T const&>()))>>
 | 
| 54 |         : is_range<decltype(getChildNodes(std::declval<T&>()))> {};
 | 
| 55 | 
 | 
| 56 | template <typename A, typename = void>
 | 
| 57 | struct is_visitable_pointer_t : std::false_type {};
 | 
| 58 | 
 | 
| 59 | template <typename A, typename = void>
 | 
| 60 | struct is_visitable_container_t : std::false_type {};
 | 
| 61 | 
 | 
| 62 | template <typename A>
 | 
| 63 | constexpr bool is_visitable = is_visitable_impl<remove_cvref_t<A>>::value;
 | 
| 64 | 
 | 
| 65 | template <typename A>
 | 
| 66 | constexpr bool is_visitable_pointer = is_visitable_pointer_t<A>::value;
 | 
| 67 | 
 | 
| 68 | template <typename A>
 | 
| 69 | constexpr bool is_visitable_element = is_visitable<A> || is_visitable_pointer<A>;
 | 
| 70 | 
 | 
| 71 | template <typename A>
 | 
| 72 | constexpr bool is_visitable_container = is_visitable_container_t<A>::value;
 | 
| 73 | 
 | 
| 74 | template <typename A>
 | 
| 75 | struct is_visitable_pointer_t<A, std::enable_if_t<souffle::is_pointer_like<A>>>
 | 
| 76 |         : std::bool_constant<is_visitable<decltype(*std::declval<A>())>> {};
 | 
| 77 | 
 | 
| 78 | template <typename A>
 | 
| 79 | struct is_visitable_container_t<A, std::void_t<decltype(std::begin(std::declval<A>()))>>
 | 
| 80 |         : std::bool_constant<is_visitable_element<decltype(*std::begin(std::declval<A>()))>> {};
 | 
| 81 | 
 | 
| 82 | template <typename A>
 | 
| 83 | constexpr bool is_visitable_dispatch_root = is_visitable_element<A> || is_visitable_container<A>;
 | 
| 84 | 
 | 
| 85 | // Actual guts of `souffle::Visitor`. Intended to be inherited by the partial specialisations
 | 
| 86 | // for a given node type. (e.g. AST, or RAM.)
 | 
| 87 | // This allows us to use partial template specialisation to pick the correct implementation for a node type.
 | 
| 88 | template <typename R, class NodeType, typename... Params>
 | 
| 89 | struct VisitorBase : visitor_tag {
 | 
| 90 |     using Root = NodeType;
 | 
| 91 |     using RootNoConstQual = std::remove_const_t<NodeType>;
 | 
| 92 |     static_assert(std::is_class_v<Root>);
 | 
| 93 |     static_assert(std::is_same_v<RootNoConstQual, visit_root_type_or_void<RootNoConstQual>>,
 | 
| 94 |             "`RootConstQualified_` must refer to a (const-qual) node root type."
 | 
| 95 |             "(Enables selecting correct visitor impl based on partial template specialisation)");
 | 
| 96 | 
 | 
| 97 |     virtual ~VisitorBase() = default;
 | 
| 98 | 
 | 
| 99 |     /** The main entry for the user allowing visitors to be utilized as functions */
 | 
| 100 |     R operator()(NodeType& node, Params const&... args) {
 | 
| 101 |         return dispatch(node, args...);
 | 
| 102 |     }
 | 
| 103 | 
 | 
| 104 |     /**
 | 
| 105 |      * The main entry for a visit process conducting the dispatching of
 | 
| 106 |      * a visit to the various sub-types of Nodes. Sub-classes may override
 | 
| 107 |      * this implementation to conduct pre-visit operations.
 | 
| 108 |      *
 | 
| 109 |      * @param node the node to be visited
 | 
| 110 |      * @param args a list of extra parameters to be forwarded
 | 
| 111 |      */
 | 
| 112 |     virtual R dispatch(NodeType& node, Params const&... args) = 0;
 | 
| 113 | 
 | 
| 114 |     /** The base case for all visitors -- if no more specific overload was defined */
 | 
| 115 |     virtual R visit_(type_identity<RootNoConstQual>, NodeType& /*node*/, Params const&... /*args*/) {
 | 
| 116 |         if constexpr (std::is_same_v<void, R>) {
 | 
| 117 |             return;
 | 
| 118 |         } else {
 | 
| 119 |             R res{};
 | 
| 120 |             return res;
 | 
| 121 |         }
 | 
| 122 |     }
 | 
| 123 | };
 | 
| 124 | 
 | 
| 125 | }  // namespace detail
 | 
| 126 | 
 | 
| 127 | /**
 | 
| 128 |  * The generic base type of all Visitors realizing the dispatching of
 | 
| 129 |  * visitor calls. Each visitor may define a return type R and a list of
 | 
| 130 |  * extra parameters to be passed along with the visited Nodes to the
 | 
| 131 |  * corresponding visitor function.
 | 
| 132 |  *
 | 
| 133 |  * You must define a partial specialisation of this class for your specific Node-type.
 | 
| 134 |  * This specialisation must derive from `detail::VisitorBase`.
 | 
| 135 |  *
 | 
| 136 |  * @tparam R the result type produced by a visit call
 | 
| 137 |  * @tparam Node the type of the node being visited (can be const qualified)
 | 
| 138 |  * @tparam Params extra parameters to be passed to the visit call
 | 
| 139 |  */
 | 
| 140 | template <typename R, class NodeType, typename... Params>
 | 
| 141 | struct Visitor {
 | 
| 142 |     static_assert(unhandled_dispatch_type<NodeType>,
 | 
| 143 |             "No partial specialisation for node type found. Likely causes:\n"
 | 
| 144 |             "\t* `NodeType` isn't a root type\n"
 | 
| 145 |             "\t* `SOUFFLE_VISITOR_DEFINE_PARTIAL_SPECIALISATION` wasn't applied for `NodeType`");
 | 
| 146 | };
 | 
| 147 | 
 | 
| 148 | #define SOUFFLE_VISITOR_DEFINE_PARTIAL_SPECIALISATION(node, inner_visitor)              \
 | 
| 149 |     template <typename R, typename... Params>                                           \
 | 
| 150 |     struct souffle::Visitor<R, node, Params...> : inner_visitor<R, node, Params...> {}; \
 | 
| 151 |     template <typename R, typename... Params>                                           \
 | 
| 152 |     struct souffle::Visitor<R, node const, Params...> : inner_visitor<R, node const, Params...> {};
 | 
| 153 | 
 | 
| 154 | #define SOUFFLE_VISITOR_FORWARD(Kind) \
 | 
| 155 |     if (auto* n = as<Kind>(node)) return visit_(type_identity<Kind>(), *n, args...);
 | 
| 156 | 
 | 
| 157 | #define SOUFFLE_VISITOR_LINK(Kind, Parent)                                                        \
 | 
| 158 |     virtual R visit_(type_identity<Kind>, copy_const<NodeType, Kind>& n, Params const&... args) { \
 | 
| 159 |         return visit_(type_identity<Parent>(), n, args...);                                       \
 | 
| 160 |     }
 | 
| 161 | 
 | 
| 162 | namespace detail {
 | 
| 163 | 
 | 
| 164 | template <typename A>
 | 
| 165 | using copy_const_visit_root = copy_const<std::remove_reference_t<A>, visit_root_type<A>>;
 | 
| 166 | 
 | 
| 167 | template <typename F>
 | 
| 168 | using infer_visit_arg_type = std::remove_reference_t<typename lambda_traits<F>::template arg<0>>;
 | 
| 169 | 
 | 
| 170 | template <typename F>
 | 
| 171 | using infer_visit_node_type = copy_const_visit_root<infer_visit_arg_type<F>>;
 | 
| 172 | 
 | 
| 173 | template <typename F, typename A>
 | 
| 174 | using visitor_enable_if_arg0 = std::enable_if_t<std::is_same_v<A, remove_cvref_t<infer_visit_arg_type<F>>>>;
 | 
| 175 | 
 | 
| 176 | // useful for implicitly-as-casting continuations
 | 
| 177 | template <typename CastType = void, typename F, typename A>
 | 
| 178 | SOUFFLE_ALWAYS_INLINE auto matchApply([[maybe_unused]] F&& f, [[maybe_unused]] A& x) {
 | 
| 179 |     using FInfo = lambda_traits<F>;
 | 
| 180 |     using Arg = std::remove_reference_t<typename FInfo::template arg<0>>;
 | 
| 181 |     using Result = typename FInfo::result_type;
 | 
| 182 |     constexpr bool can_cast = std::is_base_of_v<A, Arg> || std::is_base_of_v<Arg, A> ||
 | 
| 183 |                               std::is_same_v<CastType, AllowCrossCast>;
 | 
| 184 | 
 | 
| 185 |     if constexpr (std::is_same_v<Result, void>) {
 | 
| 186 |         if constexpr (can_cast) {
 | 
| 187 |             if (auto y = as<Arg, CastType>(&x)) f(*y);
 | 
| 188 |         }
 | 
| 189 |     } else {
 | 
| 190 |         Result result = {};
 | 
| 191 |         if constexpr (can_cast) {
 | 
| 192 |             if (auto y = as<Arg, CastType>(&x)) result = f(*y);
 | 
| 193 |         }
 | 
| 194 |         return result;
 | 
| 195 |     }
 | 
| 196 | }
 | 
| 197 | 
 | 
| 198 | /**
 | 
| 199 |  * A specialized visitor wrapping a lambda function -- an auxiliary type required
 | 
| 200 |  * for visitor convenience functions.
 | 
| 201 |  */
 | 
| 202 | template <typename F, typename Node, typename CastType>
 | 
| 203 | struct LambdaVisitor : public Visitor<void, Node> {
 | 
| 204 |     F lambda;
 | 
| 205 |     LambdaVisitor(F lam) : lambda(std::move(lam)) {}
 | 
| 206 | 
 | 
| 207 |     void dispatch(Node& node) override {
 | 
| 208 |         matchApply<CastType>(lambda, node);
 | 
| 209 |     }
 | 
| 210 | };
 | 
| 211 | 
 | 
| 212 | /**
 | 
| 213 |  * A factory function for creating LambdaVisitor instances.
 | 
| 214 |  */
 | 
| 215 | template <typename CastType = void, typename F, typename Node = infer_visit_node_type<F>>
 | 
| 216 | auto makeLambdaVisitor(F&& f) {
 | 
| 217 |     return LambdaVisitor<std::decay_t<F>, copy_const_visit_root<Node>, CastType>(std::forward<F>(f));
 | 
| 218 | }
 | 
| 219 | 
 | 
| 220 | }  // namespace detail
 | 
| 221 | 
 | 
| 222 | /**
 | 
| 223 |  * A utility function visiting all nodes within a given container of root nodes
 | 
| 224 |  * recursively in a depth-first pre-order fashion applying the given function to each
 | 
| 225 |  * encountered node.
 | 
| 226 |  *
 | 
| 227 |  * @param xs the root(s) of the ASTs to be visited
 | 
| 228 |  * @param fun the function to be applied
 | 
| 229 |  * @param args a list of extra parameters to be forwarded to the visitor
 | 
| 230 |  */
 | 
| 231 | template <typename CC, typename Visitor, typename... Args,
 | 
| 232 |         typename = std::enable_if_t<detail::is_visitable_dispatch_root<CC>>,
 | 
| 233 |         typename = std::enable_if_t<detail::is_visitor<Visitor>>>
 | 
| 234 | void visit(CC&& xs, Visitor&& visitor, Args&&... args) {
 | 
| 235 |     if constexpr (detail::is_visitable<CC>) {
 | 
| 236 |         visitor(xs, args...);
 | 
| 237 |         souffle::visit(getChildNodes(xs), std::forward<Visitor>(visitor), std::forward<Args>(args)...);
 | 
| 238 |     } else if constexpr (detail::is_visitable_pointer<CC>) {
 | 
| 239 |         static_assert(std::is_lvalue_reference_v<CC> || std::is_pointer_v<remove_cvref_t<CC>>,
 | 
| 240 |                 "root isn't an l-value `Own<Node>` or a raw pointer. this is likely a mistake: an r-value "
 | 
| 241 |                 "root means the tree will be destroyed after the call");
 | 
| 242 |         assert(xs);
 | 
| 243 |         souffle::visit(*xs, std::forward<Visitor>(visitor), std::forward<Args>(args)...);
 | 
| 244 |     } else if constexpr (detail::is_visitable_container<CC>) {
 | 
| 245 |         for (auto&& x : xs) {
 | 
| 246 |             // FIXME: Remove this once nodes are converted to references
 | 
| 247 |             if constexpr (is_pointer_like<decltype(x)>) {
 | 
| 248 |                 assert(x && "found null node during traversal");
 | 
| 249 |             }
 | 
| 250 | 
 | 
| 251 |             souffle::visit(x, visitor, args...);
 | 
| 252 |         }
 | 
| 253 |     } else
 | 
| 254 |         static_assert(unhandled_dispatch_type<CC>);
 | 
| 255 | }
 | 
| 256 | 
 | 
| 257 | /**
 | 
| 258 |  * A utility function visiting all nodes within the given root
 | 
| 259 |  * recursively in a depth-first pre-order fashion, applying the given visitor to each
 | 
| 260 |  * encountered node.
 | 
| 261 |  *
 | 
| 262 |  * @param root the root of the structure to be visited
 | 
| 263 |  * @param visitor the visitor to be applied on each node
 | 
| 264 |  */
 | 
| 265 | template <typename CastType = void, class CC, typename F,
 | 
| 266 |         typename = std::enable_if_t<detail::is_visitable_dispatch_root<CC>>,
 | 
| 267 |         typename = std::enable_if_t<!detail::is_visitor<F>>>
 | 
| 268 | void visit(CC&& root, F&& fun) {
 | 
| 269 |     using Arg = std::remove_reference_t<typename lambda_traits<F>::template arg<0>>;
 | 
| 270 |     using Node = detail::visit_root_type_or_void<Arg>;
 | 
| 271 |     static_assert(!std::is_same_v<Node, void>, "unable to infer node type from function's first argument");
 | 
| 272 | 
 | 
| 273 |     souffle::visit(std::forward<CC>(root),
 | 
| 274 |             detail::makeLambdaVisitor<CastType,
 | 
| 275 |                     std::conditional_t<std::is_rvalue_reference_v<F>, std::remove_reference_t<F>, F>,
 | 
| 276 |                     copy_const<Arg, Node>>(std::forward<F>(fun)));
 | 
| 277 | }
 | 
| 278 | 
 | 
| 279 | /**
 | 
| 280 |  * Traverses tree in pre-order, stopping immediately if the predicate matches and returns true.
 | 
| 281 |  * Returns true IFF the predicate matched and returned true at least once.
 | 
| 282 |  *
 | 
| 283 |  * @param xs the root(s) of the ASTs to be visited
 | 
| 284 |  * @param f the predicate to be applied
 | 
| 285 |  */
 | 
| 286 | template <typename CC, typename F /* exists B <: Node. B& -> bool */,
 | 
| 287 |         typename = std::enable_if_t<detail::is_visitable_dispatch_root<CC>>>
 | 
| 288 | bool visitExists(CC&& xs, F&& f) {
 | 
| 289 |     if constexpr (detail::is_visitable_node<CC>) {
 | 
| 290 |         if (detail::matchApply(f, xs)) return true;
 | 
| 291 | 
 | 
| 292 |         return visitExists(getChildNodes(xs), std::forward<F>(f));
 | 
| 293 |     } else if constexpr (detail::is_visitable_pointer<CC>) {
 | 
| 294 |         static_assert(std::is_lvalue_reference_v<CC> || std::is_pointer_v<remove_cvref_t<CC>>,
 | 
| 295 |                 "arg isn't an l-value `Own<Node>` or a raw pointer. this is likely a mistake: an r-value "
 | 
| 296 |                 "root means the tree will be destroyed after the call");
 | 
| 297 |         assert(xs);
 | 
| 298 |         return visitExists(*xs, f);
 | 
| 299 |     } else if constexpr (detail::is_visitable_container<CC>) {
 | 
| 300 |         for (auto&& x : xs)
 | 
| 301 |             if (visitExists(x, f)) return true;
 | 
| 302 | 
 | 
| 303 |         return false;
 | 
| 304 |     } else
 | 
| 305 |         static_assert(unhandled_dispatch_type<CC>);
 | 
| 306 | }
 | 
| 307 | 
 | 
| 308 | /**
 | 
| 309 |  * Traverses tree in pre-order, stopping immediately if the predicate matches and returns false.
 | 
| 310 |  * Returns true IFF the predicate never matched and returned false.
 | 
| 311 |  *
 | 
| 312 |  * @param xs the root(s) of the ASTs to be visited
 | 
| 313 |  * @param f the predicate to be applied
 | 
| 314 |  */
 | 
| 315 | template <typename CC, typename F /* exists B <: Node. B& -> bool */,
 | 
| 316 |         typename = std::enable_if_t<detail::is_visitable_dispatch_root<CC>>>
 | 
| 317 | bool visitForAll(CC&& xs, F&& pred) {
 | 
| 318 |     using Arg = remove_cvref_t<typename lambda_traits<F>::template arg<0>>;
 | 
| 319 |     return !visitExists(std::forward<CC>(xs), [&](const Arg& x) { return !pred(x); });
 | 
| 320 | }
 | 
| 321 | 
 | 
| 322 | /**
 | 
| 323 |  * Traverse tree in pre-order. Skips sub-trees when predicate matches and returns true.
 | 
| 324 |  * Returns # of frontiers found. (i.e. # of times predicate returned true.)
 | 
| 325 |  *
 | 
| 326 |  * @param xs the root(s) of the ASTs to be visited
 | 
| 327 |  * @param f the `is-frontier` predicate
 | 
| 328 |  */
 | 
| 329 | template <typename CC, typename F /* exists B <: Node. B& -> bool */,
 | 
| 330 |         typename = std::enable_if_t<detail::is_visitable_dispatch_root<CC>>>
 | 
| 331 | size_t visitFrontier(CC&& xs, F&& f) {
 | 
| 332 |     if constexpr (detail::is_visitable_node<CC>) {
 | 
| 333 |         if (detail::matchApply(f, xs)) return 1;
 | 
| 334 | 
 | 
| 335 |         return visitFrontier(getChildNodes(xs), f);
 | 
| 336 |     } else if constexpr (detail::is_visitable_pointer<CC>) {
 | 
| 337 |         static_assert(std::is_lvalue_reference_v<CC> || std::is_pointer_v<remove_cvref_t<CC>>,
 | 
| 338 |                 "arg isn't an l-value `Own<Node>` or a raw pointer. this is likely a mistake: an r-value "
 | 
| 339 |                 "root means the tree will be destroyed after the call");
 | 
| 340 |         assert(xs);
 | 
| 341 |         return visitFrontier(*xs, f);
 | 
| 342 |     } else if constexpr (detail::is_visitable_container<CC>) {
 | 
| 343 |         size_t count = 0;
 | 
| 344 |         for (auto&& x : xs)
 | 
| 345 |             count += visitFrontier(x, f);
 | 
| 346 | 
 | 
| 347 |         return count;
 | 
| 348 |     } else
 | 
| 349 |         static_assert(unhandled_dispatch_type<CC>);
 | 
| 350 | }
 | 
| 351 | 
 | 
| 352 | }  // namespace souffle
 |