OILS / vendor / souffle / utility / Visitor.h View on Github | oilshell.org

352 lines, 182 significant
1/*
2 * Souffle - A Datalog Compiler
3 * Copyright (c) 2021, The Souffle Developers. All rights reserved
4 * Licensed under the Universal Permissive License v 1.0 as shown at:
5 * - https://opensource.org/licenses/UPL
6 * - <souffle root>/licenses/SOUFFLE-UPL.txt
7 */
8
9/************************************************************************
10 *
11 * @file Visitor.h
12 *
13 * Defines a generic visitor pattern for nodes
14 *
15 ***********************************************************************/
16
17#pragma once
18
19#include "souffle/utility/FunctionalUtil.h"
20#include "souffle/utility/Types.h"
21#include "souffle/utility/VisitorFwd.h"
22#include <cassert>
23#include <type_traits>
24#include <utility>
25
26namespace souffle {
27
28/**
29 * Extension point for visitors.
30 * Can be orderloaded in the namespace of Node
31 */
32template <class Node>
33auto getChildNodes(Node&& node) -> decltype(node.getChildNodes()) {
34 static_assert(
35 std::is_const_v<std::remove_reference_t<Node>> ==
36 std::is_const_v<
37 std::remove_pointer_t<std::remove_reference_t<decltype(*node.getChildNodes().begin())>>>);
38 return node.getChildNodes();
39}
40
41namespace detail {
42
43/** A tag type required for the is_visitor type trait to identify Visitors */
44struct visitor_tag {};
45
46template <typename T>
47constexpr bool is_visitor = std::is_base_of_v<visitor_tag, remove_cvref_t<T>>;
48
49template <typename T, typename U = void>
50struct is_visitable_impl : std::false_type {};
51
52template <typename T>
53struct is_visitable_impl<T, std::void_t<decltype(getChildNodes(std::declval<T const&>()))>>
54 : is_range<decltype(getChildNodes(std::declval<T&>()))> {};
55
56template <typename A, typename = void>
57struct is_visitable_pointer_t : std::false_type {};
58
59template <typename A, typename = void>
60struct is_visitable_container_t : std::false_type {};
61
62template <typename A>
63constexpr bool is_visitable = is_visitable_impl<remove_cvref_t<A>>::value;
64
65template <typename A>
66constexpr bool is_visitable_pointer = is_visitable_pointer_t<A>::value;
67
68template <typename A>
69constexpr bool is_visitable_element = is_visitable<A> || is_visitable_pointer<A>;
70
71template <typename A>
72constexpr bool is_visitable_container = is_visitable_container_t<A>::value;
73
74template <typename A>
75struct is_visitable_pointer_t<A, std::enable_if_t<souffle::is_pointer_like<A>>>
76 : std::bool_constant<is_visitable<decltype(*std::declval<A>())>> {};
77
78template <typename A>
79struct is_visitable_container_t<A, std::void_t<decltype(std::begin(std::declval<A>()))>>
80 : std::bool_constant<is_visitable_element<decltype(*std::begin(std::declval<A>()))>> {};
81
82template <typename A>
83constexpr bool is_visitable_dispatch_root = is_visitable_element<A> || is_visitable_container<A>;
84
85// Actual guts of `souffle::Visitor`. Intended to be inherited by the partial specialisations
86// for a given node type. (e.g. AST, or RAM.)
87// This allows us to use partial template specialisation to pick the correct implementation for a node type.
88template <typename R, class NodeType, typename... Params>
89struct VisitorBase : visitor_tag {
90 using Root = NodeType;
91 using RootNoConstQual = std::remove_const_t<NodeType>;
92 static_assert(std::is_class_v<Root>);
93 static_assert(std::is_same_v<RootNoConstQual, visit_root_type_or_void<RootNoConstQual>>,
94 "`RootConstQualified_` must refer to a (const-qual) node root type."
95 "(Enables selecting correct visitor impl based on partial template specialisation)");
96
97 virtual ~VisitorBase() = default;
98
99 /** The main entry for the user allowing visitors to be utilized as functions */
100 R operator()(NodeType& node, Params const&... args) {
101 return dispatch(node, args...);
102 }
103
104 /**
105 * The main entry for a visit process conducting the dispatching of
106 * a visit to the various sub-types of Nodes. Sub-classes may override
107 * this implementation to conduct pre-visit operations.
108 *
109 * @param node the node to be visited
110 * @param args a list of extra parameters to be forwarded
111 */
112 virtual R dispatch(NodeType& node, Params const&... args) = 0;
113
114 /** The base case for all visitors -- if no more specific overload was defined */
115 virtual R visit_(type_identity<RootNoConstQual>, NodeType& /*node*/, Params const&... /*args*/) {
116 if constexpr (std::is_same_v<void, R>) {
117 return;
118 } else {
119 R res{};
120 return res;
121 }
122 }
123};
124
125} // namespace detail
126
127/**
128 * The generic base type of all Visitors realizing the dispatching of
129 * visitor calls. Each visitor may define a return type R and a list of
130 * extra parameters to be passed along with the visited Nodes to the
131 * corresponding visitor function.
132 *
133 * You must define a partial specialisation of this class for your specific Node-type.
134 * This specialisation must derive from `detail::VisitorBase`.
135 *
136 * @tparam R the result type produced by a visit call
137 * @tparam Node the type of the node being visited (can be const qualified)
138 * @tparam Params extra parameters to be passed to the visit call
139 */
140template <typename R, class NodeType, typename... Params>
141struct Visitor {
142 static_assert(unhandled_dispatch_type<NodeType>,
143 "No partial specialisation for node type found. Likely causes:\n"
144 "\t* `NodeType` isn't a root type\n"
145 "\t* `SOUFFLE_VISITOR_DEFINE_PARTIAL_SPECIALISATION` wasn't applied for `NodeType`");
146};
147
148#define SOUFFLE_VISITOR_DEFINE_PARTIAL_SPECIALISATION(node, inner_visitor) \
149 template <typename R, typename... Params> \
150 struct souffle::Visitor<R, node, Params...> : inner_visitor<R, node, Params...> {}; \
151 template <typename R, typename... Params> \
152 struct souffle::Visitor<R, node const, Params...> : inner_visitor<R, node const, Params...> {};
153
154#define SOUFFLE_VISITOR_FORWARD(Kind) \
155 if (auto* n = as<Kind>(node)) return visit_(type_identity<Kind>(), *n, args...);
156
157#define SOUFFLE_VISITOR_LINK(Kind, Parent) \
158 virtual R visit_(type_identity<Kind>, copy_const<NodeType, Kind>& n, Params const&... args) { \
159 return visit_(type_identity<Parent>(), n, args...); \
160 }
161
162namespace detail {
163
164template <typename A>
165using copy_const_visit_root = copy_const<std::remove_reference_t<A>, visit_root_type<A>>;
166
167template <typename F>
168using infer_visit_arg_type = std::remove_reference_t<typename lambda_traits<F>::template arg<0>>;
169
170template <typename F>
171using infer_visit_node_type = copy_const_visit_root<infer_visit_arg_type<F>>;
172
173template <typename F, typename A>
174using visitor_enable_if_arg0 = std::enable_if_t<std::is_same_v<A, remove_cvref_t<infer_visit_arg_type<F>>>>;
175
176// useful for implicitly-as-casting continuations
177template <typename CastType = void, typename F, typename A>
178SOUFFLE_ALWAYS_INLINE auto matchApply([[maybe_unused]] F&& f, [[maybe_unused]] A& x) {
179 using FInfo = lambda_traits<F>;
180 using Arg = std::remove_reference_t<typename FInfo::template arg<0>>;
181 using Result = typename FInfo::result_type;
182 constexpr bool can_cast = std::is_base_of_v<A, Arg> || std::is_base_of_v<Arg, A> ||
183 std::is_same_v<CastType, AllowCrossCast>;
184
185 if constexpr (std::is_same_v<Result, void>) {
186 if constexpr (can_cast) {
187 if (auto y = as<Arg, CastType>(&x)) f(*y);
188 }
189 } else {
190 Result result = {};
191 if constexpr (can_cast) {
192 if (auto y = as<Arg, CastType>(&x)) result = f(*y);
193 }
194 return result;
195 }
196}
197
198/**
199 * A specialized visitor wrapping a lambda function -- an auxiliary type required
200 * for visitor convenience functions.
201 */
202template <typename F, typename Node, typename CastType>
203struct LambdaVisitor : public Visitor<void, Node> {
204 F lambda;
205 LambdaVisitor(F lam) : lambda(std::move(lam)) {}
206
207 void dispatch(Node& node) override {
208 matchApply<CastType>(lambda, node);
209 }
210};
211
212/**
213 * A factory function for creating LambdaVisitor instances.
214 */
215template <typename CastType = void, typename F, typename Node = infer_visit_node_type<F>>
216auto makeLambdaVisitor(F&& f) {
217 return LambdaVisitor<std::decay_t<F>, copy_const_visit_root<Node>, CastType>(std::forward<F>(f));
218}
219
220} // namespace detail
221
222/**
223 * A utility function visiting all nodes within a given container of root nodes
224 * recursively in a depth-first pre-order fashion applying the given function to each
225 * encountered node.
226 *
227 * @param xs the root(s) of the ASTs to be visited
228 * @param fun the function to be applied
229 * @param args a list of extra parameters to be forwarded to the visitor
230 */
231template <typename CC, typename Visitor, typename... Args,
232 typename = std::enable_if_t<detail::is_visitable_dispatch_root<CC>>,
233 typename = std::enable_if_t<detail::is_visitor<Visitor>>>
234void visit(CC&& xs, Visitor&& visitor, Args&&... args) {
235 if constexpr (detail::is_visitable<CC>) {
236 visitor(xs, args...);
237 souffle::visit(getChildNodes(xs), std::forward<Visitor>(visitor), std::forward<Args>(args)...);
238 } else if constexpr (detail::is_visitable_pointer<CC>) {
239 static_assert(std::is_lvalue_reference_v<CC> || std::is_pointer_v<remove_cvref_t<CC>>,
240 "root isn't an l-value `Own<Node>` or a raw pointer. this is likely a mistake: an r-value "
241 "root means the tree will be destroyed after the call");
242 assert(xs);
243 souffle::visit(*xs, std::forward<Visitor>(visitor), std::forward<Args>(args)...);
244 } else if constexpr (detail::is_visitable_container<CC>) {
245 for (auto&& x : xs) {
246 // FIXME: Remove this once nodes are converted to references
247 if constexpr (is_pointer_like<decltype(x)>) {
248 assert(x && "found null node during traversal");
249 }
250
251 souffle::visit(x, visitor, args...);
252 }
253 } else
254 static_assert(unhandled_dispatch_type<CC>);
255}
256
257/**
258 * A utility function visiting all nodes within the given root
259 * recursively in a depth-first pre-order fashion, applying the given visitor to each
260 * encountered node.
261 *
262 * @param root the root of the structure to be visited
263 * @param visitor the visitor to be applied on each node
264 */
265template <typename CastType = void, class CC, typename F,
266 typename = std::enable_if_t<detail::is_visitable_dispatch_root<CC>>,
267 typename = std::enable_if_t<!detail::is_visitor<F>>>
268void visit(CC&& root, F&& fun) {
269 using Arg = std::remove_reference_t<typename lambda_traits<F>::template arg<0>>;
270 using Node = detail::visit_root_type_or_void<Arg>;
271 static_assert(!std::is_same_v<Node, void>, "unable to infer node type from function's first argument");
272
273 souffle::visit(std::forward<CC>(root),
274 detail::makeLambdaVisitor<CastType,
275 std::conditional_t<std::is_rvalue_reference_v<F>, std::remove_reference_t<F>, F>,
276 copy_const<Arg, Node>>(std::forward<F>(fun)));
277}
278
279/**
280 * Traverses tree in pre-order, stopping immediately if the predicate matches and returns true.
281 * Returns true IFF the predicate matched and returned true at least once.
282 *
283 * @param xs the root(s) of the ASTs to be visited
284 * @param f the predicate to be applied
285 */
286template <typename CC, typename F /* exists B <: Node. B& -> bool */,
287 typename = std::enable_if_t<detail::is_visitable_dispatch_root<CC>>>
288bool visitExists(CC&& xs, F&& f) {
289 if constexpr (detail::is_visitable_node<CC>) {
290 if (detail::matchApply(f, xs)) return true;
291
292 return visitExists(getChildNodes(xs), std::forward<F>(f));
293 } else if constexpr (detail::is_visitable_pointer<CC>) {
294 static_assert(std::is_lvalue_reference_v<CC> || std::is_pointer_v<remove_cvref_t<CC>>,
295 "arg isn't an l-value `Own<Node>` or a raw pointer. this is likely a mistake: an r-value "
296 "root means the tree will be destroyed after the call");
297 assert(xs);
298 return visitExists(*xs, f);
299 } else if constexpr (detail::is_visitable_container<CC>) {
300 for (auto&& x : xs)
301 if (visitExists(x, f)) return true;
302
303 return false;
304 } else
305 static_assert(unhandled_dispatch_type<CC>);
306}
307
308/**
309 * Traverses tree in pre-order, stopping immediately if the predicate matches and returns false.
310 * Returns true IFF the predicate never matched and returned false.
311 *
312 * @param xs the root(s) of the ASTs to be visited
313 * @param f the predicate to be applied
314 */
315template <typename CC, typename F /* exists B <: Node. B& -> bool */,
316 typename = std::enable_if_t<detail::is_visitable_dispatch_root<CC>>>
317bool visitForAll(CC&& xs, F&& pred) {
318 using Arg = remove_cvref_t<typename lambda_traits<F>::template arg<0>>;
319 return !visitExists(std::forward<CC>(xs), [&](const Arg& x) { return !pred(x); });
320}
321
322/**
323 * Traverse tree in pre-order. Skips sub-trees when predicate matches and returns true.
324 * Returns # of frontiers found. (i.e. # of times predicate returned true.)
325 *
326 * @param xs the root(s) of the ASTs to be visited
327 * @param f the `is-frontier` predicate
328 */
329template <typename CC, typename F /* exists B <: Node. B& -> bool */,
330 typename = std::enable_if_t<detail::is_visitable_dispatch_root<CC>>>
331size_t visitFrontier(CC&& xs, F&& f) {
332 if constexpr (detail::is_visitable_node<CC>) {
333 if (detail::matchApply(f, xs)) return 1;
334
335 return visitFrontier(getChildNodes(xs), f);
336 } else if constexpr (detail::is_visitable_pointer<CC>) {
337 static_assert(std::is_lvalue_reference_v<CC> || std::is_pointer_v<remove_cvref_t<CC>>,
338 "arg isn't an l-value `Own<Node>` or a raw pointer. this is likely a mistake: an r-value "
339 "root means the tree will be destroyed after the call");
340 assert(xs);
341 return visitFrontier(*xs, f);
342 } else if constexpr (detail::is_visitable_container<CC>) {
343 size_t count = 0;
344 for (auto&& x : xs)
345 count += visitFrontier(x, f);
346
347 return count;
348 } else
349 static_assert(unhandled_dispatch_type<CC>);
350}
351
352} // namespace souffle