FreeTensor
Loading...
Searching...
No Matches
tensor.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_TENSOR_H
2#define FREE_TENSOR_TENSOR_H
3
4#include <string>
5#include <vector>
6
7#include <except.h>
8#include <expr.h>
9#include <type/data_type.h>
10
11namespace freetensor {
12
13class Tensor : public ASTPart {
14 template <class T> friend Ref<Tensor> makeTensor(T &&, DataType);
15 friend Ref<Tensor> makeTensor(std::initializer_list<Expr>, DataType);
16
17 SubTreeList<ExprNode> shape_ = ChildOf{this};
18 DataType dtype_;
19
20 public:
21 auto &shape() { return shape_; }
22 const auto &shape() const { return shape_; }
23 void setShape(SubTreeList<ExprNode> &&shape) { shape_ = std::move(shape); }
24 void setShape(const SubTreeList<ExprNode> &shape) { shape_ = shape; }
25 void setShape(const std::vector<Expr> &shape) { shape_ = shape; }
26 void setShape(std::initializer_list<Expr> shape) { shape_ = shape; }
27
28 DataType dtype() const { return dtype_; }
29 void setDType(const DataType &dtype) { dtype_ = dtype; }
30
31 bool isScalar() const;
32
33 void compHash() override;
34};
35
36template <class T> Ref<Tensor> makeTensor(T &&shape, DataType dtype) {
37 auto t = Ref<Tensor>::make();
38 t->shape_ = std::forward<T>(shape);
39 t->dtype_ = dtype;
40 return t;
41}
42inline Ref<Tensor> makeTensor(std::initializer_list<Expr> shape,
43 DataType dtype) {
44 auto t = Ref<Tensor>::make();
45 t->shape_ = shape;
46 t->dtype_ = dtype;
47 return t;
48}
49
51 return makeTensor(t->shape(), t->dtype());
52}
53
54} // namespace freetensor
55
56#endif // FREE_TENSOR_TENSOR_H
Definition: sub_tree.h:50
Definition: data_type.h:106
Definition: ref.h:24
static Ref make()
Definition: ref.h:105
Definition: sub_tree.h:305
Definition: tensor.h:13
void setShape(std::initializer_list< Expr > shape)
Definition: tensor.h:26
void compHash() override
Definition: tensor.cc:17
void setShape(const std::vector< Expr > &shape)
Definition: tensor.h:25
auto & shape()
Definition: tensor.h:21
friend Ref< Tensor > makeTensor(T &&, DataType)
Definition: tensor.h:36
DataType dtype() const
Definition: tensor.h:28
void setShape(const SubTreeList< ExprNode > &shape)
Definition: tensor.h:24
const auto & shape() const
Definition: tensor.h:22
bool isScalar() const
Definition: tensor.cc:6
void setDType(const DataType &dtype)
Definition: tensor.h:29
void setShape(SubTreeList< ExprNode > &&shape)
Definition: tensor.h:23
Definition: allocator.h:9
Expr deepCopy(const Expr &op)
Definition: ast.cc:364
Ref< Tensor > makeTensor(T &&shape, DataType dtype)
Definition: tensor.h:36
Definition: sub_tree.h:20