grygielski opened a new issue #19361:
URL: https://github.com/apache/incubator-mxnet/issues/19361


   ## Problem statement
   Currently in MXNet there is no mechanism of handling denormal floating point 
values ([wikipedia](https://en.wikipedia.org/wiki/Denormal_number)) of 
parameters/inputs/outputs. Such numbers are problematic in terms of 
computations because adding/multiplying them require more CPU instructions than 
on normal floating point numbers. However, they are so close to zero (e.g. 
~1e-30) that most of the times they can be rounded to 0 without any lose in 
model's accuracy.
   
   It can be done simply by checking every single parameter of the model with 
some, small threshold and rounding all parameters below this threshold to 0. It 
adds some overhead to saving/loading parameters and it's not perfect because 
denormal values can be created during inference on input/output values too.
   
   Cleaner solution would to to use hardware features of modern CPUs. Since 
SSE2 extension there are CPU flags that handle denormals automatically. These 
flags are DAZ (denormals-are-zero) and FTZ (flush-to-zero). They can be set 
inside C++ code using intrinsic instructions.
   
   Important point is that denormal values are rather rare since most modern NN 
architectures do not work asymptotically close to 0. However it can happen that 
they will show up in RNN models (because of sigmoid gate activation) or when 
using layers like PReLU 
(https://github.com/apache/incubator-mxnet/issues/19218).
   
   My question here is what is a way of handling such cases preferred by a 
community? I would love to hear your suggestions and opinions about proposed 
solutions.
   
   ## Proposed solutions
   - Simplest one: leave handling denormals as users responsibility. They can 
iterate through parameters by themselves or use some external packages for 
setting CPU flags like https://github.com/chainer/daz.
   **Example** code deleting denormals from PReLU gamma parameter:
   ```Python
   def fix_denorm_params():
       global arg_params
       for key in arg_params.keys():
           if 'prelu' in key:
               gammas = arg_params[key]
               for index, gamma in enumerate(gammas):
                   if abs(gamma) < 1e-20:
                       arg_params[key][index] = 0.
   ```
   **Pros:** simple solution, no change in framework behavior
   **Cons:** users may not be aware of denormals slow-down, require using 
additional code or library, not user-friendly
   
   - Enabling DAZ and FTZ flags by default and do not create Python API on 
that. This is Tensorflow-like solution because they enable these 2 flags and do 
not allow user to change that.
   **Example** code used during execution in Tensorflow:
   ```Cpp
   ScopedFlushDenormal::ScopedFlushDenormal() {
     SetDenormalState(/*flush_zero_mode=*/true, /*denormals_zero_mode=*/true);
   }
   ```
   **Usage:**
   ```Cpp
   EnvThread* CreateThread(std::function<void()> f) {
       return env_->StartThread(thread_options_, name_, [=]() {
         // Set the processor flag to flush denormals to zero.
         port::ScopedFlushDenormal flush;
         // Set the processor rounding mode to ROUND TO NEAREST.
         port::ScopedSetRound round(FE_TONEAREST);
         if (thread_options_.numa_node != port::kNUMANoAffinity) {
           port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
         }
         f();
       });
     }
   ```
   **Pros:** users do not have to worry about denorm cases, no change in 
external API
   **Cons:** sometimes it may lead to wrong results (?), it cannot be switched 
off if needed
   
   - Creating Python API function enabling DAZ and FTZ flags. This is 
PyTorch-like solution since they do not handle denormals by default but user 
can invoke Python function to treat denormals as 0s.
   **Example** from PyTorch documentation:
   https://pytorch.org/docs/stable/generated/torch.set_flush_denormal.html
   **Pros:** users can control behavior of the framework, simple one-line API
   **Cons:** users have to be aware of denormals existence, additional 
functionality in API
   
   - Combination of 2 previous ones: Enable it by default and expose Python API 
function disabling DAZ and FTZ.
   **Pros:** user-friendly solution but allows user to control framework 
behavior if needed
   **Cons:** the most complex solution in terms of implementation


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to