概述
PyTorch的成功归功于其简单易用性(与Python的用法相似)和动态灵活性。即使在PyTorch 2.0时代,它仍然保持着"Faster, more pythonic and dynamic as ever"的核心特性。
PyTorch的动态性源自内部的调度器(dispatcher),它可以根据不同的输入类型自动选择正确的运算方式。当调用Python函数时,调度器会根据传入的参数类型选择正确的操作实现,这个过程称为分派(dispatch)。
例如,当执行矩阵乘法(torch.matmul(a, b))时,调度器会根据输入张量a和b的类型(dtype、shape、device等)选择正确的BLAS库(CPU还是CUDA,float还是half,是否批量计算)来进行计算。对于PyTorch来说,模型的执行过程就是将各个操作(op)分派给本地方法(native function)执行的过程。
dispatcher 为每个 op 都维护了一张跳转表(它有点像 C++ 实现多态用的虚表),如上图所示,表中每个条目存储了一个本地方法,有些方法和输入张量所属的设备有关,比如 XLA/CUDA/CPU
,有的和 requires_grad
有关,比如 Autograd
(这图是从 ezyang’s blog 拿来的,他这篇博客详细讲解了分派机制,建议阅读)。
当 op 被执行时,e.g. aten::addmm
,调度器会在它的跳转表中找出一个方法来执行,而且一个 op 执行过程可能会调用多个方法,例如,输入张量需要求导(requires_grad = true),那会先调用 Autograd 方法来构建反向图,再调用 backend(CPU/CUDA/XLA)的方法来运算。
分派规则
跳转表里的条目是以键值对的形式来存调度方法,其中“键”称为 dispatch key
,以 bit 的形式存在,bit 值越大,优先级越高,调度器会从键集(dispatch key set
)中选取优先级最高的条目来执行。
从上图可以看到,键集不只有一个,每个输入张量都有自己的键集,还有 local(local include
和local exclude
) 和 global 键集,这些键集最终会合并,调度器从中选取优先级最高的键值对应的方法来执行。
输入张量的键集是比较好理解的,张量本身具有很多属性,如 layout (dense or sparse)、shape 和 device (CPU or CUDA),一个属性对应一个 dispatch key(可以从 DispatchKey.h 找到所有的 key)。对于不同类型的张量,我们希望能使用不同实现的操作以实现高性能计算的目标。
Local 键集 与张量个体无关,与模型的行为有关,表示模型运行在某模式中,比如 tracing。它可以允许用户在某个范围内开启或关闭模式。要开启模式就是往 local include 里添加键,要关闭模式就是往 local exclude 里添加要屏蔽的键。
Global 则表示无论什么操作都会添加的键集(图中 autograd 已经从 global 移到 tensor 键集)。
分派流程
前面也提到,一个 op 的执行是要经历多次分派的,上图就展示了这个过程:
- 首先,输入张量需要求导(requires_grad = true),调度器就分派给 Autograd key 的本地方法。它会为 op 生成一个反向计算操作,然后,再把控制权交给调度器做重新分派。
- 接着由于输入张量在CPU上,CPU的方法会被分派执行。
前面提到,调度器会调用优先级最高的 dispatch key,因此,重新分派的前提是将已经调度过的键从键集里清除,否则重新分派将会重复调用相同的方法。
Autograd 的本地方法通过在 local exclude 键集中添加要屏蔽的键(Autograd)来避免方法的重复调用。可以通过创建 AutoNonVariableTypeMode RAII guard 来实现:
class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
public:
static Tensor forward(
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
at::AutoNonVariableTypeMode g;
return myadd(self, other);
}
...
};
注册自定义操作
回想一下分派规则:调度器首先找到 op 对应的跳转表,合并键集,并调用键值最大的条目中的函数。由于 dispatch key 是 PyTorch 固定且不可扩展的,因此注册自定义操作需要注册 op 以及跳转表中键的方法。
注册 op
TORCH_LIBRARY(myops, m) {
m.def("myadd(Tensor self, Tensor other) -> Tensor");
}
PyTorch 提供 TORCH_LIBRARY
用于将 op(也称作 schema string
或 signature
)注册到一个库里,用户可以在 python 通过 c = torch._ops.myops.myadd(a, b)
调用该 op。
schema 与 TensorFlow 的 op_def
和 ONNX 的 node
一样,都用于描述一个操作,只是由于 PyTorch 是动态图的,schema 不需要也不能承载更多信息。
注册 dispatch function
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
m.impl("myadd", myadd_cuda);
}
注册完 op 后,接着就可以通过 TORCH_LIBRARY_IMPL
注册 dispatch key 对应的方法。上述代码片段通过将 myadd_cuda
注册到键:CUDA。
除了为每个键单独注册一个方法,还可以为所有的键注册一个共同的方法,这类方法称为 catch-all
:
TORCH_LIBRARY(myops, m) {
m.def("myadd", myadd_catchall);
}
此外,还可以为所有 op 的某个键注册一个共同的 fallback
方法:
TORCH_LIBRARY_IMPL(_, XLA, m) {
m.fallback(xla_fallback);
}
除了 dispatch key 具有优先级外,这些方法也有优先级:impl > catch-all > fallback:
END
PyTorch的调度器(dispatcher)和分派机制是其灵活性和高性能计算的关键。调度器根据输入类型自动选择适当的操作实现,通过分派流程将操作分派给本地方法执行。分派规则通过 dispatch key 和 keyset 确定执行方法的优先级。注册自定义操作的过程允许用户扩展PyTorch的功能。了解这些原理有助于深入理解PyTorch的内部工作机制,并为模型开发和优化提供指导。