FreeTensor
Loading...
Searching...
No Matches
rand_ctx.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_RAND_CTX_H
2#define FREE_TENSOR_RAND_CTX_H
3
4#include <map>
5#include <mutex>
6#include <regex>
7#include <string>
8#include <unordered_map>
9
10#include <container_utils.h>
11#include <func_utils.h>
14
15namespace freetensor {
16
19
23#define PROGRAM_POSITION \
24 ([]() -> ProgramPosition { \
25 static ProgramPositionHelper p; \
26 return &p; \
27 })()
28
29typedef std::vector<DiscreteObservation> RandTrace;
30
35 protected:
36 std::unordered_map<
38 std::unordered_map<Ref<RandCondInterface>, Ref<DiscreteRandVar>,
41 randVars_; // {pos -> {conds -> var}}
42
43 std::unordered_map<ProgramPosition, Ref<std::vector<int>>> totCnt_;
44
45 std::multimap<Ref<RandTrace>, std::pair<double, double>,
47 traces_; // Ordered map
48
49 bool isInfer_ = true;
50 std::regex toLearn_{".*"};
51
52 std::mutex lock_;
53
54 public:
65 void observeTrace(const Ref<RandTrace> &trace, double value, double stddev);
66
76 void setLearnFilter(const std::regex &toLearn) { toLearn_ = toLearn; }
77
83 void setLearn() { isInfer_ = false; }
84
90 void setInfer() { isInfer_ = true; }
91};
92
101template <std::uniform_random_bit_generator RNG>
102class RandCtx : public RandCtxImpl {
103 RNG &rng_;
104
105 public:
106 RandCtx(RNG &rng) : rng_(rng) {}
107
117 int decide(ProgramPosition pos, const std::string &name,
118 const RandCondStack &condStack,
119 const std::vector<double> &priori, const Ref<RandTrace> &trace,
120 const std::string &message = "") {
121 std::lock_guard<std::mutex> guard(lock_);
122
123 auto INIT_OBS = 4;
124
125 if (!totCnt_.count(pos)) {
126 totCnt_[pos] = Ref<std::vector<int>>::make(priori.size(), INIT_OBS);
127 }
128 auto &&totCnt = totCnt_.at(pos);
129
130 std::vector<double> prob{totCnt->begin(), totCnt->end()};
131 for (auto it = condStack; !it.empty(); it = it.pop()) {
132 auto &&cond = it.top();
133 if (!randVars_.count(pos) || !randVars_.at(pos).count(cond)) {
134 std::vector<int> initObs;
135 initObs.reserve(priori.size());
136 for (auto &&p : priori) {
137 initObs.emplace_back((int)(INIT_OBS * p));
138 }
139 randVars_[pos][cond] =
140 Ref<DiscreteRandVar>::make(name, cond, totCnt, initObs);
141 }
142 auto &&var = randVars_.at(pos).at(cond);
143 auto localProb = var->prob();
144 for (auto &&[p, q] : views::zip(prob, localProb)) {
145 p *= q;
146 }
147 }
148
149 int value;
150 if (isInfer_ || !std::regex_match(name, toLearn_)) { // Most likely
151 value = std::max_element(prob.begin(), prob.end()) - prob.begin();
152 } else { // Sample
153 value =
154 std::discrete_distribution<int>(prob.begin(), prob.end())(rng_);
155 }
156
157 if (trace.isValid()) {
158 std::vector<Ref<DiscreteRandVar>> vars;
159 for (auto it = condStack; !it.empty(); it = it.pop()) {
160 auto &&cond = it.top();
161 auto &&var = randVars_.at(pos).at(cond);
162 vars.emplace_back(var);
163 }
164 trace->emplace_back(vars, totCnt, value, message);
165 }
166
167 return value;
168 }
169};
170
171} // namespace freetensor
172
173#endif // FREE_TENSOR_RAND_CTX_H
Definition: func_utils.h:9
Definition: rand_ctx.h:34
std::regex toLearn_
Definition: rand_ctx.h:50
std::mutex lock_
Definition: rand_ctx.h:52
void setLearnFilter(const std::regex &toLearn)
Definition: rand_ctx.h:76
bool isInfer_
Definition: rand_ctx.h:49
void setLearn()
Definition: rand_ctx.h:83
void observeTrace(const Ref< RandTrace > &trace, double value, double stddev)
Definition: rand_ctx.cc:5
void setInfer()
Definition: rand_ctx.h:90
std::unordered_map< ProgramPosition, std::unordered_map< Ref< RandCondInterface >, Ref< DiscreteRandVar >, PtrInvocable< std::hash< RandCondInterface > >, PtrInvocable< std::equal_to< RandCondInterface > > > > randVars_
Definition: rand_ctx.h:41
std::unordered_map< ProgramPosition, Ref< std::vector< int > > > totCnt_
Definition: rand_ctx.h:43
std::multimap< Ref< RandTrace >, std::pair< double, double >, PtrInvocable< std::less< RandTrace > > > traces_
Definition: rand_ctx.h:47
Definition: rand_ctx.h:102
RandCtx(RNG &rng)
Definition: rand_ctx.h:106
int decide(ProgramPosition pos, const std::string &name, const RandCondStack &condStack, const std::vector< double > &priori, const Ref< RandTrace > &trace, const std::string &message="")
Definition: rand_ctx.h:117
Definition: ref.h:24
static Ref make()
Definition: ref.h:105
bool isValid() const
Definition: ref.h:89
Definition: shared_linked_list.h:21
bool empty() const
Definition: shared_linked_list.h:36
Definition: allocator.h:9
const ProgramPositionHelper * ProgramPosition
Definition: rand_ctx.h:18
std::vector< DiscreteObservation > RandTrace
Definition: rand_ctx.h:29
Definition: rand_ctx.h:17