This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d6786e0  [SPARK-36711][PYTHON] Support multi-index in new syntax
d6786e0 is described below

commit d6786e036d610476a3be0fca5b16ba819dcbc013
Author: dchvn nguyen <dgd_contribu...@viettel.com.vn>
AuthorDate: Tue Oct 5 12:45:16 2021 +0900

    [SPARK-36711][PYTHON] Support multi-index in new syntax
    
    ### What changes were proposed in this pull request?
    Support multi-index in new syntax to specify index data type
    
    ### Why are the changes needed?
    Support multi-index in new syntax to specify index data type
    
    https://issues.apache.org/jira/browse/SPARK-36707
    
    ### Does this PR introduce _any_ user-facing change?
    After this PR user can use
    
    ``` python
    >>> ps.DataFrame[[int, int],[int, int]]
    typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, 
pyspark.pandas.typedef.typehints.IndexNameType, 
pyspark.pandas.typedef.typehints.NameType, 
pyspark.pandas.typedef.typehints.NameType]
    
    >>> arrays = [[1, 1, 2], ['red', 'blue', 'red']]
    >>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color'))
    >>> pdf = pd.DataFrame([[1,2,3],[2,3,4],[4,5,6]], index=idx, columns=["a", 
"b", "c"])
    >>> ps.DataFrame[pdf.index.dtypes, pdf.dtypes]
    typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, 
pyspark.pandas.typedef.typehints.IndexNameType, 
pyspark.pandas.typedef.typehints.NameType, 
pyspark.pandas.typedef.typehints.NameType, 
pyspark.pandas.typedef.typehints.NameType]
    
    >>> ps.DataFrame[[("index", int), ("index-2", int)], [("id", int), ("A", 
int)]]
    typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, 
pyspark.pandas.typedef.typehints.IndexNameType, 
pyspark.pandas.typedef.typehints.NameType, 
pyspark.pandas.typedef.typehints.NameType]
    
    >>> ps.DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, 
pdf.dtypes)]
    typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, 
pyspark.pandas.typedef.typehints.IndexNameType, 
pyspark.pandas.typedef.typehints.NameType, 
pyspark.pandas.typedef.typehints.NameType, 
pyspark.pandas.typedef.typehints.NameType]
    
    ```
    
    ### How was this patch tested?
    exist tests
    
    Closes #34176 from dchvn/SPARK-36711.
    
    Authored-by: dchvn nguyen <dgd_contribu...@viettel.com.vn>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/pandas/accessors.py            |  54 ++++--
 python/pyspark/pandas/frame.py                |  24 ++-
 python/pyspark/pandas/groupby.py              |  23 ++-
 python/pyspark/pandas/tests/test_dataframe.py |  26 +++
 python/pyspark/pandas/typedef/typehints.py    | 246 +++++++++++++++++---------
 5 files changed, 252 insertions(+), 121 deletions(-)

diff --git a/python/pyspark/pandas/accessors.py 
b/python/pyspark/pandas/accessors.py
index 4d40aab..afb3424 100644
--- a/python/pyspark/pandas/accessors.py
+++ b/python/pyspark/pandas/accessors.py
@@ -34,7 +34,7 @@ from pyspark.pandas.internal import (
     InternalFrame,
     SPARK_INDEX_NAME_FORMAT,
     SPARK_DEFAULT_SERIES_NAME,
-    SPARK_DEFAULT_INDEX_NAME,
+    SPARK_INDEX_NAME_PATTERN,
 )
 from pyspark.pandas.typedef import infer_return_type, DataFrameType, 
ScalarType, SeriesType
 from pyspark.pandas.utils import (
@@ -384,8 +384,8 @@ class PandasOnSparkFrameMethods(object):
                     "The given function should specify a frame as its type "
                     "hints; however, the return type was %s." % return_sig
                 )
