FreeTensor
Loading...
Searching...
No Matches
z3_simplify.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_Z3_SIMPLIFY
2#define FREE_TENSOR_Z3_SIMPLIFY
3
4#include <deque>
5#include <optional>
6#include <unordered_map>
7
8#include <z3++.h>
9
11#include <func.h>
12#include <hash.h>
13#include <mutator.h>
14#include <visitor.h>
15
16namespace freetensor {
17
35class Z3Simplify : public Mutator {
36 typedef Mutator BaseClass;
37
38 private:
39 int varCnt_ = 0;
41
42 z3::context ctx_;
43 z3::solver solver_;
44
51 struct ExprInfo {
53 std::optional<z3::expr> self_;
54
57 std::vector<std::optional<z3::expr>> conds_;
58 };
59
60 // We use std::optional because there is no z3::expr::expr()
61 std::unordered_map<Expr, ExprInfo> z3Exprs_;
62
63 public:
64 Z3Simplify() : solver_(ctx_) {}
65
66 protected:
67 int getVarId(const Expr &op);
68
69 void put(const Expr &key, const z3::expr &expr,
70 const std::vector<std::optional<z3::expr>> &conds = {});
71 bool exists(const Expr &key);
72 const z3::expr &get(const Expr &key);
73 const std::vector<std::optional<z3::expr>> &conds(const Expr &key);
74
75 void push(const Expr &op);
76 void pop();
77
78 bool prove(const Expr &op);
79
80 using Mutator::visit;
81
82 Expr visit(const Var &op) override;
83 Expr visit(const Load &op) override;
84 // TODO: Cast can also be treated as Load
85 Expr visit(const IntConst &op) override;
86 Expr visit(const BoolConst &op) override;
87 Expr visit(const Add &op) override;
88 Expr visit(const Sub &op) override;
89 Expr visit(const Mul &op) override;
90 Expr visit(const FloorDiv &op) override;
91 Expr visit(const CeilDiv &op) override;
92 Expr visit(const Mod &op) override;
93 Expr visit(const Min &op) override;
94 Expr visit(const Max &op) override;
95 Expr visit(const LT &op) override;
96 Expr visit(const LE &op) override;
97 Expr visit(const GT &op) override;
98 Expr visit(const GE &op) override;
99 Expr visit(const EQ &op) override;
100 Expr visit(const NE &op) override;
101 Expr visit(const LAnd &op) override;
102 Expr visit(const LOr &op) override;
103 Expr visit(const LNot &op) override;
104 Expr visit(const IfExpr &op) override;
105
106 Stmt visit(const If &op) override;
107 Stmt visit(const Assert &op) override;
108 Stmt visit(const Assume &op) override;
109 Stmt visit(const For &op) override;
110};
111
116 public SymbolTableInterface {
117 SymbolTableData symbols_;
118
119 public:
120 const std::unordered_set<std::string> &names() const override {
121 return symbols_.names();
122 }
123 const std::unordered_map<std::string, VarDef> &defs() const override {
124 return symbols_.defs();
125 }
126 const std::unordered_map<std::string, For> &loops() const override {
127 return symbols_.loops();
128 }
129
130 bool hasDef(const std::string &name) const override {
131 return symbols_.hasDef(name);
132 }
133 const VarDef &def(const std::string &name) const override {
134 return symbols_.def(name);
135 }
136 Ref<Buffer> buffer(const std::string &name) const override {
137 return symbols_.buffer(name);
138 }
139
140 bool hasLoop(const std::string &name) const override {
141 return symbols_.hasLoop(name);
142 }
143 const For &loop(const std::string &name) const override {
144 return symbols_.loop(name);
145 }
146
147 void pushDef(const VarDef &op) override { symbols_.pushDef(op); }
148 void popDef(const VarDef &op) override { symbols_.popDef(op); }
149
150 void pushFor(const For &op) override { symbols_.pushFor(op); }
151 void popFor(const For &op) override { symbols_.popFor(op); }
152
153 protected:
154 using Z3Simplify::visit;
155 Stmt visit(const VarDef &op) override;
156 Stmt visit(const For &op) override;
157};
158
159Stmt z3Simplify(const Stmt &op);
160
162
163} // namespace freetensor
164
165#endif // FREE_TENSOR_Z3_SIMPLIFY
Definition: mutator.h:11
virtual Stmt visit(const Any &op)
Definition: mutator.h:39
Definition: symbol_table.h:33
void pushDef(const VarDef &op) override
Definition: symbol_table.h:79
const std::unordered_map< std::string, VarDef > & defs() const override
Definition: symbol_table.h:42
Ref< Buffer > buffer(const std::string &name) const override
Definition: symbol_table.h:62
const std::unordered_map< std::string, For > & loops() const override
Definition: symbol_table.h:45
void pushFor(const For &op) override
Definition: symbol_table.h:93
bool hasDef(const std::string &name) const override
Definition: symbol_table.h:49
virtual const For & loop(const std::string &name) const override
Definition: symbol_table.h:70
void popFor(const For &op) override
Definition: symbol_table.h:102
const std::unordered_set< std::string > & names() const override
Definition: symbol_table.h:39
virtual bool hasLoop(const std::string &name) const override
Definition: symbol_table.h:66
void popDef(const VarDef &op) override
Definition: symbol_table.h:88
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:53
Definition: symbol_table.h:13
Definition: z3_simplify.h:116
const VarDef & def(const std::string &name) const override
Definition: z3_simplify.h:133
const std::unordered_map< std::string, For > & loops() const override
Definition: z3_simplify.h:126
const std::unordered_set< std::string > & names() const override
Definition: z3_simplify.h:120
Ref< Buffer > buffer(const std::string &name) const override
Definition: z3_simplify.h:136
bool hasDef(const std::string &name) const override
Definition: z3_simplify.h:130
const For & loop(const std::string &name) const override
Definition: z3_simplify.h:143
void pushDef(const VarDef &op) override
Definition: z3_simplify.h:147
Stmt visit(const VarDef &op) override
Definition: z3_simplify.cc:588
bool hasLoop(const std::string &name) const override
Definition: z3_simplify.h:140
void popFor(const For &op) override
Definition: z3_simplify.h:151
void popDef(const VarDef &op) override
Definition: z3_simplify.h:148
const std::unordered_map< std::string, VarDef > & defs() const override
Definition: z3_simplify.h:123
void pushFor(const For &op) override
Definition: z3_simplify.h:150
Definition: z3_simplify.h:35
Z3Simplify()
Definition: z3_simplify.h:64
void pop()
Definition: z3_simplify.cc:65
const z3::expr & get(const Expr &key)
Definition: z3_simplify.cc:35
const std::vector< std::optional< z3::expr > > & conds(const Expr &key)
Definition: z3_simplify.cc:39
void put(const Expr &key, const z3::expr &expr, const std::vector< std::optional< z3::expr > > &conds={})
Definition: z3_simplify.cc:28
bool exists(const Expr &key)
Definition: z3_simplify.cc:33
Expr visit(const Var &op) override
Definition: z3_simplify.cc:67
void push(const Expr &op)
Definition: z3_simplify.cc:58
bool prove(const Expr &op)
Definition: z3_simplify.cc:43
int getVarId(const Expr &op)
Definition: z3_simplify.cc:21
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
Ref< AssumeNode > Assume
Definition: stmt.h:425
Ref< BoolConstNode > BoolConst
Definition: expr.h:134
std::unordered_map< K, V, Hasher, HashComparator > ASTHashMap
Definition: hash.h:114
Ref< VarNode > Var
Definition: expr.h:40
Ref< CeilDivNode > CeilDiv
Definition: expr.h:257
Ref< MaxNode > Max
Definition: expr.h:352
Ref< LAndNode > LAnd
Definition: expr.h:450
Ref< LoadNode > Load
Definition: expr.h:61
Ref< SubNode > Sub
Definition: expr.h:186
Ref< IfNode > If
Definition: stmt.h:352
Ref< IfExprNode > IfExpr
Definition: expr.h:705
Stmt z3Simplify(const Stmt &op)
Definition: z3_simplify.cc:602
Ref< ForNode > For
Definition: stmt.h:308
Ref< LNotNode > LNot
Definition: expr.h:488
Ref< GENode > GE
Definition: expr.h:408
Ref< GTNode > GT
Definition: expr.h:394
Ref< FloorDivNode > FloorDiv
Definition: expr.h:237
Ref< MulNode > Mul
Definition: expr.h:200
Ref< LTNode > LT
Definition: expr.h:366
Ref< NENode > NE
Definition: expr.h:436
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< MinNode > Min
Definition: expr.h:338
Ref< IntConstNode > IntConst
Definition: expr.h:100
Ref< EQNode > EQ
Definition: expr.h:422
Ref< AddNode > Add
Definition: expr.h:172
Ref< ExprNode > Expr
Definition: ast.h:184
Ref< ModNode > Mod
Definition: expr.h:301
Ref< AssertNode > Assert
Definition: stmt.h:392
Ref< LOrNode > LOr
Definition: expr.h:464
Ref< LENode > LE
Definition: expr.h:380