Fokko commented on code in PR #6997:
URL: https://github.com/apache/iceberg/pull/6997#discussion_r1124333058


##########
python/pyiceberg/io/pyarrow.py:
##########
@@ -476,6 +483,217 @@ def expression_to_pyarrow(expr: BooleanExpression) -> 
pc.Expression:
     return boolean_expression_visit(expr, _ConvertToArrowExpression())
 
 
+def pyarrow_to_schema(schema: pa.Schema) -> Schema:
+    return visit_arrow_schema(schema, _ConvertToIceberg())
+
+
+def visit_arrow_schema(obj: pa.Schema, visitor: ArrowSchemaVisitor[T]) -> 
Schema:
+    struct_results = []
+    for i in range(len(obj.names)):
+        field = obj.field(i)
+        visitor.before_field(field)
+        struct_result = visit_arrow(field.type, visitor)
+        visitor.after_field(field)
+        struct_results.append(struct_result)
+
+    return visitor.schema(obj, struct_results)
+
+
+def visit_arrow(obj: pa.DataType, visitor: ArrowSchemaVisitor[T]) -> T:
+    if pa.types.is_struct(obj):
+        return visit_arrow_struct(obj, visitor)
+    elif pa.types.is_list(obj):
+        return visit_arrow_list(obj, visitor)
+    elif pa.types.is_map(obj):
+        return visit_arrow_map(obj, visitor)
+    else:
+        return visit_arrow_primitive(obj, visitor)
+
+
+def visit_arrow_struct(obj: pa.DataType, visitor: ArrowSchemaVisitor[T]) -> T:
+    if not pa.types.is_struct(obj):
+        raise TypeError(f"Expected struct type, got {type(obj)}")
+    obj = cast(pa.StructType, obj)
+    struct_results = []
+    for field in obj:
+        visitor.before_field(field)
+        struct_result = visit_arrow(field.type, visitor)
+        visitor.after_field(field)
+        struct_results.append(struct_result)
+
+    return visitor.struct(obj, struct_results)
+
+
+def visit_arrow_list(obj: pa.DataType, visitor: ArrowSchemaVisitor[T]) -> T:
+    if not pa.types.is_list(obj):
+        raise TypeError(f"Expected list type, got {type(obj)}")
+    obj = cast(pa.ListType, obj)
+    visitor.before_list_element(obj.value_field)
+    list_result = visit_arrow(obj.value_field.type, visitor)
+    visitor.after_list_element(obj.value_field)
+    return visitor.list(obj, list_result)
+
+
+def visit_arrow_map(obj: pa.DataType, visitor: ArrowSchemaVisitor[T]) -> T:
+    if not pa.types.is_map(obj):
+        raise TypeError(f"Expected map type, got {type(obj)}")
+    obj = cast(pa.MapType, obj)
+    visitor.before_map_key(obj.key_field)
+    key_result = visit_arrow(obj.key_field.type, visitor)
+    visitor.after_map_key(obj.key_field)
+    visitor.before_map_value(obj.item_field)
+    value_result = visit_arrow(obj.item_field.type, visitor)
+    visitor.after_map_value(obj.item_field)
+    return visitor.map(obj, key_result, value_result)
+
+
+def visit_arrow_primitive(obj: pa.DataType, visitor: ArrowSchemaVisitor[T]) -> 
T:
+    if pa.types.is_nested(obj):
+        raise TypeError(f"Expected primitive type, got {type(obj)}")
+    return visitor.primitive(obj)
+
+
+class ArrowSchemaVisitor(Generic[T], ABC):
+    def before_field(self, field: pa.Field) -> None:
+        """Override this method to perform an action immediately before 
visiting a field."""
+
+    def after_field(self, field: pa.Field) -> None:
+        """Override this method to perform an action immediately after 
visiting a field."""
+
+    def before_list_element(self, element: pa.Field) -> None:
+        """Override this method to perform an action immediately before 
visiting a list element."""
+
+    def after_list_element(self, element: pa.Field) -> None:
+        """Override this method to perform an action immediately after 
visiting a list element."""
+
+    def before_map_key(self, key: pa.Field) -> None:
+        """Override this method to perform an action immediately before 
visiting a map key."""
+
+    def after_map_key(self, key: pa.Field) -> None:
+        """Override this method to perform an action immediately after 
visiting a map key."""
+
+    def before_map_value(self, value: pa.Field) -> None:
+        """Override this method to perform an action immediately before 
visiting a map value."""
+
+    def after_map_value(self, value: pa.Field) -> None:
+        """Override this method to perform an action immediately after 
visiting a map value."""
+
+    @abstractmethod
+    def schema(self, schema: pa.Schema, field_results: List[T]) -> Schema:
+        """visit a schema"""
+
+    @abstractmethod
+    def struct(self, struct: pa.StructType, field_results: List[T]) -> T:
+        """visit a struct"""
+
+    @abstractmethod
+    def list(self, list_type: pa.ListType, element_result: T) -> T:
+        """visit a list"""
+
+    @abstractmethod
+    def map(self, map_type: pa.MapType, key_result: T, value_result: T) -> T:
+        """visit a map"""
+
+    @abstractmethod
+    def primitive(self, primitive: pa.DataType) -> T:
+        """visit a primitive type"""
+
+
+def _get_field_id(field: pa.Field) -> int:
+    field_metadata = {k.decode(): v.decode() for k, v in 
field.metadata.items()}
+    if field_id := field_metadata.get("PARQUET:field_id"):
+        return field_id
+    raise ValueError(f"Field {field.name} does not have a field_id")
+
+
+class _ConvertToIceberg(ArrowSchemaVisitor[IcebergType], ABC):
+    def schema(self, schema: pa.Schema, field_results: List[IcebergType]) -> 
Schema:
+        fields = []
+        for i in range(len(schema.names)):
+            field = schema.field(i)
+            field_id = _get_field_id(field)
+            field_type = field_results[i]
+            if field_id is not None and field_type is not None:
+                if field.nullable:
+                    fields.append(NestedField(field_id, field.name, 
field_type, False))
+                else:
+                    fields.append(NestedField(field_id, field.name, 
field_type, True))
+        return Schema(*fields)
+
+    def struct(self, struct: pa.StructType, field_results: List[IcebergType]) 
-> IcebergType:
+        fields = []
+        for i in range(struct.num_fields):
+            field = struct[i]
+            field_id = _get_field_id(field)
+            # may need to check doc strings
+            field_type = field_results[i]
+            if field_id is not None and field_type is not None:
+                if field.nullable:
+                    fields.append(NestedField(field_id, field.name, 
field_type, False))
+                else:
+                    fields.append(NestedField(field_id, field.name, 
field_type, True))
+        return StructType(*fields)
+
+    def list(self, list_type: pa.ListType, element_result: IcebergType) -> 
IcebergType:
+        element_field = list_type.value_field
+        element_id = _get_field_id(element_field)
+        if element_id is not None and element_result is not None:
+            if element_field.nullable:
+                return ListType(element_id, element_result, False)
+            else:
+                return ListType(element_id, element_result, True)
+        raise ValueError("List type must have element field")

Review Comment:
   What do you think of including the input type in the exception to help the 
user:
   ```suggestion
           raise ValueError(f"List type must have element field: {list_type}")
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to