-            index_field = cast(DataFrameType, return_type).index_field
-            should_retain_index = index_field is not None
+            index_fields = cast(DataFrameType, return_type).index_fields
+            should_retain_index = index_fields is not None
             return_schema = cast(DataFrameType, return_type).spark_type
 
             output_func = GroupBy._make_pandas_df_builder_func(
@@ -397,12 +397,19 @@ class PandasOnSparkFrameMethods(object):
 
             index_spark_columns = None
             index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
-            index_fields = None
+
             if should_retain_index:
-                index_spark_columns = [scol_for(sdf, 
index_field.struct_field.name)]
-                index_fields = [index_field]
-                if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
-                    index_names = [(index_field.struct_field.name,)]
+                index_spark_columns = [
+                    scol_for(sdf, index_field.struct_field.name) for 
index_field in index_fields
+                ]
+
+                if not any(
+                    [
+                        
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
+                        for index_field in index_fields
+                    ]
+                ):
+                    index_names = [(index_field.struct_field.name,) for 
index_field in index_fields]
             internal = InternalFrame(
                 spark_frame=sdf,
                 index_names=index_names,
@@ -680,17 +687,19 @@ class PandasOnSparkFrameMethods(object):
                 )
                 return first_series(DataFrame(internal))
             else:
-                index_field = cast(DataFrameType, return_type).index_field
-                index_field = (
-                    index_field.normalize_spark_type() if index_field is not 
None else None
+                index_fields = cast(DataFrameType, return_type).index_fields
+                index_fields = (
+                    [index_field.normalize_spark_type() for index_field in 
index_fields]
+                    if index_fields is not None
+                    else None
                 )
                 data_fields = [
                     field.normalize_spark_type()
                     for field in cast(DataFrameType, return_type).data_fields
                 ]
-                normalized_fields = ([index_field] if index_field is not None 
else []) + data_fields
+                normalized_fields = (index_fields if index_fields is not None 
else []) + data_fields
                 return_schema = StructType([field.struct_field for field in 
normalized_fields])
-                should_retain_index = index_field is not None
+                should_retain_index = index_fields is not None
 
                 self_applied = DataFrame(self._psdf._internal.resolved_copy)
 
@@ -711,12 +720,21 @@ class PandasOnSparkFrameMethods(object):
 
                 index_spark_columns = None
                 index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
-                index_fields = None
+
                 if should_retain_index:
-                    index_spark_columns = [scol_for(sdf, 
index_field.struct_field.name)]
-                    index_fields = [index_field]
-                    if index_field.struct_field.name != 
SPARK_DEFAULT_INDEX_NAME:
-                        index_names = [(index_field.struct_field.name,)]
+                    index_spark_columns = [
+                        scol_for(sdf, index_field.struct_field.name) for 
index_field in index_fields
+                    ]
+
+                    if not any(
+                        [
+                            
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
+                            for index_field in index_fields
+                        ]
+                    ):
+                        index_names = [
+                            (index_field.struct_field.name,) for index_field 
in index_fields
+                        ]
                 internal = InternalFrame(
                     spark_frame=sdf,
                     index_names=index_names,
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 24ae6b6..9c0f857 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -114,6 +114,7 @@ from pyspark.pandas.internal import (
     SPARK_INDEX_NAME_FORMAT,
     SPARK_DEFAULT_INDEX_NAME,
     SPARK_DEFAULT_SERIES_NAME,
+    SPARK_INDEX_NAME_PATTERN,
 )
 from pyspark.pandas.missing.frame import _MissingPandasLikeDataFrame
 from pyspark.pandas.ml import corr
@@ -2511,7 +2512,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
             return_type = infer_return_type(func)
             require_index_axis = isinstance(return_type, SeriesType)
             require_column_axis = isinstance(return_type, DataFrameType)
-            index_field = None
+            index_fields = None
 
             if require_index_axis:
                 if axis != 0:
@@ -2536,8 +2537,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
                         "hints when axis is 1 or 'column'; however, the return 
type "
                         "was %s" % return_sig
                     )
-                index_field = cast(DataFrameType, return_type).index_field
-                should_retain_index = index_field is not None
+                index_fields = cast(DataFrameType, return_type).index_fields
+                should_retain_index = index_fields is not None
                 data_fields = cast(DataFrameType, return_type).data_fields
                 return_schema = cast(DataFrameType, return_type).spark_type
             else:
@@ -2565,12 +2566,19 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
 
             index_spark_columns = None
             index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
-            index_fields = None
+
             if should_retain_index:
-                index_spark_columns = [scol_for(sdf, 
index_field.struct_field.name)]
-                index_fields = [index_field]
-                if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
-                    index_names = [(index_field.struct_field.name,)]
+                index_spark_columns = [
+                    scol_for(sdf, index_field.struct_field.name) for 
index_field in index_fields
+                ]
+
+                if not any(
+                    [
+                        
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
+                        for index_field in index_fields
+                    ]
+                ):
+                    index_names = [(index_field.struct_field.name,) for 
index_field in index_fields]
             internal = InternalFrame(
                 spark_frame=sdf,
                 index_names=index_names,
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index a61a024..097afb6 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -76,7 +76,7 @@ from pyspark.pandas.internal import (
     NATURAL_ORDER_COLUMN_NAME,
     SPARK_INDEX_NAME_FORMAT,
     SPARK_DEFAULT_SERIES_NAME,
-    SPARK_DEFAULT_INDEX_NAME,
+    SPARK_INDEX_NAME_PATTERN,
 )
 from pyspark.pandas.missing.groupby import (
     MissingPandasLikeDataFrameGroupBy,
@@ -1252,9 +1252,8 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
             if isinstance(return_type, DataFrameType):
                 data_fields = cast(DataFrameType, return_type).data_fields
                 return_schema = cast(DataFrameType, return_type).spark_type
-                index_field = cast(DataFrameType, return_type).index_field
-                should_retain_index = index_field is not None
-                index_fields = [index_field]
+                index_fields = cast(DataFrameType, return_type).index_fields
+                should_retain_index = index_fields is not None
                 psdf_from_pandas = None
             else:
                 should_return_series = True
@@ -1329,10 +1328,18 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
                 )
             else:
                 index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
-                index_field = index_fields[0]
-                index_spark_columns = [scol_for(sdf, 
index_field.struct_field.name)]
-                if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
-                    index_names = [(index_field.struct_field.name,)]
+
+                index_spark_columns = [
+                    scol_for(sdf, index_field.struct_field.name) for 
index_field in index_fields
+                ]
+
+                if not any(
+                    [
+                        
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
+                        for index_field in index_fields
+                    ]
+                ):
+                    index_names = [(index_field.struct_field.name,) for 
index_field in index_fields]
                 internal = InternalFrame(
                     spark_frame=sdf,
                     index_names=index_names,
diff --git a/python/pyspark/pandas/tests/test_dataframe.py 
b/python/pyspark/pandas/tests/test_dataframe.py
index 32a427a..1ae009c 100644
--- a/python/pyspark/pandas/tests/test_dataframe.py
+++ b/python/pyspark/pandas/tests/test_dataframe.py
@@ -4678,6 +4678,32 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
             actual.columns = ["a", "b"]
             self.assert_eq(actual, pdf)
 
+        arrays = [[1, 2, 3, 4, 5, 6, 7, 8, 9], ["a", "b", "c", "d", "e", "f", 
"g", "h", "i"]]
+        idx = pd.MultiIndex.from_arrays(arrays, names=("number", "color"))
+        pdf = pd.DataFrame(
+            {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 
2, 1, 0, 0, 0]]},
+            index=idx,
+        )
+        psdf = ps.from_pandas(pdf)
+
+        def identify4(x) -> ps.DataFrame[[int, str], [int, List[int]]]:
+            return x
+
+        actual = psdf.pandas_on_spark.apply_batch(identify4)
+        actual.index.names = ["number", "color"]
+        actual.columns = ["a", "b"]
+        self.assert_eq(actual, pdf)
+
+        def identify5(
+            x,
+        ) -> ps.DataFrame[
+            [("number", int), ("color", str)], [("a", int), ("b", List[int])]  
# noqa: F405
+        ]:
+            return x
+
+        actual = psdf.pandas_on_spark.apply_batch(identify5)
+        self.assert_eq(actual, pdf)
+
     def test_transform_batch(self):
         pdf = pd.DataFrame(
             {
diff --git a/python/pyspark/pandas/typedef/typehints.py 
b/python/pyspark/pandas/typedef/typehints.py
index 645e5d7..9fe6e3e 100644
--- a/python/pyspark/pandas/typedef/typehints.py
+++ b/python/pyspark/pandas/typedef/typehints.py
@@ -94,12 +94,12 @@ class SeriesType(Generic[T]):
 class DataFrameType(object):
     def __init__(
         self,
-        index_field: Optional["InternalField"],
+        index_fields: Optional[List["InternalField"]],
         data_fields: List["InternalField"],
     ):
-        self.index_field = index_field
+        self.index_fields = index_fields
         self.data_fields = data_fields
-        self.fields = [index_field] + data_fields if index_field is not None 
else data_fields
+        self.fields = index_fields + data_fields if isinstance(index_fields, 
List) else data_fields
 
     @property
     def dtypes(self) -> List[Dtype]:
@@ -514,8 +514,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, 
DataFrameType, ScalarTyp
     [dtype('int64'), dtype('int64'), dtype('int64')]
     >>> inferred.spark_type.simpleString()
     'struct<__index_level_0__:bigint,c0:bigint,c1:bigint>'
-    >>> inferred.index_field
-    
InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))
+    >>> inferred.index_fields
+    
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))]
 
     >>> def func() -> ps.DataFrame[pdf.index.dtype, pdf.dtypes]:
     ...     pass
