torch.autograd.Function
https://pytorch.org/docs/master/notes/extending.html
>>> class Exp(Function):
>>>
>>> @staticmethod
>>> def forward(ctx, i):
>>> result = i.exp()
>>> ctx.save_for_backward(result)
>>> return result
>>>
>>> @staticmethod
>>> def backward(ctx, grad_output):
>>> result, = ctx.saved_tensors
>>> return grad_output * result
- 从pyotrch1.3开始,
forward()
backward()
都必须是@staticmethod
.
backward(ctx, *grad_outputs)
- 默认第一个参数是
ctx
-
ctx
后面跟的参数个数和forward()
的return个数相同(??) - Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.(??)
- gradient w.r.t (???)
Example ReLU pytorch defining new autograd functions
当需要对方程传入参数 non-Tensor arguments
class MulConstant(Function):
@staticmethod
def forward(ctx, tensor, constant):
# ctx is a context object that can be used to stash information
# for backward computation
ctx.constant = constant
return tensor * constant
@staticmethod
def backward(ctx, grad_output):
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None