FreeTensor
Loading...
Searching...
No Matches
find_stmt.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_FIND_STMT_H
2#define FREE_TENSOR_FIND_STMT_H
3
4#include <func.h>
5#include <selector.h>
6#include <visitor.h>
7
8namespace freetensor {
9
10class FindStmtById : public Visitor {
11 ID id_;
12 Stmt result_;
13
14 public:
15 FindStmtById(const ID &id) : id_(id) {}
16
17 const Stmt &result() const { return result_; }
18
19 protected:
20 void visitStmt(const Stmt &op) override;
21};
22
23class FindStmtByFilter : public Visitor {
24 const std::function<bool(const Stmt &)> &filter_;
25 std::vector<Stmt> results_;
26
27 public:
28 FindStmtByFilter(const std::function<bool(const Stmt &)> &filter)
29 : filter_(filter) {}
30 const std::vector<Stmt> &results() const { return results_; }
31
32 protected:
33 void visitStmt(const Stmt &op) override;
34};
35
43std::vector<Stmt> findAllStmt(const Stmt &ast, const ID &id);
44std::vector<Stmt> findAllStmt(const Stmt &ast,
45 const std::function<bool(const Stmt &)> &filter);
46std::vector<Stmt> findAllStmt(const Stmt &ast, const Ref<Selector> &selector);
47inline std::vector<Stmt> findAllStmt(const Stmt &ast,
48 const std::string &pattern) {
49 return findAllStmt(ast, parseSelector(pattern));
50}
51template <class T>
52std::vector<Stmt> findAllStmt(const Func &func, const T &filter) {
53 return findAllStmt(func->body_, filter);
54}
65Stmt findStmt(const Stmt &ast, const ID &id);
66Stmt findStmt(const Stmt &ast, const std::function<bool(const Stmt &)> &filter);
67Stmt findStmt(const Stmt &ast, const Ref<Selector> &selector);
68inline Stmt findStmt(const Stmt &ast, const std::string &pattern) {
69 return findStmt(ast, parseSelector(pattern));
70}
71template <class T> Stmt findStmt(const Func &func, const T &filter) {
72 return findStmt(func->body_, filter);
73}
76} // namespace freetensor
77
78#endif // FREE_TENSOR_FIND_STMT_H
Definition: find_stmt.h:23
const std::vector< Stmt > & results() const
Definition: find_stmt.h:30
FindStmtByFilter(const std::function< bool(const Stmt &)> &filter)
Definition: find_stmt.h:28
void visitStmt(const Stmt &op) override
Definition: find_stmt.cc:14
Definition: find_stmt.h:10
FindStmtById(const ID &id)
Definition: find_stmt.h:15
void visitStmt(const Stmt &op) override
Definition: find_stmt.cc:5
const Stmt & result() const
Definition: find_stmt.h:17
SubTree< StmtNode > body_
Definition: func.h:56
Definition: id.h:18
Definition: visitor.h:11
Definition: allocator.h:9
Stmt findStmt(const Stmt &ast, const ID &id)
Definition: find_stmt.cc:22
Ref< Selector > parseSelector(const std::string &str)
Definition: selector.cc:201
std::vector< T > filter(const std::vector< T > &vec, const U &callback)
Definition: container_utils.h:131
Ref< StmtNode > Stmt
Definition: ast.h:152
std::vector< Stmt > findAllStmt(const Stmt &ast, const ID &id)
Definition: find_stmt.cc:32