@@ -524,8 +524,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, 
DataFrameType, ScalarTyp
     [dtype('int64'), dtype('int64'), CategoricalDtype(categories=[3, 4, 5], 
ordered=False)]
     >>> inferred.spark_type.simpleString()
     'struct<__index_level_0__:bigint,c0:bigint,c1:bigint>'
-    >>> inferred.index_field
-    
InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))
+    >>> inferred.index_fields
+    
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))]
 
     >>> def func() -> ps.DataFrame[
     ...     ("index", CategoricalDtype(categories=[3, 4, 5], ordered=False)),
@@ -536,8 +536,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, 
DataFrameType, ScalarTyp
     [CategoricalDtype(categories=[3, 4, 5], ordered=False), dtype('int64'), 
dtype('int64')]
     >>> inferred.spark_type.simpleString()
     'struct<index:bigint,id:bigint,A:bigint>'
-    >>> inferred.index_field
-    InternalField(dtype=category,struct_field=StructField(index,LongType,true))
+    >>> inferred.index_fields
+    
[InternalField(dtype=category,struct_field=StructField(index,LongType,true))]
 
     >>> def func() -> ps.DataFrame[
     ...         (pdf.index.name, pdf.index.dtype), zip(pdf.columns, 
pdf.dtypes)]:
@@ -547,13 +547,13 @@ def infer_return_type(f: Callable) -> Union[SeriesType, 
DataFrameType, ScalarTyp
     [dtype('int64'), dtype('int64'), CategoricalDtype(categories=[3, 4, 5], 
ordered=False)]
     >>> inferred.spark_type.simpleString()
     'struct<__index_level_0__:bigint,a:bigint,b:bigint>'
-    >>> inferred.index_field
-    
InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))
+    >>> inferred.index_fields
+    
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))]
     """
     # We should re-import to make sure the class 'SeriesType' is not treated 
as a class
     # within this module locally. See Series.__class_getitem__ which imports 
this class
     # canonically.
-    from pyspark.pandas.internal import InternalField, SPARK_DEFAULT_INDEX_NAME
+    from pyspark.pandas.internal import InternalField, SPARK_INDEX_NAME_FORMAT
     from pyspark.pandas.typedef import SeriesType, NameTypeHolder, 
IndexNameTypeHolder
     from pyspark.pandas.utils import name_like_string
 
@@ -595,20 +595,26 @@ def infer_return_type(f: Callable) -> Union[SeriesType, 
DataFrameType, ScalarTyp
         data_parameters = [p for p in parameters if p not in index_parameters]
         assert len(data_parameters) > 0, "Type hints for data must not be 
empty."
 
-        if len(index_parameters) == 1:
-            index_name = index_parameters[0].name
-            index_dtype, index_spark_type = 
pandas_on_spark_type(index_parameters[0].tpe)
-            index_field = InternalField(
-                dtype=index_dtype,
-                struct_field=types.StructField(
-                    name=index_name if index_name is not None else 
SPARK_DEFAULT_INDEX_NAME,
-                    dataType=index_spark_type,
-                ),
-            )
+        index_fields = []
+        if len(index_parameters) >= 1:
+            for level, index_parameter in enumerate(index_parameters):
+                index_name = index_parameter.name
+                index_dtype, index_spark_type = 
pandas_on_spark_type(index_parameter.tpe)
+                index_fields.append(
+                    InternalField(
+                        dtype=index_dtype,
+                        struct_field=types.StructField(
+                            name=index_name
+                            if index_name is not None
+                            else SPARK_INDEX_NAME_FORMAT(level),
+                            dataType=index_spark_type,
+                        ),
+                    )
+                )
         else:
             assert len(index_parameters) == 0
             # No type hint for index.
-            index_field = None
+            index_fields = None
 
         data_dtypes, data_spark_types = zip(
             *(
@@ -636,7 +642,7 @@ def infer_return_type(f: Callable) -> Union[SeriesType, 
DataFrameType, ScalarTyp
                 )
             )
 
-        return DataFrameType(index_field=index_field, data_fields=data_fields)
+        return DataFrameType(index_fields=index_fields, 
data_fields=data_fields)
 
     tpes = pandas_on_spark_type(tpe)
     if tpes is None:
@@ -684,10 +690,10 @@ def create_tuple_for_frame_type(params: Any) -> object:
 
     Typing data columns only:
 
-        >>> ps.DataFrame[float, float]
-        typing.Tuple[float, float]
-        >>> ps.DataFrame[pdf.dtypes]
-        typing.Tuple[numpy.int64]
+        >>> ps.DataFrame[float, float]  # doctest: +ELLIPSIS
+        typing.Tuple[...NameType, ...NameType]
+        >>> ps.DataFrame[pdf.dtypes]  # doctest: +ELLIPSIS
+        typing.Tuple[...NameType]
         >>> ps.DataFrame["id": int, "A": int]  # doctest: +ELLIPSIS
         typing.Tuple[...NameType, ...NameType]
         >>> ps.DataFrame[zip(pdf.columns, pdf.dtypes)]  # doctest: +ELLIPSIS
@@ -696,48 +702,42 @@ def create_tuple_for_frame_type(params: Any) -> object:
     Typing data columns with an index:
 
         >>> ps.DataFrame[int, [int, int]]  # doctest: +ELLIPSIS
-        typing.Tuple[...IndexNameType, int, int]
+        typing.Tuple[...IndexNameType, ...NameType, ...NameType]
         >>> ps.DataFrame[pdf.index.dtype, pdf.dtypes]  # doctest: +ELLIPSIS
-        typing.Tuple[...IndexNameType, numpy.int64]
+        typing.Tuple[...IndexNameType, ...NameType]
         >>> ps.DataFrame[("index", int), [("id", int), ("A", int)]]  # 
doctest: +ELLIPSIS
         typing.Tuple[...IndexNameType, ...NameType, ...NameType]
         >>> ps.DataFrame[(pdf.index.name, pdf.index.dtype), zip(pdf.columns, 
pdf.dtypes)]
         ... # doctest: +ELLIPSIS
         typing.Tuple[...IndexNameType, ...NameType]
+
+    Typing data columns with an Multi-index:
+        >>> arrays = [[1, 1, 2], ['red', 'blue', 'red']]
+        >>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color'))
+        >>> pdf = pd.DataFrame({'a': range(3)}, index=idx)
+        >>> ps.DataFrame[[int, int], [int, int]]  # doctest: +ELLIPSIS
+        typing.Tuple[...IndexNameType, ...IndexNameType, ...NameType, 
...NameType]
+        >>> ps.DataFrame[pdf.index.dtypes, pdf.dtypes]  # doctest: +ELLIPSIS
+        typing.Tuple[...IndexNameType, ...NameType]
+        >>> ps.DataFrame[[("index-1", int), ("index-2", int)], [("id", int), 
("A", int)]]
+        ... # doctest: +ELLIPSIS
+        typing.Tuple[...IndexNameType, ...IndexNameType, ...NameType, 
...NameType]
+        >>> ps.DataFrame[zip(pdf.index.names, pdf.index.dtypes), 
zip(pdf.columns, pdf.dtypes)]
+        ... # doctest: +ELLIPSIS
+        typing.Tuple[...IndexNameType, ...NameType]
     """
