Automatic Differentiation¶
Automatic Differentiation (AD) transforms a program to another program that computes the original one's derivative or gradient. FreeTensor supports Reverse-Mode AD, and there is a plan to support Forward-Mode AD in the future.
Reverse-Mode AD¶
Suppose there is a program x -> y -> z -> w
that computes an output w
from intermediate variables z
and y
, and an input variable x
. Reverse-Mode AD generates a gradient program dw/dw=1 -> dw/dz -> dw/dy -> dw/dx
that computes dw/dx
by Chain Rule. y
, z
and w
may be saved in a "tape" when evaluation the original program, to be reused in the gradient one.
If FreeTensor is built with WITH_PYTORCH=ON
, you can skip this section and turn to the @optimize_to_pytorch
integration, which integrates seamlessly with PyTorch's autograd mechanism, but will incur some runtime overhead.
Here is an example of Reverse-Mode AD in FreeTensor:
import freetensor as ft
import numpy as np
n = 4
@ft.grad(requires=['a', 'b'], provides=[ft.Return()], attach_backward=True)
def test(a: ft.Var[(n,), "float32"], b: ft.Var[(n,), "float32"]):
y = ft.zeros((), "float32")
for i in range(n):
y[()] += a[i] * b[i]
return y
a = np.array([0, 1, 2, 3], dtype="float32")
b = np.array([3, 2, 1, 0], dtype="float32")
y = test(a, b)
dzdy = np.array(1, dtype='float32')
input_grads = test.input_name_to_gradient_name
output_grads = test.output_name_to_gradient_name
dzda, dzdb = test.backward(
**{output_grads[ft.Return()]: dzdy})[input_grads['a'], input_grads['b']]
You need to call ft.grad
(or the inplace version ft.grad_
) to generate a forward function and a backward function. In this example, the backward function is attached as the test.backward
property because attach_backward
is set to True
. You can set it to False
and ft.grad
will return both functions. Please note that test
is updated by ft.grad
and becomes different than the original function, as it may save some intermediate tensors to a global tape
, and it must be executed before the backward test.backward
Note on JIT
JIT is only supported when attach_backward = True
After that, you call ft.optimize
to optimize and compile the program just as in previous examples. This time it is done for both test
and test.backward
Finally, you execute test
and test.backward
. The parameters and return values of test.backward
are the gradients of a
, b
and y
, which have their own names. To set and get these parameters and return values, you look up for them in two dictionaries test.input_name_to_gradient_name
and test.output_name_to_gradient_name
(in type ft.ParamRetDict
. These two dictionaries accept either a name of a parameter, or a special ft.Return
to specify a return value. When invoking test.backward
, parameters can be set via keyword arguments, and return values can be collect via a bracket (from a special type ft.ReturnValuesPack
). These two maps are attached to test
because attach_backward
is True
. Otherwise, they are returned as return values from ft.grad
Intermediate variables are not always have to be saved to the "tape" from the forward function. If a variable is need in the backward function but not saved, it will be re-computed, which is sometimes even faster than saving it due to better locality. By default, FreeTensor uses heuristics to determine which variable to save. To get better performance, you may want to control which intermediate variables should be saved by setting an optional tapes
parameter in ft.grad
. tapes
can either be a different mode, or a explicit list of AST node IDs of all VarDef
nodes of the variables you want to save.
Providing Your Custom Gradients¶
Why or When do We Need Custom Gradients¶
Sometimes neither reverse-mode or forward-mode AD produces the most elegant form of gradients. FreeTensor allows you to provide your own gradients for part of the program.
Take softmax as an example: The \(\mathbf{y} = softmax(\mathbf{x})\) function is mathematically defined by the following steps:
Suppose the final output of the program (the loss) is \(z\). If using reverse-mode AD, the gradient of the input: \(\frac{\partial z}{\partial x}\) can be computed by the following steps:
However, usually we can NOT compute softmax by Equation \(\eqref{eq:softmax-1}\eqref{eq:softmax-2}\eqref{eq:softmax-3}\) for numerical stability issues. Pratically, we compute softmax with additional normalization on \(\mathbf{x}\):
If we directly apply reverse-mode AD on Equation \(\eqref{eq:softmax-norm-1}\eqref{eq:softmax-norm-2}\eqref{eq:softmax-norm-3}\eqref{eq:softmax-norm-4}\), the backward program will be like:
You may have found that there is an extra \(\frac{\partial z}{\partial m}\) involved. Apparently, the gradient should be the same no matter if we do the normalization. This is because \(\frac{\partial z}{\partial m}\) actually always equals to \(0\). FreeTensor can not dig out this mathematical property, so the computation on \(\frac{\partial z}{\partial m}\) will remain and will be wasted.
How to Write Custom Gradients in FreeTensor¶
The following examples will demonstrate how to provide your own custom gradients, to override the default AD behaviour. Please note that this is only for demonstration. If you are just going to use softmax, call it from libop.softmax
, which has already implemented the following code.
First we show a softmax implementation with full AD:
import freetensor as ft
import torch
n = 4
@ft.optimize # Set verbose=1 to see the code
@ft.grad(requires=['x'], provides=[ft.Return()], attach_backward=True)
def test(x: ft.Var[(n,), "float32"]):
# Automatically decide gradients for this statement
m = ft.reduce_max(x, axes=[-1])
e = ft.exp(x - m)
s = ft.reduce_sum(e, axes=[-1])
y = e / s
return y
# Check forward result
x = torch.rand(n, dtype=torch.float32)
x.requires_grad = True
y_ft = test(x).torch()
y_torch = torch.softmax(x, axis=-1)
assert torch.all(torch.isclose(y_ft, y_torch))
# Check backward result
dzdy = torch.rand(n, dtype=torch.float32)
y_torch.grad = dzdy.clone()
input_grads = test.input_name_to_gradient_name
output_grads = test.output_name_to_gradient_name
dzdx_ft = test.backward(**{output_grads[ft.Return()]: dzdy}).torch()
dzdx_torch = x.grad
assert torch.all(torch.isclose(dzdx_ft, dzdx_torch, 1e-4, 1e-7))
Then, we add our own gradient to it:
import freetensor as ft
import torch
n = 4
@ft.optimize # Set verbose=1 to see the code
@ft.grad(requires=['x'], provides=[ft.Return()], attach_backward=True)
def test(x: ft.Var[(n,), "float32"]):
# Mark the range that you want to provide graident for, with `StmtRange`
with ft.StmtRange() as rng:
m = ft.reduce_max(x, axes=[-1])
e = ft.exp(x - m)
s = ft.reduce_sum(e, axes=[-1])
y = e / s
# Call `push_for_backward` so we can use forward values in backward
e_now = ft.push_for_backward(e)
s_now = ft.push_for_backward(s)
y_now = ft.push_for_backward(y)
# Define gradient in `UserGrad`
with ft.UserGrad(x, y, stmt_range=rng) as (dzdx, dzdy):
# Retrieve forward value from `y_now`, NOT `y`
dzds = -ft.reduce_sum(dzdy * y_now, axes=[-1]) / s_now
dzde = dzdy / s_now + dzds
dzdx[...] += dzde * e_now # Use `+=` here
return y
# Check forward result
x = torch.rand(n, dtype=torch.float32)
x.requires_grad = True
y_ft = test(x).torch()
y_torch = torch.softmax(x, axis=-1)
assert torch.all(torch.isclose(y_ft, y_torch))
# Check backward result
dzdy = torch.rand(n, dtype=torch.float32)
y_torch.grad = dzdy.clone()
input_grads = test.input_name_to_gradient_name
output_grads = test.output_name_to_gradient_name
dzdx_ft = test.backward(**{output_grads[ft.Return()]: dzdy}).torch()
dzdx_torch = x.grad
assert torch.all(torch.isclose(dzdx_ft, dzdx_torch, 1e-4, 1e-7))
First, we mark the range of code that we want to provide gradient for, with ft.StmtRange
, as a name rng
. In the range, we write the code to compute softmax
as usual. Additionaly, for the values that we want to reuse in the gradient, we call ft.push_for_backward
to save it. push_for_backward
returns a handle that you can use as a usual tensor in the gradient code. If your StmtRange
is inside an outer loop, the handle will always reflect the correct iteration (see the next example). Besides, push_for_backward
does not mean the value will be physically saved in tape: it only means the value will be logically reused in the backward, no matter by saving or by recomputing. push_for_backward
is orthogonal with the tapes
parameter in ft.grad
Next, we define our custom gradient with a ft.UserGrad
scope. The scopes receives a special parameter stmt_range
, which should be set to the StmtRange
we have just defined. Beside stmt_range
, UserGrand
receives an arbitrary number of parameters, in this case, x
and y
, and returns the same number of variables, dzdx
and dzdy
, so we have the mapping between each variable and its gradient. What we are going to do is update dzdx
from dzdy
We define our gradient code in the UserGrad
code of Equation \(\eqref{eq:softmax-grad-1}\eqref{eq:softmax-grad-2}\eqref{eq:softmax-grad-3}\). We want to use the forward value y
, s
and e
. But do NOT directly use its name, use the push_for_backward
handler y_now
, s_now
and e_now
instead. Finally, plase note that we update dzdx
with +=
instead of =
, because we may be only computing a partial derivative: there may be other functions of x
other than y
And it is all done.
Additional Descriptions on push_for_backward
We have mentioned push_for_backward
will automatically handle multiple versions of a variable. If you are familiar with PyTorch, you may have found the name is similar to PyTorch's save_for_backward
. Here, versioning is the major difference: ft.push_for_backward
can be called multiple times on a variable, to save multiple version (or snapshot of it), while the variable can keep changing.
Here is an additional example: a softmax written in a loop form, where we receives a 2-d input, and apply softmax on the second dimension. Again, this is only for demonstration, and there are multiple ways to implement a softmax.
import freetensor as ft
import torch
n = 4
@ft.optimize # Set verbose=1 to see the code
@ft.grad(requires=['x'], provides=[ft.Return()], attach_backward=True)
def test(x: ft.Var[(n, n), "float32"]):
y = ft.empty((n, n), "float32")
for i in range(n):
# Mark the range that you want to provide graident for, with `StmtRange`
with ft.StmtRange() as rng:
# `m`, `e` and `s` are local to `i`
m = ft.reduce_max(x[i], axes=[-1])
e = ft.exp(x[i] - m)
s = ft.reduce_sum(e, axes=[-1])
y[i] = e / s
# Call `push_for_backward` so we can use forward values in backward
e_now = ft.push_for_backward(e)
s_now = ft.push_for_backward(s)
y_now = ft.push_for_backward(y)
# Define gradient in `UserGrad`
with ft.UserGrad(x, y, stmt_range=rng) as (dzdx, dzdy):
# Retrieve forward value from `y_now`, NOT `y`
dzds = -ft.reduce_sum(dzdy[i] * y_now[i], axes=[-1]) / s_now
dzde = dzdy[i] / s_now + dzds
dzdx[i] += dzde * e_now # Use `+=` here
return y
# Check forward result
x = torch.rand(n, n, dtype=torch.float32)
x.requires_grad = True
y_ft = test(x).torch()
y_torch = torch.softmax(x, axis=-1)
assert torch.all(torch.isclose(y_ft, y_torch))
# Check backward result
dzdy = torch.rand(n, n, dtype=torch.float32)
y_torch.grad = dzdy.clone()
input_grads = test.input_name_to_gradient_name
output_grads = test.output_name_to_gradient_name
dzdx_ft = test.backward(**{output_grads[ft.Return()]: dzdy}).torch()
dzdx_torch = x.grad
assert torch.all(torch.isclose(dzdx_ft, dzdx_torch, 1e-4, 1e-7))
Here our gradient scope is inside a loop, where m
, e
and s
are local to the loop iteration. When we load the value from their push_for_backward
handlers, we get the version of value at the exact iteration we need.