ZhennanQin commented on a change in pull request #17265: Add bfloat16 
floating-point format support based on AMP 
URL: https://github.com/apache/incubator-mxnet/pull/17265#discussion_r365519649
 
 

 ##########
 File path: python/mxnet/contrib/amp/amp.py
 ##########
 @@ -43,14 +44,17 @@
 from ... import optimizer as opt
 from .loss_scaler import LossScaler
 
+bfloat16 = np.dtype([('bfloat16', np.uint16)])
 
 Review comment:
   This is a good topic, and I want to have a discussion for this. 
   Currently, MXNet doesn't have its own type system. It's simply using 
Numpy.dtype. Numpy doesn't natively support bfloat16, so we define bfloat16 as 
a numpy customized type.
   Pros: compatible with current design, isinstance(bfloat16, np.dtype) could 
return True.
   cons: bfloat16.name doesn't work, have to use bfloat16.names[0] instead.
   Another solution is, creating MXNet's own data type system, just like 
pytorch and tf. This is a big API change, so we wish this can be done when 
upgrading to MXNet 2.0.
   
   Currently, we prefer this approach to enable bfloat16 in MXNet 1.x, and 
refactor it in MXNet 2.0.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to