-    return Tuple[extract_types(params)]
+    return Tuple[_extract_types(params)]
 
 
-# TODO(SPARK-36708): numpy.typing (numpy 1.21+) support for nested types.
-def extract_types(params: Any) -> Tuple:
+def _extract_types(params: Any) -> Tuple:
     origin = params
-    if isinstance(params, zip):
-        # Example:
-        #   DataFrame[zip(pdf.columns, pdf.dtypes)]
-        params = tuple(slice(name, tpe) for name, tpe in params)  # type: 
ignore[misc, has-type]
 
-    if isinstance(params, Iterable):
-        params = tuple(params)
-    else:
-        params = (params,)
+    params = _to_tuple_of_params(params)
 
-    if all(
-        isinstance(param, slice)
-        and param.start is not None
-        and param.step is None
-        and param.stop is not None
-        for param in params
-    ):
+    if _is_named_params(params):
         # Example:
         #   DataFrame["id": int, "A": int]
-        new_params = []
-        for param in params:
-            new_param = type("NameType", (NameTypeHolder,), {})  # type: 
Type[NameTypeHolder]
-            new_param.name = param.start
-            # When the given argument is a numpy's dtype instance.
-            new_param.tpe = param.stop.type if isinstance(param.stop, 
np.dtype) else param.stop
-            new_params.append(new_param)
-
+        new_params = _address_named_type_hoders(params, is_index=False)
         return tuple(new_params)
     elif len(params) == 2 and isinstance(params[1], (zip, list, pd.Series)):
         # Example:
