Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22085#discussion_r209846054
  
    --- Diff: python/pyspark/taskcontext.py ---
    @@ -95,3 +95,92 @@ def getLocalProperty(self, key):
             Get a local property set upstream in the driver, or None if it is 
missing.
             """
             return self._localProperties.get(key, None)
    +
    +
    +class BarrierTaskContext(TaskContext):
    +
    +    """
    +    .. note:: Experimental
    +
    +    A TaskContext with extra info and tooling for a barrier stage. To 
access the BarrierTaskContext
    +    for a running task, use:
    +    L{BarrierTaskContext.get()}.
    +
    +    .. versionadded:: 2.4.0
    +    """
    +
    +    _barrierContext = None
    +
    +    def __init__(self):
    +        """Construct a BarrierTaskContext, use get instead"""
    +        pass
    +
    +    @classmethod
    +    def _getOrCreate(cls):
    +        """Internal function to get or create global BarrierTaskContext."""
    +        if cls._taskContext is None:
    +            cls._taskContext = BarrierTaskContext()
    +        return cls._taskContext
    +
    +    @classmethod
    +    def get(cls):
    +        """
    +        Return the currently active BarrierTaskContext. This can be called 
inside of user functions
    +        to access contextual information about running tasks.
    +
    +        .. note:: Must be called on the worker, not the driver. Returns 
None if not initialized.
    +        """
    +        return cls._taskContext
    +
    +    @classmethod
    +    def _initialize(cls, ctx):
    +        """
    +        Initialize BarrierTaskContext, other methods within 
BarrierTaskContext can only be called
    +        after BarrierTaskContext is initialized.
    +        """
    +        cls._barrierContext = ctx
    +
    +    def barrier(self):
    +        """
    +        .. note:: Experimental
    +
    +        Sets a global barrier and waits until all tasks in this stage hit 
this barrier.
    +        Note this method is only allowed for a BarrierTaskContext.
    +
    +        .. versionadded:: 2.4.0
    +        """
    +        if self._barrierContext is None:
    +            raise Exception("Not supported to call barrier() before 
initialize " +
    +                            "BarrierTaskContext.")
    +        else:
    +            self._barrierContext.barrier()
    +
    +    def getTaskInfos(self):
    +        """
    +        .. note:: Experimental
    +
    +        Returns the all task infos in this barrier stage, the task infos 
are ordered by
    +        partitionId.
    +        Note this method is only allowed for a BarrierTaskContext.
    +
    +        .. versionadded:: 2.4.0
    +        """
    +        if self._barrierContext is None:
    +            raise Exception("Not supported to call getTaskInfos() before 
initialize " +
    +                            "BarrierTaskContext.")
    +        else:
    +            java_list = self._barrierContext.getTaskInfos()
    +            return [BarrierTaskInfo(h) for h in java_list]
    +
    +
    +class BarrierTaskInfo(object):
    +    """
    +    .. note:: Experimental
    +
    +    Carries all task infos of a barrier task.
    +
    +    .. versionadded:: 2.4.0
    +    """
    +
    +    def __init__(self, info):
    +        self.address = info.address
    --- End diff --
    
    * should be `info.address()`
    * better to rename `info` to `jobj` to make it clear this is from Java


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to