joddiy commented on issue #691: URL: https://github.com/apache/singa/issues/691#issuecomment-629147694
## Conslusion first Good news: > The ONNX can defines the loss and optimizer now within its format. However, current loss only have `NegativeLogLikelihoodLoss` and `SoftmaxCrossEntropyLoss`. Also, it only can store optimizers, only have - `Adagrad`, `Adam`, `Momentum`(SGD with standard momentum). Bad news: > we need to update the onnx to 1.7, which is released last week, may not be so stable. In this release, ONNX defines a comlicated node called `GraphCall` to specify which gradients should be computed and how to update the tensors by using these gradients. Since we will update the weights following the backward, so this part may not be useful for us. ## ONNX Training Preview (TrainingInfoProto) In last week, the ONNX team has released a new version [1.7.0](https://github.com/onnx/onnx/releases/tag/v1.7.0) which upgrade its opset version to 12. In this new rleases, they add a new feature called [`TrainingInfoProto`](https://github.com/onnx/onnx/blob/3368834cf0b1f0ab9838cf6bdf78a27299d08187/onnx/onnx.proto3#L211-L316). This new feature defines something about training information. There are two main parts in it, `initialization-step` and `training-algorithm-step`. ### initialization-step `initialization-step` means the developer can defines a `initialization`. For its type, the `initialization` is a formal ONNX graph. It doesn't have input but seveal outputs. The developer can defines some nodes in this graph, such as `RandomNormal` or `RandomUniform`, and in another field called `initialization_binding`, the developer can assign these outputs to the specific tensors in the inference graph. The current supported ramdom methods are: `RandomNormal` or `RandomUniform`. ### training-algorithm-step `training-algorithm-step` defines a field called `algorithm`. It defines a inference graph which represents a training algorithm's step. Given required inputs, it computes outputs to update tensors in its own or in the main computaton graph. `update_binding` contains a key-value pair of strings to assign the outputs to some specific tensors. In general, this graph contains loss node, gradient node, optimizer node, increment of iteration count, and some calls to the inference graph. The field algorithm.node is the only place the user can use GraphCall operator. #### Loss node - `NegativeLogLikelihoodLoss` - `SoftmaxCrossEntropyLoss` #### Optimizer node - `Adagrad` - `Adam` - `Momentum`: SG with standard momentum #### Gradient node The gradient node actually only defines the necessary information to compute the gradient for all graph, for example, at the following graph, the gradient defines its inputs containing the `xs`(intermidate weights) and `zs`(input of the graph), and `y`(the output of the graph), and its outputs having `dY/dW`, `dY/dZ` whose order corresponds to the inputs in `xs`. It doesn't defines any logic about how to compute the `dY/dW`, `dY/dZ`. ``` W --> Conv --> H --> Gemm --> Y | ^ ^ | | | | X Z | | | | | .----------' | | | (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in | | | "xs" followed by "zs") | v v '---> Gradient(xs=["W", "Z"], zs=["X"], y="Y") | | | '-----------------------------------> dY/dW (1st output of Gradient) | '---------------------------------------> dY/dZ (2nd output of Gradient) ``` #### GraphCall node The GraphCall operator invokes a graph inside TrainingInfoProto's algorithm field. The GraphCall inputs and outputs are bound to those of invoked graph by position. Based on the above inference graph, the GraphCall can use like this: ``` .-------- W (a global and mutable variable from | | the inference graph) | | | .-----'-----------. | | | | | v | | .-- X_1 --> GraphCall(graph_name="MyInferenceGraph") | | | | | | | | | | | | | Z_1 -----' | | | | | V | | | | Y_1 ---> Loss ---> O | | | | ^ | | | | | | | `--. | C | | | | | | | | | .----------------' | | | | | | | v v v | `--> Gradient(xs=["W"], zs=["X_1", "Z_1", "C"], y="O") | | | v | dO_dW (gradient of W) 1 (a scalar one) | | | | V v | Div <--- T ------------> Add ---> T_new | | (T is the number of training iterations. | | T is also globally visible and mutable.) | v `-----> Sub ----> W_new ``` The previous section's inference graph is called by `GraphCall(graph_name="MyInferenceGraph")`, and it uses a new batch of inputs (`X_1`, `Z_1`) to compute `Y_1`. `Gradient` defines the graidents the graph should compute, finally, it gets `W_new` amd `T_new`. The it uses the following `update_binding` to udpate the tensors: ``` update_binding: {"W": "W_new", "T": "T_new"} ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