@@ -745,49 +745,117 @@ def extract_types(params: Any) -> Tuple:
         #   DataFrame[pdf.index.dtype, pdf.dtypes]
         #   DataFrame[("index", int), [("id", int), ("A", int)]]
         #   DataFrame[(pdf.index.name, pdf.index.dtype), zip(pdf.columns, 
pdf.dtypes)]
+        #
+        #   DataFrame[[int, int], [int, int]]
+        #   DataFrame[pdf.index.dtypes, pdf.dtypes]
+        #   DataFrame[[("index", int), ("index-2", int)], [("id", int), ("A", 
int)]]
+        #   DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, 
pdf.dtypes)]
 
-        index_param = params[0]
-        index_type = type(
-            "IndexNameType", (IndexNameTypeHolder,), {}
-        )  # type: Type[IndexNameTypeHolder]
-        if isinstance(index_param, tuple):
-            if len(index_param) != 2:
-                raise TypeError(
-                    "Type hints for index should be specified as "
-                    "DataFrame[('name', type), ...]; however, got %s" % 
index_param
-                )
-            name, tpe = index_param
-        else:
-            name, tpe = None, index_param
+        index_params = params[0]
+
+        if isinstance(index_params, tuple) and len(index_params) == 2:
+            index_params = tuple([slice(*index_params)])
+
+        index_params = _convert_tuples_to_zip(index_params)
+        index_params = _to_tuple_of_params(index_params)
 
