Fokko commented on code in PR #6525:
URL: https://github.com/apache/iceberg/pull/6525#discussion_r1064448040
##########
python/pyiceberg/avro/resolver.py:
##########
@@ -109,38 +109,46 @@ def resolve(
class SchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]):
- read_types: Optional[Dict[int, Callable[[Schema], StructProtocol]]]
+ read_types: Dict[int, Type[StructProtocol]]
+ context: List[int]
- def __init__(self, read_types: Optional[Dict[int, Callable[[Schema],
StructProtocol]]]):
+ def __init__(self, read_types: Dict[int, Type[StructProtocol]] =
EMPTY_DICT) -> None:
self.read_types = read_types
+ self.context = []
def schema(self, schema: Schema, expected_schema: Optional[IcebergType],
result: Reader) -> Reader:
return result
+ def before_field(self, field: NestedField, field_partner:
Optional[NestedField]) -> None:
+ self.context.append(field.field_id)
+
+ def after_field(self, field: NestedField, field_partner:
Optional[NestedField]) -> None:
+ self.context.pop()
+
def struct(self, struct: StructType, expected_struct:
Optional[IcebergType], field_readers: List[Reader]) -> Reader:
+ # -1 indicates the struct root
+ read_struct_id = self.context[-1] if len(self.context) > 0 else -1
+ struct_callable = self.read_types.get(read_struct_id, Record)
+
if not expected_struct:
- return StructReader(tuple(enumerate(field_readers)))
+ return StructReader(tuple(enumerate(field_readers)),
struct_callable)
if not isinstance(expected_struct, StructType):
raise ResolveError(f"File/read schema are not aligned for struct,
got {expected_struct}")
- results: List[Tuple[Optional[int], Reader]] = []
expected_positions: Dict[int, int] = {field.field_id: pos for pos,
field in enumerate(expected_struct.fields)}
# first, add readers for the file fields that must be in order
- for field, result_reader in zip(struct.fields, field_readers):
- read_pos = expected_positions.get(field.field_id)
- results.append((read_pos, result_reader))
+ results: List[Tuple[Optional[int], Reader]] = [
+ (expected_positions.get(field.field_id), result_reader) for field,
result_reader in zip(struct.fields, field_readers)
+ ]
file_fields = {field.field_id: field for field in struct.fields}
- for pos, read_field in enumerate(expected_struct.fields):
- if read_field.field_id not in file_fields:
- if read_field.required:
- raise ResolveError(f"{read_field} is non-optional, and not
part of the file schema")
- # Just set the new field to None
- results.append((pos, NoneReader()))
Review Comment:
Also added a test
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]