Just to follow up on what @tqchen summarized previously, here's my
understanding:
### frontend converters
We want users who write frontend converters be aware that certain operators are
stateful. We can encourage them to write these operations in A1 style. For
instance:
```
def _mx_dropout_train(inputs, attrs, module):
rate = attrs.get_float("p", 0.5)
global_state = module['prng']
state_ref = relay.RefCreate(global_state)
read_state = relay.RefRead(state_ref)
# the dropout_train operator outputs both y and the new state
y_state = _op.nn.dropout_train(inputs[0], read_state, rate=rate)
# write back new state, return y
write_state = relay.RefWrite(state_ref, y_state[1])
y = relay.Let(relay.var('ref_write'), write_state, y_state[0])
return y
```
where `module['prng']` is a global variable representing the PRNG state in the
module. As of now, global variables currently are only used to represent
functions. We need to extend it to represent the random state, too.
### rewriting A1-style programs to A2 -style ones
Let's say we have a function below with stateful ops:
```
def @func1(%x) {
%0 = ref(@prng_state);
%1 = %0^;
%2 = nn.dropout_train(%x, %1, rate=0.7f)
%3 = %2.1;
let %ref_write: () = (%0 := %3);
%2.0
}
```
In the rewriting pass, we detect that the global random state is used, and
replace its references to the following:
```
def @func1_rewritten(%x, %state) {
%2 = nn.dropout_train(%x, %state, rate=0.7f)
(%2.0, %2.1)
}
```
Note that the function output type is changed to a tuple containing the new
state. Meanwhile we need to update all CallNodes for this function accordingly.
Here is another example:
```
def @long_func(%x) {
%0 = ref(@prng_state);
%1 = %0^;
%2 = nn.dropout_train(%x, %1, rate=0.7f)
%3 = %2.1;
%4 = (
let %ref_write1: () = (%0 := %3);
%2.0
);
%5 = %0^;
%6 = nn.dropout_train(%4, %5, rate=0.1f)
%7 = %6.1;
let %ref_write: () = (%0 := %7);
%6.0
}
===>
def @long_func_rewritten(%x, %state) {
%2 = nn.dropout_train(%x, %state, rate=0.7f)
%3 = %2.1;
%4 = %2.0;
%6 = nn.dropout_train(%4, %3, rate=0.1f)
(%6.1, %6.0)
}
```
Note that the pass implementation requires tracking the latest value of the
global variable within each scope. For instance, the program below:
```
def @func2(%x, %y) { # returns tensor
if (%x) {
add(%x, %y)
} else {
func1(%y)
}
}
```
would be rewritten to:
```
def @func2(%x, %y, %state) { # returns (tensor, state) for both branches
if (%x) {
(add(%x, %y), %state) # the original state is also returned
} else {
func1_rewritten(%y, %state) # returns the new state
}
}
```
Since the pass requires evaluations within each scope, it would be easier to
implement the pass after the program is already transformed to the bblock form.
### discussions
what type do we use for the random state?
- option a: use the empty tuple type. The runtime actually uses the global
state, and it relies on the deterministic execution order of the program to
ensure reproducibility.
- option b: add a new type (e.g. TypeRandState), and the random state Object
actually carries the data structure used for generating random numbers (e.g.
`std::mt19937`). The state is passed around in the program, and invoking an
operator with the same state object always leads to the same deterministic
outputs.
@junrushao1994 @haichen @MarisaKirisame @ziheng would you like to provide some
suggestions/comments?
---
[Visit
Topic](https://discuss.tvm.ai/t/rfc-handling-effect-in-tvm-and-relay/5946/22)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.ai/email/unsubscribe/a86d24207aa1a7a2a8338af75e38b548746c7fb2bc863f1801881c09812a82bf).