-        index_type.name = name
-        if isinstance(tpe, ExtensionDtype):
-            index_type.tpe = tpe
+        if _is_named_params(index_params):
+            # Example:
+            #   DataFrame[[("id", int), ("A", int)], [int, int]]
+            new_index_params = _address_named_type_hoders(index_params, 
is_index=True)
+            index_types = tuple(new_index_params)
         else:
-            index_type.tpe = tpe.type if isinstance(tpe, np.dtype) else tpe
+            # Exaxmples:
+            #   DataFrame[[float, float], [int, int]]
+            #   DataFrame[pdf.dtypes, [int, int]]
+            index_types = _address_unnamed_type_holders(index_params, origin, 
is_index=True)
 
         data_types = params[1]
-        if (
-            isinstance(data_types, list)
-            and len(data_types) >= 1
-            and isinstance(data_types[0], tuple)
-        ):
-            # Example:
-            #   DataFrame[("index", int), [("id", int), ("A", int)]]
-            data_types = zip((name for name, _ in data_types), (tpe for _, tpe 
in data_types))
-        return (index_type,) + extract_types(data_types)
-    elif all(not isinstance(param, slice) and not isinstance(param, Iterable) 
for param in params):
+        data_types = _convert_tuples_to_zip(data_types)
+
+        return index_types + _extract_types(data_types)
+
+    else:
         # Exaxmples:
         #   DataFrame[float, float]
         #   DataFrame[pdf.dtypes]
