FreeTensor
Loading...
Searching...
No Matches
rand_var.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_RAND_VAR_H
2#define FREE_TENSOR_RAND_VAR_H
3
4#include <algorithm>
5#include <iostream>
6#include <random>
7#include <string>
8#include <vector>
9
11#include <ref.h>
12
13namespace freetensor {
14
29 std::string name_;
31 Ref<std::vector<int>> totCnt_;
32 std::vector<int> obs_;
33
34 public:
35 DiscreteRandVar(const std::string &name, const Ref<RandCondInterface> &cond,
36 const Ref<std::vector<int>> totCnt,
37 const std::vector<int> &initObs)
38 : name_(name), cond_(cond), totCnt_(totCnt), obs_(initObs) {}
39
40 void observe(int value, int cnt = 1) { obs_.at(value) += cnt; }
41
42 std::vector<double> prob() const {
43 std::vector<double> ret;
44 ret.reserve(obs_.size());
45 for (auto &&[p, q] : views::zip(obs_, *totCnt_)) {
46 ret.emplace_back((double)p / q);
47 }
48 return ret;
49 }
50
55 auto ret = Ref<DiscreteRandVar>::make(*this);
56 ret->totCnt_ = Ref<std::vector<int>>::make(*ret->totCnt_);
57 return ret;
58 }
59
60 const std::string &name() const { return name_; }
61
62 friend std::ostream &operator<<(std::ostream &os,
63 const DiscreteRandVar &var);
64};
65
74 std::vector<Ref<DiscreteRandVar>> vars_;
75 std::vector<Ref<DiscreteRandVar>>
76 varsSnapshot_; // Freeze the distribution at the time of
77 // observation, used for debugging
79 int value_;
80 std::string message_; // Debug info
81
82 DiscreteObservation(const std::vector<Ref<DiscreteRandVar>> &vars,
83 const Ref<std::vector<int>> &totCnt, int value,
84 const std::string &message = "")
85 : vars_(vars), totCnt_(totCnt), value_(value), message_(message) {
86 varsSnapshot_.reserve(vars_.size());
87 for (auto &&var : vars_) {
88 varsSnapshot_.emplace_back(var->clone());
89 }
90 }
91
93 const DiscreteObservation &rhs) {
94 if (auto cmp = lhs.vars_ <=> rhs.vars_; cmp != 0) {
95 return cmp;
96 }
97 return lhs.value_ <=> rhs.value_;
98 }
100 const DiscreteObservation &rhs) {
101 return lhs.vars_ == rhs.vars_ && lhs.value_ == rhs.value_;
102 }
103
104 friend std::ostream &operator<<(std::ostream &os,
105 const DiscreteObservation &obs);
106};
107
108} // namespace freetensor
109
110#endif // FREE_TENSOR_RAND_VAR_H
Definition: rand_var.h:28
friend std::ostream & operator<<(std::ostream &os, const DiscreteRandVar &var)
Definition: rand_var.cc:6
DiscreteRandVar(const std::string &name, const Ref< RandCondInterface > &cond, const Ref< std::vector< int > > totCnt, const std::vector< int > &initObs)
Definition: rand_var.h:35
void observe(int value, int cnt=1)
Definition: rand_var.h:40
Ref< DiscreteRandVar > clone() const
Definition: rand_var.h:54
std::vector< double > prob() const
Definition: rand_var.h:42
const std::string & name() const
Definition: rand_var.h:60
Definition: ref.h:24
static Ref make()
Definition: ref.h:105
Expr cond_
Definition: invert_stmts.cc:58
Definition: allocator.h:9
auto && lhs
Definition: const_fold.cc:70
auto auto && rhs
Definition: const_fold.cc:70
Definition: rand_var.h:73
std::string message_
Definition: rand_var.h:80
int value_
Definition: rand_var.h:79
friend std::ostream & operator<<(std::ostream &os, const DiscreteObservation &obs)
Definition: rand_var.cc:15
friend bool operator==(const DiscreteObservation &lhs, const DiscreteObservation &rhs)
Definition: rand_var.h:99
DiscreteObservation(const std::vector< Ref< DiscreteRandVar > > &vars, const Ref< std::vector< int > > &totCnt, int value, const std::string &message="")
Definition: rand_var.h:82
std::vector< Ref< DiscreteRandVar > > vars_
Definition: rand_var.h:74
Ref< std::vector< int > > totCnt_
Definition: rand_var.h:78
friend auto operator<=>(const DiscreteObservation &lhs, const DiscreteObservation &rhs)
Definition: rand_var.h:92
std::vector< Ref< DiscreteRandVar > > varsSnapshot_
Definition: rand_var.h:76