Pytorch Custom Function用法

torch.autograd.Function

https://pytorch.org/docs/master/notes/extending.html

>>> class Exp(Function):
>>>
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
  • 从pyotrch1.3开始,forward() backward() 都必须是@staticmethod.

backward(ctx, *grad_outputs)

  • 默认第一个参数是ctx
  • ctx后面跟的参数个数和forward() 的return个数相同(??)
  • Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.(??)
  • gradient w.r.t (???)

Example ReLU pytorch defining new autograd functions

当需要对方程传入参数 non-Tensor arguments

class MulConstant(Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        # ctx is a context object that can be used to stash information
        # for backward computation
        ctx.constant = constant
        return tensor * constant

    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。