+        return _address_unnamed_type_holders(params, origin, is_index=False)
+
+
+def _is_named_params(params: Any) -> Any:
+    return all(
+        isinstance(param, slice) and param.step is None and param.stop is not 
None
+        for param in params
+    )
+
+
+def _address_named_type_hoders(params: Any, is_index: bool) -> Any:
+    # Example:
+    #   params = (slice("id", int, None), slice("A", int, None))
+    new_params = []
+    for param in params:
+        new_param = (
+            type("IndexNameType", (IndexNameTypeHolder,), {})
+            if is_index
+            else type("NameType", (NameTypeHolder,), {})
+        )  # type: Union[Type[IndexNameTypeHolder], Type[NameTypeHolder]]
+        new_param.name = param.start
+        if isinstance(param.stop, ExtensionDtype):
+            new_param.tpe = param.stop
+        else:
+            # When the given argument is a numpy's dtype instance.
+            new_param.tpe = param.stop.type if isinstance(param.stop, 
np.dtype) else param.stop
+        new_params.append(new_param)
+    return new_params
+
+
+def _to_tuple_of_params(params: Any) -> Any:
+    """
+    >>> _to_tuple_of_params(int)
+    (<class 'int'>,)
+
+    >>> _to_tuple_of_params([int, int, int])
+    (<class 'int'>, <class 'int'>, <class 'int'>)
+
+    >>> arrays = [[1, 1, 2], ['red', 'blue', 'red']]
+    >>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color'))
+    >>> pdf = pd.DataFrame([[1, 2], [2, 3], [4, 5]], index=idx, columns=["a", 
"b"])
+
+    >>> _to_tuple_of_params(zip(pdf.columns, pdf.dtypes))
+    (slice('a', dtype('int64'), None), slice('b', dtype('int64'), None))
+    >>> _to_tuple_of_params(zip(pdf.index.names, pdf.index.dtypes))
+    (slice('number', dtype('int64'), None), slice('color', dtype('O'), None))
+    """
+    if isinstance(params, zip):
+        params = tuple(slice(name, tpe) for name, tpe in params)  # type: 
ignore[misc, has-type]
+
+    if isinstance(params, Iterable):
+        params = tuple(params)
+    else:
+        params = (params,)
+    return params
+
+
+def _convert_tuples_to_zip(params: Any) -> Any:
+    if isinstance(params, list) and len(params) >= 1 and isinstance(params[0], 
tuple):
+        return zip((name for name, _ in params), (tpe for _, tpe in params))
+    return params
+
+
+def _address_unnamed_type_holders(params: Any, origin: Any, is_index: bool) -> 
Any:
+    if all(not isinstance(param, slice) and not isinstance(param, Iterable) 
for param in params):
         new_types = []
         for param in params:
+            new_type = (
+                type("IndexNameType", (IndexNameTypeHolder,), {})
+                if is_index
+                else type("NameType", (NameTypeHolder,), {})
+            )  # type: Union[Type[IndexNameTypeHolder], Type[NameTypeHolder]]
             if isinstance(param, ExtensionDtype):
-                new_type = type("NameType", (NameTypeHolder,), {})  # type: 
Type[NameTypeHolder]
                 new_type.tpe = param
-                new_types.append(new_type)
             else:
-                new_types.append(param.type if isinstance(param, np.dtype) 
else param)
+                new_type.tpe = param.type if isinstance(param, np.dtype) else 
param
+            new_types.append(new_type)
         return tuple(new_types)
     else:
         raise TypeError(
@@ -799,7 +867,11 @@ def extract_types(params: Any) -> Tuple:
   - DataFrame[index_type, [type, ...]]
   - DataFrame[(index_name, index_type), [(name, type), ...]]
   - DataFrame[dtype instance, dtypes instance]
-  - DataFrame[(index_name, index_type), zip(names, types)]\n"""
+  - DataFrame[(index_name, index_type), zip(names, types)]
+  - DataFrame[[index_type, ...], [type, ...]]
+  - DataFrame[[(index_name, index_type), ...], [(name, type), ...]]
+  - DataFrame[dtypes instance, dtypes instance]
+  - DataFrame[zip(index_names, index_types), zip(names, types)]\n"""
             + "However, got %s." % str(origin)
         )
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to