This is an automated email from the ASF dual-hosted git repository.

iblis pushed a commit to branch ib/jl-context-num-gpus
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 01de29fa0cca2107b0d92fffb24349f0271af03d
Author: Iblis Lin <ib...@hs.ntnu.edu.tw>
AuthorDate: Sun Sep 22 09:54:22 2019 +0000

    julia: implement `context.num_gpus`
    
    MXNET-1427 #resolve
---
 julia/src/MXNet.jl             |  3 ++-
 julia/src/context.jl           | 11 +++++++++++
 julia/test/unittest/context.jl | 34 ++++++++++++++++++++++++++++++++++
 python/mxnet/context.py        |  2 ++
 4 files changed, 49 insertions(+), 1 deletion(-)

diff --git a/julia/src/MXNet.jl b/julia/src/MXNet.jl
index 89ec88b..5a14639 100644
--- a/julia/src/MXNet.jl
+++ b/julia/src/MXNet.jl
@@ -74,7 +74,8 @@ export Executor,
 # context.jl
 export Context,
        cpu,
-       gpu
+       gpu,
+       num_gpus
 
 # model.jl
 export AbstractModel,
diff --git a/julia/src/context.jl b/julia/src/context.jl
index 71aee30..bce67a5 100644
--- a/julia/src/context.jl
+++ b/julia/src/context.jl
@@ -57,3 +57,14 @@ Get a GPU context with a specific id. The K GPUs on a node 
is typically numbered
 * `dev_id::Integer = 0` the GPU device id.
 """
 gpu(dev_id::Integer = 0) = Context(GPU, dev_id)
+
+"""
+    num_gpus()
+
+Query CUDA for the number of GPUs present.
+"""
+function num_gpus()
+  n = Ref{Cint}()
+  @mxcall :MXGetGPUCount (Ref{Cint},) n
+  n[]
+end
diff --git a/julia/test/unittest/context.jl b/julia/test/unittest/context.jl
new file mode 100644
index 0000000..0a8f086
--- /dev/null
+++ b/julia/test/unittest/context.jl
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+module TestContext
+
+using MXNet
+using Test
+
+function test_num_gpus()
+  @info "Context::num_gpus"
+
+  @test num_gpus() >= 0
+end
+
+@testset "Context Test" begin
+  test_num_gpus()
+end
+
+
+end  # module TestContext
diff --git a/python/mxnet/context.py b/python/mxnet/context.py
index f284e00..f2b0137 100644
--- a/python/mxnet/context.py
+++ b/python/mxnet/context.py
@@ -276,6 +276,7 @@ def num_gpus():
     check_call(_LIB.MXGetGPUCount(ctypes.byref(count)))
     return count.value
 
+
 def gpu_memory_info(device_id=0):
     """Query CUDA for the free and total bytes of GPU global memory.
 
@@ -300,6 +301,7 @@ def gpu_memory_info(device_id=0):
     check_call(_LIB.MXGetGPUMemoryInformation64(dev_id, ctypes.byref(free), 
ctypes.byref(total)))
     return (free.value, total.value)
 
+
 def current_context():
     """Returns the current context.
 

Reply via email to