viirya commented on code in PR #295:
URL: https://github.com/apache/iceberg-rust/pull/295#discussion_r1555120078


##########
crates/iceberg/src/arrow.rs:
##########
@@ -113,6 +143,405 @@ impl ArrowReader {
         // TODO: full implementation
         ProjectionMask::all()
     }
+
+    fn get_row_filter(&self, parquet_schema: &SchemaDescriptor) -> 
Result<Option<RowFilter>> {
+        if let Some(predicates) = &self.predicates {
+            let field_id_map = self.build_field_id_map(parquet_schema)?;
+
+            // Collect Parquet column indices from field ids
+            let column_indices = predicates
+                .iter()
+                .map(|predicate| {
+                    let mut collector = CollectFieldIdVisitor { field_ids: 
vec![] };
+                    collector.visit_predicate(predicate).unwrap();
+                    collector
+                        .field_ids
+                        .iter()
+                        .map(|field_id| {
+                            field_id_map.get(field_id).cloned().ok_or_else(|| {
+                                Error::new(ErrorKind::DataInvalid, "Field id 
not found in schema")
+                            })
+                        })
+                        .collect::<Result<Vec<_>>>()
+                })
+                .collect::<Result<Vec<_>>>()?;
+
+            // Convert BoundPredicates to ArrowPredicates
+            let mut arrow_predicates = vec![];
+            for (predicate, columns) in 
predicates.iter().zip(column_indices.iter()) {
+                let mut converter = PredicateConverter {
+                    columns,
+                    projection_mask: ProjectionMask::leaves(parquet_schema, 
columns.clone()),
+                    parquet_schema,
+                    column_map: &field_id_map,
+                };
+                let arrow_predicate = converter.visit_predicate(predicate)?;
+                arrow_predicates.push(arrow_predicate);
+            }
+            Ok(Some(RowFilter::new(arrow_predicates)))
+        } else {
+            Ok(None)
+        }
+    }
+
+    /// Build the map of field id to Parquet column index in the schema.
+    fn build_field_id_map(&self, parquet_schema: &SchemaDescriptor) -> 
Result<HashMap<i32, usize>> {
+        let mut column_map = HashMap::new();
+        for (idx, field) in parquet_schema.columns().iter().enumerate() {
+            let field_type = field.self_type();
+            match field_type {
+                ParquetType::PrimitiveType { basic_info, .. } => {
+                    if !basic_info.has_id() {
+                        return Err(Error::new(
+                            ErrorKind::DataInvalid,
+                            format!(
+                                "Leave column {:?} in schema doesn't have 
field id",
+                                field_type
+                            ),
+                        ));
+                    }
+                    column_map.insert(basic_info.id(), idx);
+                }
+                ParquetType::GroupType { .. } => {
+                    return Err(Error::new(
+                        ErrorKind::DataInvalid,
+                        format!(
+                            "Leave column in schema should be primitive type 
but got {:?}",
+                            field_type
+                        ),
+                    ));
+                }
+            };
+        }
+
+        Ok(column_map)
+    }
+}
+
+/// A visitor to collect field ids from bound predicates.
+struct CollectFieldIdVisitor {
+    field_ids: Vec<i32>,
+}
+
+impl BoundPredicateVisitor for CollectFieldIdVisitor {
+    type T = ();
+    type U = ();
+
+    fn and(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> {
+        Ok(())
+    }
+
+    fn or(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> {
+        Ok(())
+    }
+
+    fn not(&mut self, _predicate: Self::T) -> Result<Self::T> {
+        Ok(())
+    }
+
+    fn visit_always_true(&mut self) -> Result<Self::T> {
+        Ok(())
+    }
+
+    fn visit_always_false(&mut self) -> Result<Self::T> {
+        Ok(())
+    }
+
+    fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> 
Result<Self::T> {
+        self.bound_reference(predicate.term())?;
+        Ok(())
+    }
+
+    fn visit_binary(&mut self, predicate: &BinaryExpression<BoundReference>) 
-> Result<Self::T> {
+        self.bound_reference(predicate.term())?;
+        Ok(())
+    }
+
+    fn visit_set(&mut self, predicate: &SetExpression<BoundReference>) -> 
Result<Self::T> {
+        self.bound_reference(predicate.term())?;
+        Ok(())
+    }
+
+    fn bound_reference(&mut self, reference: &BoundReference) -> 
Result<Self::T> {
+        self.field_ids.push(reference.field().id);
+        Ok(())
+    }
+}
+
+struct PredicateConverter<'a> {
+    pub columns: &'a Vec<usize>,
+    pub projection_mask: ProjectionMask,
+    pub parquet_schema: &'a SchemaDescriptor,
+    pub column_map: &'a HashMap<i32, usize>,
+}
+
+fn get_arrow_datum(datum: &Datum) -> Box<dyn ArrowDatum> {
+    match datum.literal() {
+        PrimitiveLiteral::Boolean(value) => 
Box::new(BooleanArray::new_scalar(*value)),
+        PrimitiveLiteral::Int(value) => 
Box::new(Int32Array::new_scalar(*value)),
+        PrimitiveLiteral::Long(value) => 
Box::new(Int64Array::new_scalar(*value)),
+        PrimitiveLiteral::Float(value) => 
Box::new(Float32Array::new_scalar(value.as_f32())),
+        PrimitiveLiteral::Double(value) => 
Box::new(Float64Array::new_scalar(value.as_f64())),
+        _ => todo!("Unsupported literal type"),
+    }
+}
+
+impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
+    type T = Box<dyn ArrowPredicate>;
+    type U = usize;
+
+    fn visit_always_true(&mut self) -> Result<Self::T> {
+        Ok(Box::new(ArrowPredicateFn::new(
+            self.projection_mask.clone(),
+            |batch| Ok(BooleanArray::from(vec![true; batch.num_rows()])),
+        )))
+    }
+
+    fn visit_always_false(&mut self) -> Result<Self::T> {
+        Ok(Box::new(ArrowPredicateFn::new(
+            self.projection_mask.clone(),
+            |batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])),
+        )))
+    }
+
+    fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> 
Result<Self::T> {
+        let term_index = self.bound_reference(predicate.term())?;
+
+        match predicate.op() {
+            PredicateOperator::IsNull => Ok(Box::new(ArrowPredicateFn::new(
+                self.projection_mask.clone(),
+                move |batch| {
+                    let column = batch.column(term_index);
+                    is_null(column)
+                },
+            ))),
+            PredicateOperator::NotNull => Ok(Box::new(ArrowPredicateFn::new(
+                self.projection_mask.clone(),
+                move |batch| {
+                    let column = batch.column(term_index);

Review Comment:
   > This maybe incorrect for nested column, I think maybe we should either 
return projection_mask for each leave column, or implement a general purpose 
flatten method for struct array.
   
   I tried to change to return projection_mask for each leave column, it is 
pretty straightforward to implement. Please let me know if it looks good to 
you. Thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org
For additional commands, e-mail: issues-h...@iceberg.apache.org

Reply via email to