This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f84cca2 [SPARK-28234][CORE][PYTHON] Add python and JavaSparkContext support to get resources f84cca2 is described below commit f84cca2d84d67cee2877092d0354cf111c95eb8e Author: Thomas Graves <tgra...@nvidia.com> AuthorDate: Thu Jul 11 09:32:58 2019 +0900 [SPARK-28234][CORE][PYTHON] Add python and JavaSparkContext support to get resources ## What changes were proposed in this pull request? Add python api support and JavaSparkContext support for resources(). I needed the JavaSparkContext support for it to properly translate into python with the py4j stuff. ## How was this patch tested? Unit tests added and manually tested in local cluster mode and on yarn. Closes #25087 from tgravescs/SPARK-28234-python. Authored-by: Thomas Graves <tgra...@nvidia.com> Signed-off-by: HyukjinKwon <gurwls...@apache.org> --- .../apache/spark/api/java/JavaSparkContext.scala | 3 ++ .../org/apache/spark/api/python/PythonRunner.scala | 10 +++++ python/pyspark/__init__.py | 3 +- python/pyspark/context.py | 12 ++++++ python/pyspark/resourceinformation.py | 43 ++++++++++++++++++++++ python/pyspark/taskcontext.py | 8 ++++ python/pyspark/tests/test_context.py | 35 +++++++++++++++++- python/pyspark/tests/test_taskcontext.py | 38 +++++++++++++++++++ python/pyspark/worker.py | 11 ++++++ 9 files changed, 161 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index c5ef190..330c2f6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -35,6 +35,7 @@ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD} +import org.apache.spark.resource.ResourceInformation /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns @@ -114,6 +115,8 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable { def appName: String = sc.appName + def resources: JMap[String, ResourceInformation] = sc.resources.asJava + def jars: util.List[String] = sc.jars.asJava def startTime: java.lang.Long = sc.startTime diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 414d208..dc6c596 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -281,6 +281,16 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) dataOut.writeLong(context.taskAttemptId()) + val resources = context.resources() + dataOut.writeInt(resources.size) + resources.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v.name, dataOut) + dataOut.writeInt(v.addresses.size) + v.addresses.foreach { case addr => + PythonRDD.writeUTF(addr, dataOut) + } + } val localProps = context.getLocalProperties.asScala dataOut.writeInt(localProps.size) localProps.foreach { case (k, v) => diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index ee153af..70c0b27 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -54,6 +54,7 @@ from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast +from pyspark.resourceinformation import ResourceInformation from pyspark.serializers import MarshalSerializer, PickleSerializer from pyspark.status import * from pyspark.taskcontext import TaskContext, BarrierTaskContext, BarrierTaskInfo @@ -118,5 +119,5 @@ __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext", - "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", + "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", "ResourceInformation", ] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 69020e6..8d28488 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -37,6 +37,7 @@ from pyspark.java_gateway import launch_gateway, local_connect_and_auth from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream from pyspark.storagelevel import StorageLevel +from pyspark.resourceinformation import ResourceInformation from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker @@ -1107,6 +1108,17 @@ class SparkContext(object): conf.setAll(self._conf.getAll()) return conf + @property + def resources(self): + resources = {} + jresources = self._jsc.resources() + for x in jresources: + name = jresources[x].name() + jaddresses = jresources[x].addresses() + addrs = [addr for addr in jaddresses] + resources[name] = ResourceInformation(name, addrs) + return resources + def _test(): import atexit diff --git a/python/pyspark/resourceinformation.py b/python/pyspark/resourceinformation.py new file mode 100644 index 0000000..aaed213 --- /dev/null +++ b/python/pyspark/resourceinformation.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +class ResourceInformation(object): + + """ + .. note:: Evolving + + Class to hold information about a type of Resource. A resource could be a GPU, FPGA, etc. + The array of addresses are resource specific and its up to the user to interpret the address. + + One example is GPUs, where the addresses would be the indices of the GPUs + + @param name the name of the resource + @param addresses an array of strings describing the addresses of the resource + """ + + def __init__(self, name, addresses): + self._name = name + self._addresses = addresses + + @property + def name(self): + return self._name + + @property + def addresses(self): + return self._addresses diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 6d28491..790de0b 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -38,6 +38,7 @@ class TaskContext(object): _stageId = None _taskAttemptId = None _localProperties = None + _resources = None def __new__(cls): """Even if users construct TaskContext instead of using get, give them the singleton.""" @@ -95,6 +96,13 @@ class TaskContext(object): """ return self._localProperties.get(key, None) + def resources(self): + """ + Resources allocated to the task. The key is the resource name and the value is information + about the resource. + """ + return self._resources + BARRIER_FUNCTION = 1 diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index 4048ac5..bcd5d06 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -16,13 +16,14 @@ # import os import shutil +import stat import tempfile import threading import time import unittest from collections import namedtuple -from pyspark import SparkFiles, SparkContext +from pyspark import SparkConf, SparkFiles, SparkContext from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME @@ -256,6 +257,38 @@ class ContextTests(unittest.TestCase): SparkContext(gateway=mock_insecure_gateway) self.assertIn("insecure Py4j gateway", str(context.exception)) + def test_resources(self): + """Test the resources are empty by default.""" + with SparkContext() as sc: + resources = sc.resources + self.assertEqual(len(resources), 0) + + +class ContextTestsWithResources(unittest.TestCase): + + def setUp(self): + class_name = self.__class__.__name__ + self.tempFile = tempfile.NamedTemporaryFile(delete=False) + self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}') + self.tempFile.close() + os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | + stat.S_IROTH | stat.S_IXOTH) + conf = SparkConf().set("spark.driver.resource.gpu.amount", "1") + conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name) + self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf) + + def test_resources(self): + """Test the resources are available.""" + resources = self.sc.resources + self.assertEqual(len(resources), 1) + self.assertTrue('gpu' in resources) + self.assertEqual(resources['gpu'].name, 'gpu') + self.assertEqual(resources['gpu'].addresses, ['0']) + + def tearDown(self): + os.unlink(self.tempFile.name) + self.sc.stop() + if __name__ == "__main__": from pyspark.tests.test_context import * diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index d7d1d80..66357b6 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -16,7 +16,9 @@ # import os import random +import stat import sys +import tempfile import time import unittest @@ -43,6 +45,15 @@ class TaskContextTests(PySparkTestCase): self.assertEqual(stage1 + 2, stage3) self.assertEqual(stage2 + 1, stage3) + def test_resources(self): + """Test the resources are empty by default.""" + rdd = self.sc.parallelize(range(10)) + resources1 = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0] + # Test using the constructor directly rather than the get() + resources2 = rdd.map(lambda x: TaskContext().resources()).take(1)[0] + self.assertEqual(len(resources1), 0) + self.assertEqual(len(resources2), 0) + def test_partition_id(self): """Test the partition id.""" rdd1 = self.sc.parallelize(range(10), 1) @@ -174,6 +185,33 @@ class TaskContextTestsWithWorkerReuse(unittest.TestCase): self.sc.stop() +class TaskContextTestsWithResources(unittest.TestCase): + + def setUp(self): + class_name = self.__class__.__name__ + self.tempFile = tempfile.NamedTemporaryFile(delete=False) + self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}') + self.tempFile.close() + os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | + stat.S_IROTH | stat.S_IXOTH) + conf = SparkConf().set("spark.task.resource.gpu.amount", "1") + conf = conf.set("spark.executor.resource.gpu.amount", "1") + conf = conf.set("spark.executor.resource.gpu.discoveryScript", self.tempFile.name) + self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf) + + def test_resources(self): + """Test the resources are available.""" + rdd = self.sc.parallelize(range(10)) + resources = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0] + self.assertEqual(len(resources), 1) + self.assertTrue('gpu' in resources) + self.assertEqual(resources['gpu'].name, 'gpu') + self.assertEqual(resources['gpu'].addresses, ['0']) + + def tearDown(self): + os.unlink(self.tempFile.name) + self.sc.stop() + if __name__ == "__main__": import unittest from pyspark.tests.test_taskcontext import * diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b34abd0..7f38c27 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -35,6 +35,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.java_gateway import local_connect_and_auth from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.files import SparkFiles +from pyspark.resourceinformation import ResourceInformation from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ @@ -435,6 +436,16 @@ def main(infile, outfile): taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) + taskContext._resources = {} + for r in range(read_int(infile)): + key = utf8_deserializer.loads(infile) + name = utf8_deserializer.loads(infile) + addresses = [] + taskContext._resources = {} + for a in range(read_int(infile)): + addresses.append(utf8_deserializer.loads(infile)) + taskContext._resources[key] = ResourceInformation(name, addresses) + taskContext._localProperties = dict() for i in range(read_int(infile)): k = utf8_deserializer.loads(infile) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org