This is an automated email from the ASF dual-hosted git repository.

gangwu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-cpp.git


The following commit(s) were added to refs/heads/main by this push:
     new f842411c feat: implement reference visitor (#491)
f842411c is described below

commit f842411c25a7c5b32488998822241841e64558b5
Author: Gang Wu <[email protected]>
AuthorDate: Tue Jan 13 13:46:52 2026 +0800

    feat: implement reference visitor (#491)
---
 src/iceberg/expression/binder.cc            |  63 ++++++++-
 src/iceberg/expression/binder.h             |  30 +++-
 src/iceberg/test/expression_visitor_test.cc | 204 ++++++++++++++++++++++++++++
 3 files changed, 291 insertions(+), 6 deletions(-)

diff --git a/src/iceberg/expression/binder.cc b/src/iceberg/expression/binder.cc
index 43c3ebcd..650dc730 100644
--- a/src/iceberg/expression/binder.cc
+++ b/src/iceberg/expression/binder.cc
@@ -19,6 +19,9 @@
 
 #include "iceberg/expression/binder.h"
 
+#include "iceberg/result.h"
+#include "iceberg/util/macros.h"
+
 namespace iceberg {
 
 Binder::Binder(const Schema& schema, bool case_sensitive)
@@ -54,30 +57,30 @@ Result<std::shared_ptr<Expression>> Binder::Or(
 
 Result<std::shared_ptr<Expression>> Binder::Predicate(
     const std::shared_ptr<UnboundPredicate>& pred) {
-  ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null");
+  ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null");
   return pred->Bind(schema_, case_sensitive_);
 }
 
 Result<std::shared_ptr<Expression>> Binder::Predicate(
     const std::shared_ptr<BoundPredicate>& pred) {
-  ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null");
+  ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null");
   return InvalidExpression("Found already bound predicate: {}", 
pred->ToString());
 }
 
 Result<std::shared_ptr<Expression>> Binder::Aggregate(
     const std::shared_ptr<BoundAggregate>& aggregate) {
-  ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null");
+  ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null");
   return InvalidExpression("Found already bound aggregate: {}", 
aggregate->ToString());
 }
 
 Result<std::shared_ptr<Expression>> Binder::Aggregate(
     const std::shared_ptr<UnboundAggregate>& aggregate) {
-  ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null");
+  ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null");
   return aggregate->Bind(schema_, case_sensitive_);
 }
 
 Result<bool> IsBoundVisitor::IsBound(const std::shared_ptr<Expression>& expr) {
-  ICEBERG_DCHECK(expr != nullptr, "Expression cannot be null");
+  ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null");
   IsBoundVisitor visitor;
   return Visit<bool, IsBoundVisitor>(expr, visitor);
 }
@@ -113,4 +116,54 @@ Result<bool> IsBoundVisitor::Aggregate(
   return false;
 }
 
+Result<std::unordered_set<int32_t>> ReferenceVisitor::GetReferencedFieldIds(
+    const std::shared_ptr<Expression>& expr) {
+  ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null");
+  ReferenceVisitor visitor;
+  return Visit<FieldIdsSetRef, ReferenceVisitor>(expr, visitor);
+}
+
+Result<FieldIdsSetRef> ReferenceVisitor::AlwaysTrue() { return 
referenced_field_ids_; }
+
+Result<FieldIdsSetRef> ReferenceVisitor::AlwaysFalse() { return 
referenced_field_ids_; }
+
+Result<FieldIdsSetRef> ReferenceVisitor::Not(
+    [[maybe_unused]] const FieldIdsSetRef& child_result) {
+  return referenced_field_ids_;
+}
+
+Result<FieldIdsSetRef> ReferenceVisitor::And(
+    [[maybe_unused]] const FieldIdsSetRef& left_result,
+    [[maybe_unused]] const FieldIdsSetRef& right_result) {
+  return referenced_field_ids_;
+}
+
+Result<FieldIdsSetRef> ReferenceVisitor::Or(
+    [[maybe_unused]] const FieldIdsSetRef& left_result,
+    [[maybe_unused]] const FieldIdsSetRef& right_result) {
+  return referenced_field_ids_;
+}
+
+Result<FieldIdsSetRef> ReferenceVisitor::Predicate(
+    const std::shared_ptr<BoundPredicate>& pred) {
+  referenced_field_ids_.insert(pred->reference()->field_id());
+  return referenced_field_ids_;
+}
+
+Result<FieldIdsSetRef> ReferenceVisitor::Predicate(
+    [[maybe_unused]] const std::shared_ptr<UnboundPredicate>& pred) {
+  return InvalidExpression("Cannot get referenced field IDs from unbound 
predicate");
+}
+
+Result<FieldIdsSetRef> ReferenceVisitor::Aggregate(
+    const std::shared_ptr<BoundAggregate>& aggregate) {
+  referenced_field_ids_.insert(aggregate->reference()->field_id());
+  return referenced_field_ids_;
+}
+
+Result<FieldIdsSetRef> ReferenceVisitor::Aggregate(
+    [[maybe_unused]] const std::shared_ptr<UnboundAggregate>& aggregate) {
+  return InvalidExpression("Cannot get referenced field IDs from unbound 
aggregate");
+}
+
 }  // namespace iceberg
diff --git a/src/iceberg/expression/binder.h b/src/iceberg/expression/binder.h
index a78b7a4b..276ab076 100644
--- a/src/iceberg/expression/binder.h
+++ b/src/iceberg/expression/binder.h
@@ -22,6 +22,9 @@
 /// \file iceberg/expression/binder.h
 /// Bind an expression to a schema.
 
+#include <functional>
+#include <unordered_set>
+
 #include "iceberg/expression/expression_visitor.h"
 
 namespace iceberg {
@@ -73,6 +76,31 @@ class ICEBERG_EXPORT IsBoundVisitor : public 
ExpressionVisitor<bool> {
   Result<bool> Aggregate(const std::shared_ptr<UnboundAggregate>& aggregate) 
override;
 };
 
-// TODO(gangwu): add the Java parity `ReferenceVisitor`
+using FieldIdsSetRef = std::reference_wrapper<std::unordered_set<int32_t>>;
+
+/// \brief Visitor to collect referenced field IDs from an expression.
+class ICEBERG_EXPORT ReferenceVisitor : public 
ExpressionVisitor<FieldIdsSetRef> {
+ public:
+  static Result<std::unordered_set<int32_t>> GetReferencedFieldIds(
+      const std::shared_ptr<Expression>& expr);
+
+  Result<FieldIdsSetRef> AlwaysTrue() override;
+  Result<FieldIdsSetRef> AlwaysFalse() override;
+  Result<FieldIdsSetRef> Not(const FieldIdsSetRef& child_result) override;
+  Result<FieldIdsSetRef> And(const FieldIdsSetRef& left_result,
+                             const FieldIdsSetRef& right_result) override;
+  Result<FieldIdsSetRef> Or(const FieldIdsSetRef& left_result,
+                            const FieldIdsSetRef& right_result) override;
+  Result<FieldIdsSetRef> Predicate(const std::shared_ptr<BoundPredicate>& 
pred) override;
+  Result<FieldIdsSetRef> Predicate(
+      const std::shared_ptr<UnboundPredicate>& pred) override;
+  Result<FieldIdsSetRef> Aggregate(
+      const std::shared_ptr<BoundAggregate>& aggregate) override;
+  Result<FieldIdsSetRef> Aggregate(
+      const std::shared_ptr<UnboundAggregate>& aggregate) override;
+
+ private:
+  std::unordered_set<int32_t> referenced_field_ids_;
+};
 
 }  // namespace iceberg
diff --git a/src/iceberg/test/expression_visitor_test.cc 
b/src/iceberg/test/expression_visitor_test.cc
index f2bbe70e..697c0096 100644
--- a/src/iceberg/test/expression_visitor_test.cc
+++ b/src/iceberg/test/expression_visitor_test.cc
@@ -22,6 +22,7 @@
 #include "iceberg/expression/binder.h"
 #include "iceberg/expression/expressions.h"
 #include "iceberg/expression/rewrite_not.h"
+#include "iceberg/result.h"
 #include "iceberg/schema.h"
 #include "iceberg/test/matchers.h"
 #include "iceberg/type.h"
@@ -505,4 +506,207 @@ TEST_F(RewriteNotTest, ComplexExpression) {
   EXPECT_EQ(rewritten->op(), Expression::Operation::kOr);
 }
 
+class ReferenceVisitorTest : public ExpressionVisitorTest {};
+
+TEST_F(ReferenceVisitorTest, Constants) {
+  // Constants should have no referenced fields
+  auto true_expr = Expressions::AlwaysTrue();
+  ICEBERG_UNWRAP_OR_FAIL(auto refs_true,
+                         ReferenceVisitor::GetReferencedFieldIds(true_expr));
+  EXPECT_TRUE(refs_true.empty());
+
+  auto false_expr = Expressions::AlwaysFalse();
+  ICEBERG_UNWRAP_OR_FAIL(auto refs_false,
+                         ReferenceVisitor::GetReferencedFieldIds(false_expr));
+  EXPECT_TRUE(refs_false.empty());
+}
+
+TEST_F(ReferenceVisitorTest, UnboundPredicate) {
+  auto unbound_pred = Expressions::Equal("name", Literal::String("Alice"));
+  auto result = ReferenceVisitor::GetReferencedFieldIds(unbound_pred);
+  EXPECT_THAT(result, IsError(ErrorKind::kInvalidExpression));
+  EXPECT_THAT(result,
+              HasErrorMessage("Cannot get referenced field IDs from unbound 
predicate"));
+}
+
+TEST_F(ReferenceVisitorTest, BoundPredicate) {
+  // Bound predicate should return the field ID
+  auto unbound_pred = Expressions::Equal("name", Literal::String("Alice"));
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_pred, Bind(unbound_pred));
+
+  ICEBERG_UNWRAP_OR_FAIL(auto refs, 
ReferenceVisitor::GetReferencedFieldIds(bound_pred));
+  EXPECT_EQ(refs.size(), 1);
+  EXPECT_EQ(refs.count(2), 1);  // name field has id=2
+}
+
+TEST_F(ReferenceVisitorTest, MultiplePredicates) {
+  // Test various predicates with different fields
+  auto pred_age = Expressions::GreaterThan("age", Literal::Int(25));
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_age, Bind(pred_age));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs_age,
+                         ReferenceVisitor::GetReferencedFieldIds(bound_age));
+  EXPECT_EQ(refs_age.size(), 1);
+  EXPECT_EQ(refs_age.count(3), 1);  // age field has id=3
+
+  auto pred_salary = Expressions::LessThan("salary", Literal::Double(50000.0));
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_salary, Bind(pred_salary));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs_salary,
+                         
ReferenceVisitor::GetReferencedFieldIds(bound_salary));
+  EXPECT_EQ(refs_salary.size(), 1);
+  EXPECT_EQ(refs_salary.count(4), 1);  // salary field has id=4
+}
+
+TEST_F(ReferenceVisitorTest, UnaryPredicates) {
+  // Test unary predicates
+  auto pred_is_null = Expressions::IsNull("name");
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_is_null, Bind(pred_is_null));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs,
+                         
ReferenceVisitor::GetReferencedFieldIds(bound_is_null));
+  EXPECT_EQ(refs.size(), 1);
+  EXPECT_EQ(refs.count(2), 1);
+
+  auto pred_is_nan = Expressions::IsNaN("salary");
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_is_nan, Bind(pred_is_nan));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs_nan,
+                         
ReferenceVisitor::GetReferencedFieldIds(bound_is_nan));
+  EXPECT_EQ(refs_nan.size(), 1);
+  EXPECT_EQ(refs_nan.count(4), 1);
+}
+
+TEST_F(ReferenceVisitorTest, AndExpression) {
+  // AND expression should return union of field IDs from both sides
+  auto pred1 = Expressions::Equal("name", Literal::String("Alice"));
+  auto pred2 = Expressions::GreaterThan("age", Literal::Int(25));
+  auto and_expr = Expressions::And(pred1, pred2);
+
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs, 
ReferenceVisitor::GetReferencedFieldIds(bound_and));
+
+  EXPECT_EQ(refs.size(), 2);
+  EXPECT_EQ(refs.count(2), 1);  // name field
+  EXPECT_EQ(refs.count(3), 1);  // age field
+}
+
+TEST_F(ReferenceVisitorTest, OrExpression) {
+  // OR expression should return union of field IDs from both sides
+  auto pred1 = Expressions::IsNull("salary");
+  auto pred2 = Expressions::Equal("active", Literal::Boolean(true));
+  auto or_expr = Expressions::Or(pred1, pred2);
+
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_or, Bind(or_expr));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs, 
ReferenceVisitor::GetReferencedFieldIds(bound_or));
+
+  EXPECT_EQ(refs.size(), 2);
+  EXPECT_EQ(refs.count(4), 1);  // salary field
+  EXPECT_EQ(refs.count(5), 1);  // active field
+}
+
+TEST_F(ReferenceVisitorTest, NotExpression) {
+  // NOT expression should return field IDs from its child
+  auto pred = Expressions::Equal("name", Literal::String("Alice"));
+  auto not_expr = Expressions::Not(pred);
+
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_not, Bind(not_expr));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs, 
ReferenceVisitor::GetReferencedFieldIds(bound_not));
+
+  EXPECT_EQ(refs.size(), 1);
+  EXPECT_EQ(refs.count(2), 1);  // name field
+}
+
+TEST_F(ReferenceVisitorTest, ComplexNestedExpression) {
+  // (name = 'Alice' AND age > 25) OR (salary < 30000 AND active = true)
+  // Should reference fields: name(2), age(3), salary(4), active(5)
+  auto pred1 = Expressions::Equal("name", Literal::String("Alice"));
+  auto pred2 = Expressions::GreaterThan("age", Literal::Int(25));
+  auto pred3 = Expressions::LessThan("salary", Literal::Double(30000.0));
+  auto pred4 = Expressions::Equal("active", Literal::Boolean(true));
+
+  auto and1 = Expressions::And(pred1, pred2);
+  auto and2 = Expressions::And(pred3, pred4);
+  auto complex_or = Expressions::Or(and1, and2);
+
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_complex, Bind(complex_or));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs,
+                         
ReferenceVisitor::GetReferencedFieldIds(bound_complex));
+
+  EXPECT_EQ(refs.size(), 4);
+  EXPECT_EQ(refs.count(2), 1);  // name field
+  EXPECT_EQ(refs.count(3), 1);  // age field
+  EXPECT_EQ(refs.count(4), 1);  // salary field
+  EXPECT_EQ(refs.count(5), 1);  // active field
+}
+
+TEST_F(ReferenceVisitorTest, DuplicateFieldReferences) {
+  // Multiple predicates referencing the same field
+  // age > 25 AND age < 50
+  auto pred1 = Expressions::GreaterThan("age", Literal::Int(25));
+  auto pred2 = Expressions::LessThan("age", Literal::Int(50));
+  auto and_expr = Expressions::And(pred1, pred2);
+
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs, 
ReferenceVisitor::GetReferencedFieldIds(bound_and));
+
+  // Should only contain the field ID once (set semantics)
+  EXPECT_EQ(refs.size(), 1);
+  EXPECT_EQ(refs.count(3), 1);  // age field
+}
+
+TEST_F(ReferenceVisitorTest, SetPredicates) {
+  // Test In predicate
+  auto pred_in =
+      Expressions::In("age", {Literal::Int(25), Literal::Int(30), 
Literal::Int(35)});
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_in, Bind(pred_in));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs_in, 
ReferenceVisitor::GetReferencedFieldIds(bound_in));
+
+  EXPECT_EQ(refs_in.size(), 1);
+  EXPECT_EQ(refs_in.count(3), 1);  // age field
+
+  // Test NotIn predicate
+  auto pred_not_in =
+      Expressions::NotIn("name", {Literal::String("Alice"), 
Literal::String("Bob")});
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_not_in, Bind(pred_not_in));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs_not_in,
+                         
ReferenceVisitor::GetReferencedFieldIds(bound_not_in));
+
+  EXPECT_EQ(refs_not_in.size(), 1);
+  EXPECT_EQ(refs_not_in.count(2), 1);  // name field
+}
+
+TEST_F(ReferenceVisitorTest, MixedBoundAndUnbound) {
+  auto bound_pred = Expressions::Equal("name", Literal::String("Alice"));
+  ICEBERG_UNWRAP_OR_FAIL(auto pred1, Bind(bound_pred));
+  auto unbound_pred = Expressions::GreaterThan("age", Literal::Int(25));
+  auto mixed_and = Expressions::And(pred1, unbound_pred);
+
+  auto result = ReferenceVisitor::GetReferencedFieldIds(mixed_and);
+  EXPECT_THAT(result, IsError(ErrorKind::kInvalidExpression));
+  EXPECT_THAT(result,
+              HasErrorMessage("Cannot get referenced field IDs from unbound 
predicate"));
+}
+
+TEST_F(ReferenceVisitorTest, AllFields) {
+  // Create expression referencing all fields in the schema
+  auto pred1 = Expressions::NotNull("id");
+  auto pred2 = Expressions::Equal("name", Literal::String("Test"));
+  auto pred3 = Expressions::GreaterThan("age", Literal::Int(0));
+  auto pred4 = Expressions::LessThan("salary", Literal::Double(100000.0));
+  auto pred5 = Expressions::Equal("active", Literal::Boolean(true));
+
+  auto and1 = Expressions::And(pred1, pred2);
+  auto and2 = Expressions::And(pred3, pred4);
+  auto and3 = Expressions::And(and1, and2);
+  auto all_fields = Expressions::And(and3, pred5);
+
+  ICEBERG_UNWRAP_OR_FAIL(auto bound_all, Bind(all_fields));
+  ICEBERG_UNWRAP_OR_FAIL(auto refs, 
ReferenceVisitor::GetReferencedFieldIds(bound_all));
+
+  // Should reference all 5 fields
+  EXPECT_EQ(refs.size(), 4);
+  EXPECT_EQ(refs.count(1), 0);  // id field is optimized out
+  EXPECT_EQ(refs.count(2), 1);  // name field
+  EXPECT_EQ(refs.count(3), 1);  // age field
+  EXPECT_EQ(refs.count(4), 1);  // salary field
+  EXPECT_EQ(refs.count(5), 1);  // active field
+}
+
 }  // namespace iceberg

Reply via email to