-
Notifications
You must be signed in to change notification settings - Fork 3k
Per Op Gradient Building
To build a gradient for an Op in onnxruntime, we need to define a subgraph, which has list of Ops to perform on the forward inputs and gradient outputs. The set of Ops available in ONNX can be found in Operators.md. In many cases we can compose the gradient with existing ops, but some cases require us to define a gradient op, e.g. ConvGrad
from scratch. The gradient registry and gradient builder have functions and classes to build the gradient subgraph.
Understand the Gradient Registry (gradient_builder_registry.cc)
Gradient Registry contains all the gradients supported by onnxruntime. To register a gradient, under RegisterGradientBuilders()
function add,
REGISTER_GRADIENT_BUILDER(<Op Name>, <GradientFuncName>); eg: REGISTER_GRADIENT_BUILDER("Tile", GetTileGradient);
Understand the Gradient Builder Declaration (gradient_builder.h)
DECLARE_GRADIENT_BUILDER()
creates a class called <GradientFuncName>
inherited from GradientBuilderBase
class, which has helper functions to implement the gradient subgraph in onnxruntime. You can define the gradient class for an Op by adding the following in gradient_builder.h.
DECLARE_GRADIENT_BUILDER(<GradientFuncName>); eg: DECLARE_GRADIENT_BUILDER(GetTileGradient);
Once the gradient is declared, the gradient subgraph is implemented using IMPLEMENT_GRADIENT_BUILDER()
in gradient_builder.cc. This function returns a vector of nodes (NodeDef
) that represent the gradient sub-graph.
Understand the shorthands of I, GI, O, GO (gradient_builder_base.h)
GradientBuilderBase class has some basic helper functions to write gradients for Ops.
I() - i-th input of forward op
O() - i-th output of forward op
GI() - gradient of i-th input of forward op
GO() - gradient of i-th output of forward op
IA() - intermediate argument
ConstantVectorNode() - creates a node from an input vector of values.
ConstantScalarNode() - creates a node from an input value.
For more functions, refer to gradient_builder_base.h
It also helps to look into the various data types available to declare graph variables like ArgDef
, NodeDef
, OpDef
and how it is used to build a graph in onnxruntime.
NodeDef struct is used to define a node in the subgraph. To build a node we need to define the Op to use (Operators.md), input and output args of ArgDef
type, Attributes (optional) of AttributeProto
type.
ArgDef struct defines the function arguments needed by the ONNX Ops.
AttributeProto type defines the constant attributes required by the node which do not change during runtime.
GetShape function returns the shape of input. This shape can be symbolic or the static shape. Symbolic shapes can change at runtime, in which case we would need to call the Shape Op at runtime.
(Refer to graph_augmenter.h for more helper functions to build sub-graphs.)
Tips to write your own gradient for an Op.
-
Read a few examples in Gradient Builder Implementation (gradient_builder.cc). You can start with this PR which implements the gradient for Tile. Go through how the gradient is declared and implemented.
-
Understand how the gradient subgraph is composed with existing ops The following are good examples :
-
Easy: GetDropoutGradient, GetSqrtGradient
-
Medium: GetAddSubGradient, GetMulGradient
-
Hard: GetMatMulGradient, GetGemmGradient
-
-
Understand how broadcasting is handled when building gradient graph
GradientBuilderBase::HandleBroadcasting()
-
Action: Implement a gradient definition for an op to get hands-on experience
Please use the learning roadmap on the home wiki page for building general understanding of ORT.