Github user mengxr commented on a diff in the pull request: https://github.com/apache/spark/pull/22085#discussion_r209846015 --- Diff: python/pyspark/taskcontext.py --- @@ -95,3 +95,92 @@ def getLocalProperty(self, key): Get a local property set upstream in the driver, or None if it is missing. """ return self._localProperties.get(key, None) + + +class BarrierTaskContext(TaskContext): + + """ + .. note:: Experimental + + A TaskContext with extra info and tooling for a barrier stage. To access the BarrierTaskContext + for a running task, use: + L{BarrierTaskContext.get()}. + + .. versionadded:: 2.4.0 + """ + + _barrierContext = None + + def __init__(self): + """Construct a BarrierTaskContext, use get instead""" + pass + + @classmethod + def _getOrCreate(cls): + """Internal function to get or create global BarrierTaskContext.""" + if cls._taskContext is None: + cls._taskContext = BarrierTaskContext() + return cls._taskContext + + @classmethod + def get(cls): + """ + Return the currently active BarrierTaskContext. This can be called inside of user functions + to access contextual information about running tasks. + + .. note:: Must be called on the worker, not the driver. Returns None if not initialized. + """ + return cls._taskContext + + @classmethod + def _initialize(cls, ctx): + """ + Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called + after BarrierTaskContext is initialized. + """ + cls._barrierContext = ctx + + def barrier(self): + """ + .. note:: Experimental + + Sets a global barrier and waits until all tasks in this stage hit this barrier. + Note this method is only allowed for a BarrierTaskContext. + + .. versionadded:: 2.4.0 + """ + if self._barrierContext is None: + raise Exception("Not supported to call barrier() before initialize " + + "BarrierTaskContext.") + else: + self._barrierContext.barrier() + + def getTaskInfos(self): + """ + .. note:: Experimental + + Returns the all task infos in this barrier stage, the task infos are ordered by + partitionId. + Note this method is only allowed for a BarrierTaskContext. + + .. versionadded:: 2.4.0 + """ + if self._barrierContext is None: + raise Exception("Not supported to call getTaskInfos() before initialize " + + "BarrierTaskContext.") + else: + java_list = self._barrierContext.getTaskInfos() + return [BarrierTaskInfo(h) for h in java_list] + + +class BarrierTaskInfo(object): + """ + .. note:: Experimental + + Carries all task infos of a barrier task. + + .. versionadded:: 2.4.0 + """ + + def __init__(self, info): + self.address = info.address --- End diff -- * should be `info.address` * better to rename `info` to `jobj` to make it clear this is from Java
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org