This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d6786e0 [SPARK-36711][PYTHON] Support multi-index in new syntax d6786e0 is described below commit d6786e036d610476a3be0fca5b16ba819dcbc013 Author: dchvn nguyen <dgd_contribu...@viettel.com.vn> AuthorDate: Tue Oct 5 12:45:16 2021 +0900 [SPARK-36711][PYTHON] Support multi-index in new syntax ### What changes were proposed in this pull request? Support multi-index in new syntax to specify index data type ### Why are the changes needed? Support multi-index in new syntax to specify index data type https://issues.apache.org/jira/browse/SPARK-36707 ### Does this PR introduce _any_ user-facing change? After this PR user can use ``` python >>> ps.DataFrame[[int, int],[int, int]] typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType] >>> arrays = [[1, 1, 2], ['red', 'blue', 'red']] >>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color')) >>> pdf = pd.DataFrame([[1,2,3],[2,3,4],[4,5,6]], index=idx, columns=["a", "b", "c"]) >>> ps.DataFrame[pdf.index.dtypes, pdf.dtypes] typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType] >>> ps.DataFrame[[("index", int), ("index-2", int)], [("id", int), ("A", int)]] typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType] >>> ps.DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, pdf.dtypes)] typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType] ``` ### How was this patch tested? exist tests Closes #34176 from dchvn/SPARK-36711. Authored-by: dchvn nguyen <dgd_contribu...@viettel.com.vn> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/pandas/accessors.py | 54 ++++-- python/pyspark/pandas/frame.py | 24 ++- python/pyspark/pandas/groupby.py | 23 ++- python/pyspark/pandas/tests/test_dataframe.py | 26 +++ python/pyspark/pandas/typedef/typehints.py | 246 +++++++++++++++++--------- 5 files changed, 252 insertions(+), 121 deletions(-) diff --git a/python/pyspark/pandas/accessors.py b/python/pyspark/pandas/accessors.py index 4d40aab..afb3424 100644 --- a/python/pyspark/pandas/accessors.py +++ b/python/pyspark/pandas/accessors.py @@ -34,7 +34,7 @@ from pyspark.pandas.internal import ( InternalFrame, SPARK_INDEX_NAME_FORMAT, SPARK_DEFAULT_SERIES_NAME, - SPARK_DEFAULT_INDEX_NAME, + SPARK_INDEX_NAME_PATTERN, ) from pyspark.pandas.typedef import infer_return_type, DataFrameType, ScalarType, SeriesType from pyspark.pandas.utils import ( @@ -384,8 +384,8 @@ class PandasOnSparkFrameMethods(object): "The given function should specify a frame as its type " "hints; however, the return type was %s." % return_sig ) - index_field = cast(DataFrameType, return_type).index_field - should_retain_index = index_field is not None + index_fields = cast(DataFrameType, return_type).index_fields + should_retain_index = index_fields is not None return_schema = cast(DataFrameType, return_type).spark_type output_func = GroupBy._make_pandas_df_builder_func( @@ -397,12 +397,19 @@ class PandasOnSparkFrameMethods(object): index_spark_columns = None index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None - index_fields = None + if should_retain_index: - index_spark_columns = [scol_for(sdf, index_field.struct_field.name)] - index_fields = [index_field] - if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME: - index_names = [(index_field.struct_field.name,)] + index_spark_columns = [ + scol_for(sdf, index_field.struct_field.name) for index_field in index_fields + ] + + if not any( + [ + SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name) + for index_field in index_fields + ] + ): + index_names = [(index_field.struct_field.name,) for index_field in index_fields] internal = InternalFrame( spark_frame=sdf, index_names=index_names, @@ -680,17 +687,19 @@ class PandasOnSparkFrameMethods(object): ) return first_series(DataFrame(internal)) else: - index_field = cast(DataFrameType, return_type).index_field - index_field = ( - index_field.normalize_spark_type() if index_field is not None else None + index_fields = cast(DataFrameType, return_type).index_fields + index_fields = ( + [index_field.normalize_spark_type() for index_field in index_fields] + if index_fields is not None + else None ) data_fields = [ field.normalize_spark_type() for field in cast(DataFrameType, return_type).data_fields ] - normalized_fields = ([index_field] if index_field is not None else []) + data_fields + normalized_fields = (index_fields if index_fields is not None else []) + data_fields return_schema = StructType([field.struct_field for field in normalized_fields]) - should_retain_index = index_field is not None + should_retain_index = index_fields is not None self_applied = DataFrame(self._psdf._internal.resolved_copy) @@ -711,12 +720,21 @@ class PandasOnSparkFrameMethods(object): index_spark_columns = None index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None - index_fields = None + if should_retain_index: - index_spark_columns = [scol_for(sdf, index_field.struct_field.name)] - index_fields = [index_field] - if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME: - index_names = [(index_field.struct_field.name,)] + index_spark_columns = [ + scol_for(sdf, index_field.struct_field.name) for index_field in index_fields + ] + + if not any( + [ + SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name) + for index_field in index_fields + ] + ): + index_names = [ + (index_field.struct_field.name,) for index_field in index_fields + ] internal = InternalFrame( spark_frame=sdf, index_names=index_names, diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 24ae6b6..9c0f857 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -114,6 +114,7 @@ from pyspark.pandas.internal import ( SPARK_INDEX_NAME_FORMAT, SPARK_DEFAULT_INDEX_NAME, SPARK_DEFAULT_SERIES_NAME, + SPARK_INDEX_NAME_PATTERN, ) from pyspark.pandas.missing.frame import _MissingPandasLikeDataFrame from pyspark.pandas.ml import corr @@ -2511,7 +2512,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})] return_type = infer_return_type(func) require_index_axis = isinstance(return_type, SeriesType) require_column_axis = isinstance(return_type, DataFrameType) - index_field = None + index_fields = None if require_index_axis: if axis != 0: @@ -2536,8 +2537,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})] "hints when axis is 1 or 'column'; however, the return type " "was %s" % return_sig ) - index_field = cast(DataFrameType, return_type).index_field - should_retain_index = index_field is not None + index_fields = cast(DataFrameType, return_type).index_fields + should_retain_index = index_fields is not None data_fields = cast(DataFrameType, return_type).data_fields return_schema = cast(DataFrameType, return_type).spark_type else: @@ -2565,12 +2566,19 @@ defaultdict(<class 'list'>, {'col..., 'col...})] index_spark_columns = None index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None - index_fields = None + if should_retain_index: - index_spark_columns = [scol_for(sdf, index_field.struct_field.name)] - index_fields = [index_field] - if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME: - index_names = [(index_field.struct_field.name,)] + index_spark_columns = [ + scol_for(sdf, index_field.struct_field.name) for index_field in index_fields + ] + + if not any( + [ + SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name) + for index_field in index_fields + ] + ): + index_names = [(index_field.struct_field.name,) for index_field in index_fields] internal = InternalFrame( spark_frame=sdf, index_names=index_names, diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index a61a024..097afb6 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -76,7 +76,7 @@ from pyspark.pandas.internal import ( NATURAL_ORDER_COLUMN_NAME, SPARK_INDEX_NAME_FORMAT, SPARK_DEFAULT_SERIES_NAME, - SPARK_DEFAULT_INDEX_NAME, + SPARK_INDEX_NAME_PATTERN, ) from pyspark.pandas.missing.groupby import ( MissingPandasLikeDataFrameGroupBy, @@ -1252,9 +1252,8 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): if isinstance(return_type, DataFrameType): data_fields = cast(DataFrameType, return_type).data_fields return_schema = cast(DataFrameType, return_type).spark_type - index_field = cast(DataFrameType, return_type).index_field - should_retain_index = index_field is not None - index_fields = [index_field] + index_fields = cast(DataFrameType, return_type).index_fields + should_retain_index = index_fields is not None psdf_from_pandas = None else: should_return_series = True @@ -1329,10 +1328,18 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): ) else: index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None - index_field = index_fields[0] - index_spark_columns = [scol_for(sdf, index_field.struct_field.name)] - if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME: - index_names = [(index_field.struct_field.name,)] + + index_spark_columns = [ + scol_for(sdf, index_field.struct_field.name) for index_field in index_fields + ] + + if not any( + [ + SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name) + for index_field in index_fields + ] + ): + index_names = [(index_field.struct_field.name,) for index_field in index_fields] internal = InternalFrame( spark_frame=sdf, index_names=index_names, diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 32a427a..1ae009c 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -4678,6 +4678,32 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils): actual.columns = ["a", "b"] self.assert_eq(actual, pdf) + arrays = [[1, 2, 3, 4, 5, 6, 7, 8, 9], ["a", "b", "c", "d", "e", "f", "g", "h", "i"]] + idx = pd.MultiIndex.from_arrays(arrays, names=("number", "color")) + pdf = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]}, + index=idx, + ) + psdf = ps.from_pandas(pdf) + + def identify4(x) -> ps.DataFrame[[int, str], [int, List[int]]]: + return x + + actual = psdf.pandas_on_spark.apply_batch(identify4) + actual.index.names = ["number", "color"] + actual.columns = ["a", "b"] + self.assert_eq(actual, pdf) + + def identify5( + x, + ) -> ps.DataFrame[ + [("number", int), ("color", str)], [("a", int), ("b", List[int])] # noqa: F405 + ]: + return x + + actual = psdf.pandas_on_spark.apply_batch(identify5) + self.assert_eq(actual, pdf) + def test_transform_batch(self): pdf = pd.DataFrame( { diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py index 645e5d7..9fe6e3e 100644 --- a/python/pyspark/pandas/typedef/typehints.py +++ b/python/pyspark/pandas/typedef/typehints.py @@ -94,12 +94,12 @@ class SeriesType(Generic[T]): class DataFrameType(object): def __init__( self, - index_field: Optional["InternalField"], + index_fields: Optional[List["InternalField"]], data_fields: List["InternalField"], ): - self.index_field = index_field + self.index_fields = index_fields self.data_fields = data_fields - self.fields = [index_field] + data_fields if index_field is not None else data_fields + self.fields = index_fields + data_fields if isinstance(index_fields, List) else data_fields @property def dtypes(self) -> List[Dtype]: @@ -514,8 +514,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp [dtype('int64'), dtype('int64'), dtype('int64')] >>> inferred.spark_type.simpleString() 'struct<__index_level_0__:bigint,c0:bigint,c1:bigint>' - >>> inferred.index_field - InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true)) + >>> inferred.index_fields + [InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))] >>> def func() -> ps.DataFrame[pdf.index.dtype, pdf.dtypes]: ... pass @@ -524,8 +524,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp [dtype('int64'), dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)] >>> inferred.spark_type.simpleString() 'struct<__index_level_0__:bigint,c0:bigint,c1:bigint>' - >>> inferred.index_field - InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true)) + >>> inferred.index_fields + [InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))] >>> def func() -> ps.DataFrame[ ... ("index", CategoricalDtype(categories=[3, 4, 5], ordered=False)), @@ -536,8 +536,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp [CategoricalDtype(categories=[3, 4, 5], ordered=False), dtype('int64'), dtype('int64')] >>> inferred.spark_type.simpleString() 'struct<index:bigint,id:bigint,A:bigint>' - >>> inferred.index_field - InternalField(dtype=category,struct_field=StructField(index,LongType,true)) + >>> inferred.index_fields + [InternalField(dtype=category,struct_field=StructField(index,LongType,true))] >>> def func() -> ps.DataFrame[ ... (pdf.index.name, pdf.index.dtype), zip(pdf.columns, pdf.dtypes)]: @@ -547,13 +547,13 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp [dtype('int64'), dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)] >>> inferred.spark_type.simpleString() 'struct<__index_level_0__:bigint,a:bigint,b:bigint>' - >>> inferred.index_field - InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true)) + >>> inferred.index_fields + [InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))] """ # We should re-import to make sure the class 'SeriesType' is not treated as a class # within this module locally. See Series.__class_getitem__ which imports this class # canonically. - from pyspark.pandas.internal import InternalField, SPARK_DEFAULT_INDEX_NAME + from pyspark.pandas.internal import InternalField, SPARK_INDEX_NAME_FORMAT from pyspark.pandas.typedef import SeriesType, NameTypeHolder, IndexNameTypeHolder from pyspark.pandas.utils import name_like_string @@ -595,20 +595,26 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp data_parameters = [p for p in parameters if p not in index_parameters] assert len(data_parameters) > 0, "Type hints for data must not be empty." - if len(index_parameters) == 1: - index_name = index_parameters[0].name - index_dtype, index_spark_type = pandas_on_spark_type(index_parameters[0].tpe) - index_field = InternalField( - dtype=index_dtype, - struct_field=types.StructField( - name=index_name if index_name is not None else SPARK_DEFAULT_INDEX_NAME, - dataType=index_spark_type, - ), - ) + index_fields = [] + if len(index_parameters) >= 1: + for level, index_parameter in enumerate(index_parameters): + index_name = index_parameter.name + index_dtype, index_spark_type = pandas_on_spark_type(index_parameter.tpe) + index_fields.append( + InternalField( + dtype=index_dtype, + struct_field=types.StructField( + name=index_name + if index_name is not None + else SPARK_INDEX_NAME_FORMAT(level), + dataType=index_spark_type, + ), + ) + ) else: assert len(index_parameters) == 0 # No type hint for index. - index_field = None + index_fields = None data_dtypes, data_spark_types = zip( *( @@ -636,7 +642,7 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp ) ) - return DataFrameType(index_field=index_field, data_fields=data_fields) + return DataFrameType(index_fields=index_fields, data_fields=data_fields) tpes = pandas_on_spark_type(tpe) if tpes is None: @@ -684,10 +690,10 @@ def create_tuple_for_frame_type(params: Any) -> object: Typing data columns only: - >>> ps.DataFrame[float, float] - typing.Tuple[float, float] - >>> ps.DataFrame[pdf.dtypes] - typing.Tuple[numpy.int64] + >>> ps.DataFrame[float, float] # doctest: +ELLIPSIS + typing.Tuple[...NameType, ...NameType] + >>> ps.DataFrame[pdf.dtypes] # doctest: +ELLIPSIS + typing.Tuple[...NameType] >>> ps.DataFrame["id": int, "A": int] # doctest: +ELLIPSIS typing.Tuple[...NameType, ...NameType] >>> ps.DataFrame[zip(pdf.columns, pdf.dtypes)] # doctest: +ELLIPSIS @@ -696,48 +702,42 @@ def create_tuple_for_frame_type(params: Any) -> object: Typing data columns with an index: >>> ps.DataFrame[int, [int, int]] # doctest: +ELLIPSIS - typing.Tuple[...IndexNameType, int, int] + typing.Tuple[...IndexNameType, ...NameType, ...NameType] >>> ps.DataFrame[pdf.index.dtype, pdf.dtypes] # doctest: +ELLIPSIS - typing.Tuple[...IndexNameType, numpy.int64] + typing.Tuple[...IndexNameType, ...NameType] >>> ps.DataFrame[("index", int), [("id", int), ("A", int)]] # doctest: +ELLIPSIS typing.Tuple[...IndexNameType, ...NameType, ...NameType] >>> ps.DataFrame[(pdf.index.name, pdf.index.dtype), zip(pdf.columns, pdf.dtypes)] ... # doctest: +ELLIPSIS typing.Tuple[...IndexNameType, ...NameType] + + Typing data columns with an Multi-index: + >>> arrays = [[1, 1, 2], ['red', 'blue', 'red']] + >>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color')) + >>> pdf = pd.DataFrame({'a': range(3)}, index=idx) + >>> ps.DataFrame[[int, int], [int, int]] # doctest: +ELLIPSIS + typing.Tuple[...IndexNameType, ...IndexNameType, ...NameType, ...NameType] + >>> ps.DataFrame[pdf.index.dtypes, pdf.dtypes] # doctest: +ELLIPSIS + typing.Tuple[...IndexNameType, ...NameType] + >>> ps.DataFrame[[("index-1", int), ("index-2", int)], [("id", int), ("A", int)]] + ... # doctest: +ELLIPSIS + typing.Tuple[...IndexNameType, ...IndexNameType, ...NameType, ...NameType] + >>> ps.DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, pdf.dtypes)] + ... # doctest: +ELLIPSIS + typing.Tuple[...IndexNameType, ...NameType] """ - return Tuple[extract_types(params)] + return Tuple[_extract_types(params)] -# TODO(SPARK-36708): numpy.typing (numpy 1.21+) support for nested types. -def extract_types(params: Any) -> Tuple: +def _extract_types(params: Any) -> Tuple: origin = params - if isinstance(params, zip): - # Example: - # DataFrame[zip(pdf.columns, pdf.dtypes)] - params = tuple(slice(name, tpe) for name, tpe in params) # type: ignore[misc, has-type] - if isinstance(params, Iterable): - params = tuple(params) - else: - params = (params,) + params = _to_tuple_of_params(params) - if all( - isinstance(param, slice) - and param.start is not None - and param.step is None - and param.stop is not None - for param in params - ): + if _is_named_params(params): # Example: # DataFrame["id": int, "A": int] - new_params = [] - for param in params: - new_param = type("NameType", (NameTypeHolder,), {}) # type: Type[NameTypeHolder] - new_param.name = param.start - # When the given argument is a numpy's dtype instance. - new_param.tpe = param.stop.type if isinstance(param.stop, np.dtype) else param.stop - new_params.append(new_param) - + new_params = _address_named_type_hoders(params, is_index=False) return tuple(new_params) elif len(params) == 2 and isinstance(params[1], (zip, list, pd.Series)): # Example: @@ -745,49 +745,117 @@ def extract_types(params: Any) -> Tuple: # DataFrame[pdf.index.dtype, pdf.dtypes] # DataFrame[("index", int), [("id", int), ("A", int)]] # DataFrame[(pdf.index.name, pdf.index.dtype), zip(pdf.columns, pdf.dtypes)] + # + # DataFrame[[int, int], [int, int]] + # DataFrame[pdf.index.dtypes, pdf.dtypes] + # DataFrame[[("index", int), ("index-2", int)], [("id", int), ("A", int)]] + # DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, pdf.dtypes)] - index_param = params[0] - index_type = type( - "IndexNameType", (IndexNameTypeHolder,), {} - ) # type: Type[IndexNameTypeHolder] - if isinstance(index_param, tuple): - if len(index_param) != 2: - raise TypeError( - "Type hints for index should be specified as " - "DataFrame[('name', type), ...]; however, got %s" % index_param - ) - name, tpe = index_param - else: - name, tpe = None, index_param + index_params = params[0] + + if isinstance(index_params, tuple) and len(index_params) == 2: + index_params = tuple([slice(*index_params)]) + + index_params = _convert_tuples_to_zip(index_params) + index_params = _to_tuple_of_params(index_params) - index_type.name = name - if isinstance(tpe, ExtensionDtype): - index_type.tpe = tpe + if _is_named_params(index_params): + # Example: + # DataFrame[[("id", int), ("A", int)], [int, int]] + new_index_params = _address_named_type_hoders(index_params, is_index=True) + index_types = tuple(new_index_params) else: - index_type.tpe = tpe.type if isinstance(tpe, np.dtype) else tpe + # Exaxmples: + # DataFrame[[float, float], [int, int]] + # DataFrame[pdf.dtypes, [int, int]] + index_types = _address_unnamed_type_holders(index_params, origin, is_index=True) data_types = params[1] - if ( - isinstance(data_types, list) - and len(data_types) >= 1 - and isinstance(data_types[0], tuple) - ): - # Example: - # DataFrame[("index", int), [("id", int), ("A", int)]] - data_types = zip((name for name, _ in data_types), (tpe for _, tpe in data_types)) - return (index_type,) + extract_types(data_types) - elif all(not isinstance(param, slice) and not isinstance(param, Iterable) for param in params): + data_types = _convert_tuples_to_zip(data_types) + + return index_types + _extract_types(data_types) + + else: # Exaxmples: # DataFrame[float, float] # DataFrame[pdf.dtypes] + return _address_unnamed_type_holders(params, origin, is_index=False) + + +def _is_named_params(params: Any) -> Any: + return all( + isinstance(param, slice) and param.step is None and param.stop is not None + for param in params + ) + + +def _address_named_type_hoders(params: Any, is_index: bool) -> Any: + # Example: + # params = (slice("id", int, None), slice("A", int, None)) + new_params = [] + for param in params: + new_param = ( + type("IndexNameType", (IndexNameTypeHolder,), {}) + if is_index + else type("NameType", (NameTypeHolder,), {}) + ) # type: Union[Type[IndexNameTypeHolder], Type[NameTypeHolder]] + new_param.name = param.start + if isinstance(param.stop, ExtensionDtype): + new_param.tpe = param.stop + else: + # When the given argument is a numpy's dtype instance. + new_param.tpe = param.stop.type if isinstance(param.stop, np.dtype) else param.stop + new_params.append(new_param) + return new_params + + +def _to_tuple_of_params(params: Any) -> Any: + """ + >>> _to_tuple_of_params(int) + (<class 'int'>,) + + >>> _to_tuple_of_params([int, int, int]) + (<class 'int'>, <class 'int'>, <class 'int'>) + + >>> arrays = [[1, 1, 2], ['red', 'blue', 'red']] + >>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color')) + >>> pdf = pd.DataFrame([[1, 2], [2, 3], [4, 5]], index=idx, columns=["a", "b"]) + + >>> _to_tuple_of_params(zip(pdf.columns, pdf.dtypes)) + (slice('a', dtype('int64'), None), slice('b', dtype('int64'), None)) + >>> _to_tuple_of_params(zip(pdf.index.names, pdf.index.dtypes)) + (slice('number', dtype('int64'), None), slice('color', dtype('O'), None)) + """ + if isinstance(params, zip): + params = tuple(slice(name, tpe) for name, tpe in params) # type: ignore[misc, has-type] + + if isinstance(params, Iterable): + params = tuple(params) + else: + params = (params,) + return params + + +def _convert_tuples_to_zip(params: Any) -> Any: + if isinstance(params, list) and len(params) >= 1 and isinstance(params[0], tuple): + return zip((name for name, _ in params), (tpe for _, tpe in params)) + return params + + +def _address_unnamed_type_holders(params: Any, origin: Any, is_index: bool) -> Any: + if all(not isinstance(param, slice) and not isinstance(param, Iterable) for param in params): new_types = [] for param in params: + new_type = ( + type("IndexNameType", (IndexNameTypeHolder,), {}) + if is_index + else type("NameType", (NameTypeHolder,), {}) + ) # type: Union[Type[IndexNameTypeHolder], Type[NameTypeHolder]] if isinstance(param, ExtensionDtype): - new_type = type("NameType", (NameTypeHolder,), {}) # type: Type[NameTypeHolder] new_type.tpe = param - new_types.append(new_type) else: - new_types.append(param.type if isinstance(param, np.dtype) else param) + new_type.tpe = param.type if isinstance(param, np.dtype) else param + new_types.append(new_type) return tuple(new_types) else: raise TypeError( @@ -799,7 +867,11 @@ def extract_types(params: Any) -> Tuple: - DataFrame[index_type, [type, ...]] - DataFrame[(index_name, index_type), [(name, type), ...]] - DataFrame[dtype instance, dtypes instance] - - DataFrame[(index_name, index_type), zip(names, types)]\n""" + - DataFrame[(index_name, index_type), zip(names, types)] + - DataFrame[[index_type, ...], [type, ...]] + - DataFrame[[(index_name, index_type), ...], [(name, type), ...]] + - DataFrame[dtypes instance, dtypes instance] + - DataFrame[zip(index_names, index_types), zip(names, types)]\n""" + "However, got %s." % str(origin) ) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org