ptrendx opened a new pull request #18622:
URL: https://github.com/apache/incubator-mxnet/pull/18622


   ## Description ##
   As described in 
https://github.com/apache/incubator-mxnet/issues/18280#issuecomment-627010252, 
MXNet currently contains too many CUDA kernels, that affect negatively compile 
time, size of the resulting binary (resulting in issues like #17045 and 
#18205), and GPU memory consumption (as all of those kernels need to be loaded 
during the first GPU context creation to GPU memory).
   
   The reason of those problems is the number of templates that need to be 
instantiated, especially in the case of NumPy operators which need to accept 
different input/output types - this results in multiple nested 
`MSHADOW_TYPE_SWITCH` macros and great increase in the number of kernels 
generated, most of them pretty much never used. For example, executing this 
command:
   ```
   cuobjdump -symbols -arch sm_70 
/usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so | grep GLOBAL | wc -l
   ```
   on the nightly build of mxnet-cu102 from 6/25 shows 69169 kernels (the same 
command executed on the library built with this PR at the time of writing gives 
51511 kernels).
   
   The proposed approach is to use RTC (runtime compilation) in order to 
generate the needed kernels at runtime. This saves the ahead-of-time 
compilation time and binary size as well as the GPU memory utilization (since 
only the needed kernels are generated, not all combinations).
   
   This PR uses that approach to handle elementwise and broadcast kernels (as 
well as their backward), which constitute a big portion of the total number of 
kernels in MXNet.
   
   FYI @leezu @sxjscience @eric-haibin-lin 
   
   ## Checklist ##
   ### Essentials ###
   Please feel free to remove inapplicable items for your PR.
   - [ ] Changes are complete (i.e. I finished coding on this PR)
   - [x] All changes have test coverage
   - [ ] Code is well-documented: 
   - For new C++ functions in header files, their functionalities and arguments 
are documented. 
   - [x] To the best of my knowledge, examples are either not affected by this 
change, or have been fixed to be compatible with this change
   
   ### Changes ###
    - RTC is now required for using CUDA in MXNet
    - Unary, binary, binary with scalar, binary broadcast ops and their 
backward counterparts were changed to use RTC 
   
   ## Comments ##
   - Things left to do:
   [ ] Test the performance impact of cache lookup for kernel code
   [ ] Convert `MixedUnaryBackward` functions
   [ ] Update PR description with the change in GPU memory utilization and 
binary size resulting from this PR
    - After this PR the next step would be to use the same approach for reduce 
kernels - this PR already contains a ground work for this as reduction was 
needed for backward of broadcast ops, but it does not apply that path to 
standalone reduction ops. Grepping for `reduce_kernel` in the symbols visible 
in libmxnet.so after application of this PR:
   ```
   cuobjdump -symbols -arch sm_70 
/usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so | grep GLOBAL | grep 
reduce_kernel | wc -l
   ```
   gives 12057 entries. This would also help with reducing the amount of code 
duplication that this PR introduces (to maintain both RTC and non-RTC paths).
    


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to