jorisvandenbossche commented on code in PR #33948: URL: https://github.com/apache/arrow/pull/33948#discussion_r1108600699
########## python/pyarrow/tests/test_extension_type.py: ########## @@ -1079,3 +1082,282 @@ def test_array_constructor_from_pandas(): pd.Series([1, 2, 3], dtype="category"), type=IntegerType() ) assert result.equals(expected) + + +class FixedShapeTensorType(pa.ExtensionType): + """ + Canonical extension type class for fixed shape tensors. + + Parameters + ---------- + value_type : DataType or Field + The data type of an individual tensor + shape : tuple + Shape of the tensors + dim_names : tuple, default: None + Explicit names of the dimensions. + permutation : tuple, default: None + Indices of the dimensions ordering. + + Examples + -------- + >>> import pyarrow as pa + >>> tensor_type = FixedShapeTensorType(pa.int32(), (2, 2)) + >>> tensor_type + FixedShapeTensorType(FixedSizeListType(fixed_size_list<item: int32>[4])) + >>> pa.register_extension_type(tensor_type) + """ + + def __init__(self, value_type, shape, dim_names=None, permutation=None): + self._value_type = value_type + self._shape = shape + size = math.prod(shape) + self._dim_names = dim_names + self._permutation = permutation + pa.ExtensionType.__init__(self, pa.list_(self._value_type, size), + 'arrow.fixed_size_tensor') + + @property + def value_type(self): + """ + Data type of an individual tensor. + """ + return self._value_type + + @property + def shape(self): + """ + Shape of the tensors. + """ + return self._shape + + @property + def dim_names(self): + """ + Explicit names of the dimensions. + """ + return self._dim_names + + @property + def permutation(self): + """ + Indices of the dimensions ordering. + """ + return self._permutation + + def __arrow_ext_serialize__(self): + metadata = {"shape": str(self._shape), + "dim_names": str(self._dim_names), + "permutation": str(self._permutation)} + return json.dumps(metadata).encode() + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + # return an instance of this subclass given the serialized + # metadata. + assert serialized.decode().startswith('{"shape":') + + metadata = json.loads(serialized.decode()) + shape = ast.literal_eval(metadata['shape']) + dim_names = ast.literal_eval(metadata['dim_names']) + permutation = ast.literal_eval(metadata['permutation']) + + return FixedShapeTensorType(storage_type.value_type, shape, + dim_names, permutation) + + def __arrow_ext_class__(self): + return FixedShapeTensorArray + + +class FixedShapeTensorArray(pa.ExtensionArray): + """ + Canonical extension array class for fixed shape tensors. + + Examples + -------- + Define and register extension type for tensor array + + >>> import pyarrow as pa + >>> tensor_type = FixedShapeTensorType(pa.int32(), (2, 2)) + >>> pa.register_extension_type(tensor_type) + + Create an extension array + + >>> arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]] + >>> storage = pa.array(arr, pa.list_(pa.int32(), 4)) + >>> pa.ExtensionArray.from_storage(tensor_type, storage) + <__main__.FixedShapeTensorArray object at ...> + [ + [ + 1, + 2, + 3, + 4 + ], + [ + 10, + 20, + 30, + 40 + ], + [ + 100, + 200, + 300, + 400 + ] + ] + """ + + def to_numpy_tensor(self): Review Comment: I would personally keep "numpy" in the name, given that it's explicitly a numpy ndarray, and not from some other package -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org