This is an automated email from the ASF dual-hosted git repository.
gangwu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-cpp.git
The following commit(s) were added to refs/heads/main by this push:
new 428a1714 feat: add DataFile aggregate evaluation (#400)
428a1714 is described below
commit 428a17146d77d5db43b47d55cc4c7013a2c2e3e7
Author: Zhiyuan Liang <[email protected]>
AuthorDate: Sun Dec 7 07:12:33 2025 -0800
feat: add DataFile aggregate evaluation (#400)
---
src/iceberg/expression/aggregate.cc | 262 ++++++++++++++++++++++++++++++++--
src/iceberg/expression/aggregate.h | 39 ++++-
src/iceberg/expression/expressions.cc | 14 ++
src/iceberg/expression/expressions.h | 6 +
src/iceberg/expression/term.h | 11 +-
src/iceberg/row/struct_like.cc | 40 ++++++
src/iceberg/row/struct_like.h | 3 +
src/iceberg/test/aggregate_test.cc | 240 +++++++++++++++++++++++++++++++
8 files changed, 597 insertions(+), 18 deletions(-)
diff --git a/src/iceberg/expression/aggregate.cc
b/src/iceberg/expression/aggregate.cc
index a9c1a60b..12bb3f03 100644
--- a/src/iceberg/expression/aggregate.cc
+++ b/src/iceberg/expression/aggregate.cc
@@ -19,11 +19,16 @@
#include "iceberg/expression/aggregate.h"
+#include <algorithm>
#include <format>
+#include <map>
+#include <memory>
#include <optional>
+#include <string_view>
#include <vector>
#include "iceberg/expression/literal.h"
+#include "iceberg/manifest/manifest_entry.h"
#include "iceberg/row/struct_like.h"
#include "iceberg/type.h"
#include "iceberg/util/checked_cast.h"
@@ -38,6 +43,32 @@ std::shared_ptr<PrimitiveType> GetPrimitiveType(const
BoundTerm& term) {
return internal::checked_pointer_cast<PrimitiveType>(term.type());
}
+/// \brief A single-field StructLike that wraps a Literal
+class SingleValueStructLike : public StructLike {
+ public:
+ explicit SingleValueStructLike(Literal literal) :
literal_(std::move(literal)) {}
+
+ Result<Scalar> GetField(size_t) const override { return
LiteralToScalar(literal_); }
+
+ size_t num_fields() const override { return 1; }
+
+ private:
+ Literal literal_;
+};
+
+Result<Literal> EvaluateBoundTerm(const BoundTerm& term,
+ const std::optional<std::vector<uint8_t>>&
bound) {
+ auto ptype = GetPrimitiveType(term);
+ if (!bound.has_value()) {
+ SingleValueStructLike data(Literal::Null(ptype));
+ return term.Evaluate(data);
+ }
+
+ ICEBERG_ASSIGN_OR_RAISE(auto literal, Literal::Deserialize(*bound, ptype));
+ SingleValueStructLike data(std::move(literal));
+ return term.Evaluate(data);
+}
+
class CountAggregator : public BoundAggregate::Aggregator {
public:
explicit CountAggregator(const CountAggregate& aggregate) :
aggregate_(aggregate) {}
@@ -48,11 +79,32 @@ class CountAggregator : public BoundAggregate::Aggregator {
return {};
}
- Literal GetResult() const override { return Literal::Long(count_); }
+ Status Update(const DataFile& file) override {
+ if (!valid_) {
+ return {};
+ }
+ if (!aggregate_.HasValue(file)) {
+ valid_ = false;
+ return {};
+ }
+ ICEBERG_ASSIGN_OR_RAISE(auto count, aggregate_.CountFor(file));
+ count_ += count;
+ return {};
+ }
+
+ Literal GetResult() const override {
+ if (!valid_) {
+ return Literal::Null(int64());
+ }
+ return Literal::Long(count_);
+ }
+
+ bool IsValid() const override { return valid_; }
private:
const CountAggregate& aggregate_;
int64_t count_ = 0;
+ bool valid_ = true;
};
class MaxAggregator : public BoundAggregate::Aggregator {
@@ -73,6 +125,7 @@ class MaxAggregator : public BoundAggregate::Aggregator {
if (auto ordering = value <=> current_;
ordering == std::partial_ordering::unordered) {
+ valid_ = false;
return InvalidArgument("Cannot compare literal {} with current value {}",
value.ToString(), current_.ToString());
} else if (ordering == std::partial_ordering::greater) {
@@ -82,11 +135,48 @@ class MaxAggregator : public BoundAggregate::Aggregator {
return {};
}
- Literal GetResult() const override { return current_; }
+ Status Update(const DataFile& file) override {
+ if (!valid_) {
+ return {};
+ }
+ if (!aggregate_.HasValue(file)) {
+ valid_ = false;
+ return {};
+ }
+
+ ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file));
+ if (value.IsNull()) {
+ return {};
+ }
+ if (current_.IsNull()) {
+ current_ = std::move(value);
+ return {};
+ }
+
+ if (auto ordering = value <=> current_;
+ ordering == std::partial_ordering::unordered) {
+ valid_ = false;
+ return InvalidArgument("Cannot compare literal {} with current value {}",
+ value.ToString(), current_.ToString());
+ } else if (ordering == std::partial_ordering::greater) {
+ current_ = std::move(value);
+ }
+ return {};
+ }
+
+ Literal GetResult() const override {
+ if (!valid_) {
+ return Literal::Null(GetPrimitiveType(*aggregate_.term()));
+ }
+ return current_;
+ }
+
+ bool IsValid() const override { return valid_; }
private:
const MaxAggregate& aggregate_;
Literal current_;
+ bool valid_ = true;
};
class MinAggregator : public BoundAggregate::Aggregator {
@@ -107,6 +197,7 @@ class MinAggregator : public BoundAggregate::Aggregator {
if (auto ordering = value <=> current_;
ordering == std::partial_ordering::unordered) {
+ valid_ = false;
return InvalidArgument("Cannot compare literal {} with current value {}",
value.ToString(), current_.ToString());
} else if (ordering == std::partial_ordering::less) {
@@ -115,13 +206,66 @@ class MinAggregator : public BoundAggregate::Aggregator {
return {};
}
- Literal GetResult() const override { return current_; }
+ Status Update(const DataFile& file) override {
+ if (!valid_) {
+ return {};
+ }
+ if (!aggregate_.HasValue(file)) {
+ valid_ = false;
+ return {};
+ }
+
+ ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file));
+ if (value.IsNull()) {
+ return {};
+ }
+ if (current_.IsNull()) {
+ current_ = std::move(value);
+ return {};
+ }
+
+ if (auto ordering = value <=> current_;
+ ordering == std::partial_ordering::unordered) {
+ valid_ = false;
+ return InvalidArgument("Cannot compare literal {} with current value {}",
+ value.ToString(), current_.ToString());
+ } else if (ordering == std::partial_ordering::less) {
+ current_ = std::move(value);
+ }
+ return {};
+ }
+
+ Literal GetResult() const override {
+ if (!valid_) {
+ return Literal::Null(GetPrimitiveType(*aggregate_.term()));
+ }
+ return current_;
+ }
+
+ bool IsValid() const override { return valid_; }
private:
const MinAggregate& aggregate_;
Literal current_;
+ bool valid_ = true;
};
+template <typename T>
+std::optional<T> GetMapValue(const std::map<int32_t, T>& map, int32_t key) {
+ auto iter = map.find(key);
+ if (iter == map.end()) {
+ return std::nullopt;
+ }
+ return iter->second;
+}
+
+int32_t GetFieldId(const std::shared_ptr<BoundTerm>& term) {
+ ICEBERG_DCHECK(term != nullptr, "Aggregate term should not be null");
+ auto ref = term->reference();
+ ICEBERG_DCHECK(ref != nullptr, "Aggregate term reference should not be
null");
+ return ref->field().field_id();
+}
+
} // namespace
template <TermType T>
@@ -149,7 +293,11 @@ std::string Aggregate<T>::ToString() const {
// -------------------- CountAggregate --------------------
Result<Literal> CountAggregate::Evaluate(const StructLike& data) const {
- return CountFor(data).transform([](int64_t count) { return
Literal::Long(count); });
+ return CountFor(data).transform(Literal::Long);
+}
+
+Result<Literal> CountAggregate::Evaluate(const DataFile& file) const {
+ return CountFor(file).transform(Literal::Long);
}
std::unique_ptr<BoundAggregate::Aggregator> CountAggregate::NewAggregator()
const {
@@ -173,6 +321,22 @@ Result<int64_t> CountNonNullAggregate::CountFor(const
StructLike& data) const {
[](const auto& val) { return val.IsNull() ? 0 : 1; });
}
+Result<int64_t> CountNonNullAggregate::CountFor(const DataFile& file) const {
+ auto field_id = GetFieldId(term());
+ if (!HasValue(file)) {
+ return NotFound("Missing metrics for field id {}", field_id);
+ }
+ auto value_count = GetMapValue(file.value_counts, field_id).value();
+ auto null_count = GetMapValue(file.null_value_counts, field_id).value();
+ return value_count - null_count;
+}
+
+bool CountNonNullAggregate::HasValue(const DataFile& file) const {
+ auto field_id = GetFieldId(term());
+ return file.value_counts.contains(field_id) &&
+ file.null_value_counts.contains(field_id);
+}
+
CountNullAggregate::CountNullAggregate(std::shared_ptr<BoundTerm> term)
: CountAggregate(Expression::Operation::kCountNull, std::move(term)) {}
@@ -189,6 +353,18 @@ Result<int64_t> CountNullAggregate::CountFor(const
StructLike& data) const {
[](const auto& val) { return val.IsNull() ? 1 : 0; });
}
+Result<int64_t> CountNullAggregate::CountFor(const DataFile& file) const {
+ auto field_id = GetFieldId(term());
+ if (!HasValue(file)) {
+ return NotFound("Missing metrics for field id {}", field_id);
+ }
+ return GetMapValue(file.null_value_counts, field_id).value();
+}
+
+bool CountNullAggregate::HasValue(const DataFile& file) const {
+ return file.null_value_counts.contains(GetFieldId(term()));
+}
+
CountStarAggregate::CountStarAggregate()
: CountAggregate(Expression::Operation::kCountStar, nullptr) {}
@@ -200,36 +376,93 @@ Result<int64_t> CountStarAggregate::CountFor(const
StructLike& /*data*/) const {
return 1;
}
+Result<int64_t> CountStarAggregate::CountFor(const DataFile& file) const {
+ if (!HasValue(file)) {
+ return NotFound("Record count is missing");
+ }
+ return file.record_count;
+}
+
+bool CountStarAggregate::HasValue(const DataFile& file) const {
+ return file.record_count >= 0;
+}
+
MaxAggregate::MaxAggregate(std::shared_ptr<BoundTerm> term)
: BoundAggregate(Expression::Operation::kMax, std::move(term)) {}
-std::shared_ptr<MaxAggregate> MaxAggregate::Make(std::shared_ptr<BoundTerm>
term) {
- return std::shared_ptr<MaxAggregate>(new MaxAggregate(std::move(term)));
+Result<std::unique_ptr<MaxAggregate>> MaxAggregate::Make(
+ std::shared_ptr<BoundTerm> term) {
+ if (!term) {
+ return InvalidExpression("Bound max aggregate requires non-null term");
+ }
+ if (!term->type()->is_primitive()) {
+ return InvalidExpression("Max aggregate term should be primitive");
+ }
+ return std::unique_ptr<MaxAggregate>(new MaxAggregate(std::move(term)));
}
Result<Literal> MaxAggregate::Evaluate(const StructLike& data) const {
return term()->Evaluate(data);
}
+Result<Literal> MaxAggregate::Evaluate(const DataFile& file) const {
+ auto field_id = GetFieldId(term());
+ auto upper = GetMapValue(file.upper_bounds, field_id);
+ return EvaluateBoundTerm(*term(), upper);
+}
+
std::unique_ptr<BoundAggregate::Aggregator> MaxAggregate::NewAggregator()
const {
return std::unique_ptr<BoundAggregate::Aggregator>(new MaxAggregator(*this));
}
+bool MaxAggregate::HasValue(const DataFile& file) const {
+ auto field_id = GetFieldId(term());
+ bool has_bound = file.upper_bounds.contains(field_id);
+ auto value_count = GetMapValue(file.value_counts, field_id);
+ auto null_count = GetMapValue(file.null_value_counts, field_id);
+ bool all_null = value_count.has_value() && *value_count > 0 &&
null_count.has_value() &&
+ null_count.value() == value_count.value();
+ return has_bound || all_null;
+}
+
MinAggregate::MinAggregate(std::shared_ptr<BoundTerm> term)
: BoundAggregate(Expression::Operation::kMin, std::move(term)) {}
-std::shared_ptr<MinAggregate> MinAggregate::Make(std::shared_ptr<BoundTerm>
term) {
- return std::shared_ptr<MinAggregate>(new MinAggregate(std::move(term)));
+Result<std::unique_ptr<MinAggregate>> MinAggregate::Make(
+ std::shared_ptr<BoundTerm> term) {
+ if (!term) {
+ return InvalidExpression("Bound min aggregate requires non-null term");
+ }
+ if (!term->type()->is_primitive()) {
+ return InvalidExpression("Max aggregate term should be primitive");
+ }
+ return std::unique_ptr<MinAggregate>(new MinAggregate(std::move(term)));
}
Result<Literal> MinAggregate::Evaluate(const StructLike& data) const {
return term()->Evaluate(data);
}
+Result<Literal> MinAggregate::Evaluate(const DataFile& file) const {
+ auto field_id = GetFieldId(term());
+ auto lower = GetMapValue(file.lower_bounds, field_id);
+ return EvaluateBoundTerm(*term(), lower);
+}
+
std::unique_ptr<BoundAggregate::Aggregator> MinAggregate::NewAggregator()
const {
return std::unique_ptr<BoundAggregate::Aggregator>(new MinAggregator(*this));
}
+bool MinAggregate::HasValue(const DataFile& file) const {
+ auto field_id = GetFieldId(term());
+ bool has_bound = file.lower_bounds.contains(field_id);
+ auto value_count = GetMapValue(file.value_counts, field_id);
+ auto null_count = GetMapValue(file.null_value_counts, field_id);
+ bool all_null = value_count.has_value() && *value_count > 0 &&
null_count.has_value() &&
+ null_count.value() == value_count.value();
+ return has_bound || all_null;
+}
+
// -------------------- Unbound binding --------------------
template <typename B>
@@ -275,8 +508,10 @@ Result<std::shared_ptr<UnboundAggregateImpl<B>>>
UnboundAggregateImpl<B>::Make(
}
template class Aggregate<UnboundTerm<BoundReference>>;
+template class Aggregate<UnboundTerm<BoundTransform>>;
template class Aggregate<BoundTerm>;
template class UnboundAggregateImpl<BoundReference>;
+template class UnboundAggregateImpl<BoundTransform>;
// -------------------- AggregateEvaluator --------------------
@@ -296,6 +531,13 @@ class AggregateEvaluatorImpl : public AggregateEvaluator {
return {};
}
+ Status Update(const DataFile& file) override {
+ for (auto& aggregator : aggregators_) {
+ ICEBERG_RETURN_UNEXPECTED(aggregator->Update(file));
+ }
+ return {};
+ }
+
Result<std::span<const Literal>> GetResults() const override {
results_.clear();
results_.reserve(aggregates_.size());
@@ -315,6 +557,10 @@ class AggregateEvaluatorImpl : public AggregateEvaluator {
return all.front();
}
+ bool AllAggregatorsValid() const override {
+ return std::ranges::all_of(aggregators_,
&BoundAggregate::Aggregator::IsValid);
+ }
+
private:
std::vector<std::shared_ptr<BoundAggregate>> aggregates_;
std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators_;
diff --git a/src/iceberg/expression/aggregate.h
b/src/iceberg/expression/aggregate.h
index cde9e458..6cf659d6 100644
--- a/src/iceberg/expression/aggregate.h
+++ b/src/iceberg/expression/aggregate.h
@@ -109,14 +109,15 @@ class ICEBERG_EXPORT BoundAggregate : public
Aggregate<BoundTerm>, public Bound
virtual Status Update(const StructLike& data) = 0;
- virtual Status Update(const DataFile& file) {
- return NotImplemented("Update(DataFile) not implemented");
- }
+ virtual Status Update(const DataFile& file) = 0;
+
+ /// \brief Whether the aggregator is still valid.
+ virtual bool IsValid() const = 0;
/// \brief Get the result of the aggregation.
/// \return The result of the aggregation.
/// \note It is an undefined behavior to call this method if any previous
Update call
- /// has returned an error.
+ /// has returned an error or if IsValid() returns false.
virtual Literal GetResult() const = 0;
};
@@ -128,6 +129,11 @@ class ICEBERG_EXPORT BoundAggregate : public
Aggregate<BoundTerm>, public Bound
Result<Literal> Evaluate(const StructLike& data) const override = 0;
+ virtual Result<Literal> Evaluate(const DataFile& file) const = 0;
+
+ /// \brief Whether metrics in the data file are sufficient to evaluate.
+ virtual bool HasValue(const DataFile& file) const = 0;
+
bool is_bound_aggregate() const override { return true; }
/// \brief Create a new aggregator for this aggregate.
@@ -142,12 +148,15 @@ class ICEBERG_EXPORT BoundAggregate : public
Aggregate<BoundTerm>, public Bound
/// \brief Base class for COUNT aggregates.
class ICEBERG_EXPORT CountAggregate : public BoundAggregate {
public:
- Result<Literal> Evaluate(const StructLike& data) const final;
+ Result<Literal> Evaluate(const StructLike& data) const override;
+ Result<Literal> Evaluate(const DataFile& file) const override;
std::unique_ptr<Aggregator> NewAggregator() const override;
/// \brief Count for a single row. Subclasses implement this.
virtual Result<int64_t> CountFor(const StructLike& data) const = 0;
+ /// \brief Count using metrics from a data file.
+ virtual Result<int64_t> CountFor(const DataFile& file) const = 0;
protected:
CountAggregate(Expression::Operation op, std::shared_ptr<BoundTerm> term)
@@ -161,6 +170,8 @@ class ICEBERG_EXPORT CountNonNullAggregate : public
CountAggregate {
std::shared_ptr<BoundTerm> term);
Result<int64_t> CountFor(const StructLike& data) const override;
+ Result<int64_t> CountFor(const DataFile& file) const override;
+ bool HasValue(const DataFile& file) const override;
private:
explicit CountNonNullAggregate(std::shared_ptr<BoundTerm> term);
@@ -173,6 +184,8 @@ class ICEBERG_EXPORT CountNullAggregate : public
CountAggregate {
std::shared_ptr<BoundTerm> term);
Result<int64_t> CountFor(const StructLike& data) const override;
+ Result<int64_t> CountFor(const DataFile& file) const override;
+ bool HasValue(const DataFile& file) const override;
private:
explicit CountNullAggregate(std::shared_ptr<BoundTerm> term);
@@ -184,6 +197,8 @@ class ICEBERG_EXPORT CountStarAggregate : public
CountAggregate {
static Result<std::unique_ptr<CountStarAggregate>> Make();
Result<int64_t> CountFor(const StructLike& data) const override;
+ Result<int64_t> CountFor(const DataFile& file) const override;
+ bool HasValue(const DataFile& file) const override;
private:
CountStarAggregate();
@@ -192,9 +207,11 @@ class ICEBERG_EXPORT CountStarAggregate : public
CountAggregate {
/// \brief Bound MAX aggregate.
class ICEBERG_EXPORT MaxAggregate : public BoundAggregate {
public:
- static std::shared_ptr<MaxAggregate> Make(std::shared_ptr<BoundTerm> term);
+ static Result<std::unique_ptr<MaxAggregate>> Make(std::shared_ptr<BoundTerm>
term);
Result<Literal> Evaluate(const StructLike& data) const override;
+ Result<Literal> Evaluate(const DataFile& file) const override;
+ bool HasValue(const DataFile& file) const override;
std::unique_ptr<Aggregator> NewAggregator() const override;
@@ -205,9 +222,11 @@ class ICEBERG_EXPORT MaxAggregate : public BoundAggregate {
/// \brief Bound MIN aggregate.
class ICEBERG_EXPORT MinAggregate : public BoundAggregate {
public:
- static std::shared_ptr<MinAggregate> Make(std::shared_ptr<BoundTerm> term);
+ static Result<std::unique_ptr<MinAggregate>> Make(std::shared_ptr<BoundTerm>
term);
Result<Literal> Evaluate(const StructLike& data) const override;
+ Result<Literal> Evaluate(const DataFile& file) const override;
+ bool HasValue(const DataFile& file) const override;
std::unique_ptr<Aggregator> NewAggregator() const override;
@@ -234,11 +253,17 @@ class ICEBERG_EXPORT AggregateEvaluator {
/// \brief Update aggregates with a row.
virtual Status Update(const StructLike& data) = 0;
+ /// \brief Update aggregates using data file metrics.
+ virtual Status Update(const DataFile& file) = 0;
+
/// \brief Final aggregated value.
virtual Result<std::span<const Literal>> GetResults() const = 0;
/// \brief Convenience accessor when only one aggregate is evaluated.
virtual Result<Literal> GetResult() const = 0;
+
+ /// \brief Whether all aggregators are still valid (metrics present).
+ virtual bool AllAggregatorsValid() const = 0;
};
} // namespace iceberg
diff --git a/src/iceberg/expression/expressions.cc
b/src/iceberg/expression/expressions.cc
index 7eef6023..4b0e538a 100644
--- a/src/iceberg/expression/expressions.cc
+++ b/src/iceberg/expression/expressions.cc
@@ -138,6 +138,13 @@ std::shared_ptr<UnboundAggregateImpl<BoundReference>>
Expressions::Max(
return agg;
}
+std::shared_ptr<UnboundAggregateImpl<BoundTransform>> Expressions::Max(
+ std::shared_ptr<UnboundTerm<BoundTransform>> expr) {
+ ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl<BoundTransform>::Make(
+ Expression::Operation::kMax,
std::move(expr)));
+ return agg;
+}
+
std::shared_ptr<UnboundAggregateImpl<BoundReference>>
Expressions::Min(std::string name) {
return Min(Ref(std::move(name)));
}
@@ -149,6 +156,13 @@ std::shared_ptr<UnboundAggregateImpl<BoundReference>>
Expressions::Min(
return agg;
}
+std::shared_ptr<UnboundAggregateImpl<BoundTransform>> Expressions::Min(
+ std::shared_ptr<UnboundTerm<BoundTransform>> expr) {
+ ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl<BoundTransform>::Make(
+ Expression::Operation::kMin,
std::move(expr)));
+ return agg;
+}
+
// Template implementations for unary predicates
std::shared_ptr<UnboundPredicateImpl<BoundReference>> Expressions::IsNull(
diff --git a/src/iceberg/expression/expressions.h
b/src/iceberg/expression/expressions.h
index 92c523ca..4ef2e780 100644
--- a/src/iceberg/expression/expressions.h
+++ b/src/iceberg/expression/expressions.h
@@ -135,6 +135,9 @@ class ICEBERG_EXPORT Expressions {
/// \brief Create a MAX aggregate for an unbound term.
static std::shared_ptr<UnboundAggregateImpl<BoundReference>> Max(
std::shared_ptr<UnboundTerm<BoundReference>> expr);
+ /// \brief Create a MAX aggregate for an unbound transform term.
+ static std::shared_ptr<UnboundAggregateImpl<BoundTransform>> Max(
+ std::shared_ptr<UnboundTerm<BoundTransform>> expr);
/// \brief Create a MIN aggregate for a field name.
static std::shared_ptr<UnboundAggregateImpl<BoundReference>> Min(std::string
name);
@@ -142,6 +145,9 @@ class ICEBERG_EXPORT Expressions {
/// \brief Create a MIN aggregate for an unbound term.
static std::shared_ptr<UnboundAggregateImpl<BoundReference>> Min(
std::shared_ptr<UnboundTerm<BoundReference>> expr);
+ /// \brief Create a MIN aggregate for an unbound transform term.
+ static std::shared_ptr<UnboundAggregateImpl<BoundTransform>> Min(
+ std::shared_ptr<UnboundTerm<BoundTransform>> expr);
// Unary predicates
diff --git a/src/iceberg/expression/term.h b/src/iceberg/expression/term.h
index 8b9606e5..616f11da 100644
--- a/src/iceberg/expression/term.h
+++ b/src/iceberg/expression/term.h
@@ -37,10 +37,13 @@ namespace iceberg {
/// \brief A term is an expression node that produces a typed value when
evaluated.
class ICEBERG_EXPORT Term : public util::Formattable {
public:
- enum class Kind : uint8_t { kReference = 0, kTransform, kExtract };
+ enum class Kind : uint8_t { kReference, kTransform, kExtract };
/// \brief Returns the kind of this term.
virtual Kind kind() const = 0;
+
+ /// \brief Returns whether this term is unbound.
+ virtual bool is_unbound() const = 0;
};
template <typename T>
@@ -53,6 +56,8 @@ template <typename B>
class ICEBERG_EXPORT UnboundTerm : public Unbound<B>, public Term {
public:
using BoundType = B;
+
+ bool is_unbound() const override { return true; }
};
/// \brief Base class for bound terms.
@@ -66,8 +71,6 @@ class ICEBERG_EXPORT BoundTerm : public Bound, public Term {
/// \brief Returns whether this term may produce null values.
virtual bool MayProduceNull() const = 0;
- // TODO(gangwu): add a comparator function to Literal and BoundTerm.
-
/// \brief Returns whether this term is equivalent to another.
///
/// Two terms are equivalent if they produce the same values when evaluated.
@@ -79,6 +82,8 @@ class ICEBERG_EXPORT BoundTerm : public Bound, public Term {
friend bool operator==(const BoundTerm& lhs, const BoundTerm& rhs) {
return lhs.Equals(rhs);
}
+
+ bool is_unbound() const override { return false; }
};
/// \brief A reference represents a named field in an expression.
diff --git a/src/iceberg/row/struct_like.cc b/src/iceberg/row/struct_like.cc
index b0fb67fb..85bde1a6 100644
--- a/src/iceberg/row/struct_like.cc
+++ b/src/iceberg/row/struct_like.cc
@@ -19,7 +19,9 @@
#include "iceberg/row/struct_like.h"
+#include <string>
#include <utility>
+#include <vector>
#include "iceberg/result.h"
#include "iceberg/util/checked_cast.h"
@@ -28,6 +30,44 @@
namespace iceberg {
+Result<Scalar> LiteralToScalar(const Literal& literal) {
+ if (literal.IsNull()) {
+ return Scalar{std::monostate{}};
+ }
+
+ switch (literal.type()->type_id()) {
+ case TypeId::kBoolean:
+ return Scalar{std::get<bool>(literal.value())};
+ case TypeId::kInt:
+ case TypeId::kDate:
+ return Scalar{std::get<int32_t>(literal.value())};
+ case TypeId::kLong:
+ case TypeId::kTime:
+ case TypeId::kTimestamp:
+ case TypeId::kTimestampTz:
+ return Scalar{std::get<int64_t>(literal.value())};
+ case TypeId::kFloat:
+ return Scalar{std::get<float>(literal.value())};
+ case TypeId::kDouble:
+ return Scalar{std::get<double>(literal.value())};
+ case TypeId::kString: {
+ const auto& str = std::get<std::string>(literal.value());
+ return Scalar{std::string_view(str)};
+ }
+ case TypeId::kBinary:
+ case TypeId::kFixed: {
+ const auto& bytes = std::get<std::vector<uint8_t>>(literal.value());
+ return Scalar{
+ std::string_view(reinterpret_cast<const char*>(bytes.data()),
bytes.size())};
+ }
+ case TypeId::kDecimal:
+ return Scalar{std::get<Decimal>(literal.value())};
+ default:
+ return NotSupported("Cannot convert literal of type {} to Scalar",
+ literal.type()->ToString());
+ }
+}
+
StructLikeAccessor::StructLikeAccessor(std::shared_ptr<Type> type,
std::span<const size_t> position_path)
: type_(std::move(type)) {
diff --git a/src/iceberg/row/struct_like.h b/src/iceberg/row/struct_like.h
index 4999da69..36ff5d86 100644
--- a/src/iceberg/row/struct_like.h
+++ b/src/iceberg/row/struct_like.h
@@ -55,6 +55,9 @@ using Scalar = std::variant<std::monostate, // for null
std::shared_ptr<ArrayLike>, // for list
std::shared_ptr<MapLike>>; // for map
+/// \brief Convert a Literal to a Scalar
+Result<Scalar> LiteralToScalar(const Literal& literal);
+
/// \brief An immutable struct-like wrapper.
class ICEBERG_EXPORT StructLike {
public:
diff --git a/src/iceberg/test/aggregate_test.cc
b/src/iceberg/test/aggregate_test.cc
index 264e606f..9885c7a6 100644
--- a/src/iceberg/test/aggregate_test.cc
+++ b/src/iceberg/test/aggregate_test.cc
@@ -23,6 +23,7 @@
#include "iceberg/expression/binder.h"
#include "iceberg/expression/expressions.h"
+#include "iceberg/manifest/manifest_entry.h"
#include "iceberg/row/struct_like.h"
#include "iceberg/schema.h"
#include "iceberg/test/matchers.h"
@@ -236,4 +237,243 @@ TEST(AggregateTest, MultipleAggregatesInEvaluator) {
EXPECT_EQ(std::get<int64_t>(results[4].value()), 4); // count_star
}
+TEST(AggregateTest, AggregatesFromDataFileMetrics) {
+ Schema schema({SchemaField::MakeOptional(1, "id", int32()),
+ SchemaField::MakeOptional(2, "value", int32())});
+
+ auto count_bound = BindAggregate(schema, Expressions::Count("id"));
+ auto count_null_bound = BindAggregate(schema, Expressions::CountNull("id"));
+ auto count_star_bound = BindAggregate(schema, Expressions::CountStar());
+ auto max_bound = BindAggregate(schema, Expressions::Max("value"));
+ auto min_bound = BindAggregate(schema, Expressions::Min("value"));
+
+ std::vector<std::shared_ptr<BoundAggregate>> aggregates{
+ count_bound, count_null_bound, count_star_bound, max_bound, min_bound};
+ ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates));
+
+ ICEBERG_UNWRAP_OR_FAIL(auto lower, Literal::Int(5).Serialize());
+ ICEBERG_UNWRAP_OR_FAIL(auto upper, Literal::Int(50).Serialize());
+ DataFile file{
+ .record_count = 10,
+ .value_counts = {{1, 10}, {2, 10}},
+ .null_value_counts = {{1, 2}, {2, 0}},
+ .lower_bounds = {{2, lower}},
+ .upper_bounds = {{2, upper}},
+ };
+
+ ASSERT_TRUE(evaluator->Update(file).has_value());
+
+ ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults());
+ ASSERT_EQ(results.size(), aggregates.size());
+ EXPECT_EQ(std::get<int64_t>(results[0].value()), 8); // count(id) = 10 - 2
+ EXPECT_EQ(std::get<int64_t>(results[1].value()), 2); // count_null(id)
+ EXPECT_EQ(std::get<int64_t>(results[2].value()), 10); // count_star
+ EXPECT_EQ(std::get<int32_t>(results[3].value()), 50); // max(value)
+ EXPECT_EQ(std::get<int32_t>(results[4].value()), 5); // min(value)
+}
+
+TEST(AggregateTest, AggregatesFromDataFileMissingMetricsReturnNull) {
+ Schema schema({SchemaField::MakeOptional(1, "id", int32()),
+ SchemaField::MakeOptional(2, "value", int32())});
+
+ auto count_bound = BindAggregate(schema, Expressions::Count("id"));
+ auto count_null_bound = BindAggregate(schema, Expressions::CountNull("id"));
+ auto count_star_bound = BindAggregate(schema, Expressions::CountStar());
+ auto max_bound = BindAggregate(schema, Expressions::Max("value"));
+ auto min_bound = BindAggregate(schema, Expressions::Min("value"));
+
+ std::vector<std::shared_ptr<BoundAggregate>> aggregates{
+ count_bound, count_null_bound, count_star_bound, max_bound, min_bound};
+ ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates));
+
+ DataFile file{.record_count = -1}; // missing/invalid
+
+ ASSERT_TRUE(evaluator->Update(file).has_value());
+
+ ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults());
+ ASSERT_EQ(results.size(), aggregates.size());
+ for (const auto& literal : results) {
+ EXPECT_TRUE(literal.IsNull());
+ }
+}
+
+TEST(AggregateTest, AggregatesFromDataFileWithTransform) {
+ Schema schema({SchemaField::MakeOptional(1, "id", int32())});
+
+ auto truncate_id = Expressions::Truncate("id", 10);
+ auto max_bound = BindAggregate(schema, Expressions::Max(truncate_id));
+ auto min_bound = BindAggregate(schema, Expressions::Min(truncate_id));
+
+ std::vector<std::shared_ptr<BoundAggregate>> aggregates{max_bound,
min_bound};
+ ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates));
+
+ ICEBERG_UNWRAP_OR_FAIL(auto lower, Literal::Int(5).Serialize());
+ ICEBERG_UNWRAP_OR_FAIL(auto upper, Literal::Int(23).Serialize());
+ DataFile file{
+ .record_count = 5,
+ .value_counts = {{1, 5}},
+ .null_value_counts = {{1, 0}},
+ .lower_bounds = {{1, lower}},
+ .upper_bounds = {{1, upper}},
+ };
+
+ ASSERT_TRUE(evaluator->Update(file).has_value());
+
+ ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults());
+ ASSERT_EQ(results.size(), aggregates.size());
+ // Truncate width 10: max(truncate(23)) -> 20, min(truncate(5)) -> 0
+ EXPECT_EQ(std::get<int32_t>(results[0].value()), 20);
+ EXPECT_EQ(std::get<int32_t>(results[1].value()), 0);
+ EXPECT_TRUE(evaluator->AllAggregatorsValid());
+}
+
+TEST(AggregateTest, DataFileAggregatorParity) {
+ Schema schema({SchemaField::MakeRequired(1, "id", int32()),
+ SchemaField::MakeOptional(2, "no_stats", int32()),
+ SchemaField::MakeOptional(3, "all_nulls", string()),
+ SchemaField::MakeOptional(4, "some_nulls", string())});
+
+ auto make_bounds = [](int field_id, int32_t lower, int32_t upper) {
+ std::map<int32_t, std::vector<uint8_t>> lower_bounds;
+ std::map<int32_t, std::vector<uint8_t>> upper_bounds;
+ auto lser = Literal::Int(lower).Serialize().value();
+ auto user = Literal::Int(upper).Serialize().value();
+ lower_bounds.emplace(field_id, std::move(lser));
+ upper_bounds.emplace(field_id, std::move(user));
+ return std::pair{std::move(lower_bounds), std::move(upper_bounds)};
+ };
+
+ auto [b1_lower, b1_upper] = make_bounds(1, 33, 2345);
+ DataFile file{
+ .file_path = "file.avro",
+ .record_count = 50,
+ .value_counts = {{1, 50}, {3, 50}, {4, 50}},
+ .null_value_counts = {{1, 10}, {3, 50}, {4, 10}},
+ .lower_bounds = std::move(b1_lower),
+ .upper_bounds = std::move(b1_upper),
+ };
+
+ auto [b2_lower, b2_upper] = make_bounds(1, 33, 100);
+ DataFile missing_some_nulls_1{
+ .file_path = "file_2.avro",
+ .record_count = 20,
+ .value_counts = {{1, 20}, {3, 20}},
+ .null_value_counts = {{1, 0}, {3, 20}},
+ .lower_bounds = std::move(b2_lower),
+ .upper_bounds = std::move(b2_upper),
+ };
+
+ auto [b3_lower, b3_upper] = make_bounds(1, -33, 3333);
+ DataFile missing_some_nulls_2{
+ .file_path = "file_3.avro",
+ .record_count = 20,
+ .value_counts = {{1, 20}, {3, 20}},
+ .null_value_counts = {{1, 20}, {3, 20}},
+ .lower_bounds = std::move(b3_lower),
+ .upper_bounds = std::move(b3_upper),
+ };
+
+ DataFile missing_some_stats{
+ .file_path = "file_missing_stats.avro",
+ .record_count = 20,
+ .value_counts = {{1, 20}, {4, 10}},
+ };
+ auto [b4_lower, b4_upper] = make_bounds(1, -3, 1333);
+ missing_some_stats.lower_bounds = std::move(b4_lower);
+ missing_some_stats.upper_bounds = std::move(b4_upper);
+
+ DataFile missing_all_optional_stats{
+ .file_path = "file_null_stats.avro",
+ .record_count = 20,
+ };
+
+ auto run_case = [&](const std::vector<std::shared_ptr<Expression>>& exprs,
+ const std::vector<DataFile>& files,
+ const std::vector<std::optional<Scalar>>& expected,
+ bool expect_all_valid) {
+ std::vector<std::shared_ptr<BoundAggregate>> aggregates;
+ aggregates.reserve(exprs.size());
+ for (const auto& e : exprs) {
+ aggregates.emplace_back(BindAggregate(schema, e));
+ }
+ ICEBERG_UNWRAP_OR_FAIL(auto evaluator,
AggregateEvaluator::Make(aggregates));
+ for (const auto& f : files) {
+ ASSERT_TRUE(evaluator->Update(f).has_value());
+ }
+ ASSERT_EQ(evaluator->AllAggregatorsValid(), expect_all_valid);
+ ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults());
+ ASSERT_EQ(results.size(), expected.size());
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].has_value()) {
+ EXPECT_TRUE(results[i].IsNull());
+ } else {
+ const auto& exp = *expected[i];
+ const auto& res = results[i].value();
+ if (std::holds_alternative<int32_t>(exp)) {
+ EXPECT_EQ(std::get<int32_t>(res), std::get<int32_t>(exp));
+ } else if (std::holds_alternative<int64_t>(exp)) {
+ EXPECT_EQ(std::get<int64_t>(res), std::get<int64_t>(exp));
+ } else {
+ FAIL() << "Unexpected expected type";
+ }
+ }
+ }
+ };
+
+ // testIntAggregate
+ run_case({Expressions::CountStar(), Expressions::Count("id"),
+ Expressions::CountNull("id"), Expressions::Max("id"),
Expressions::Min("id")},
+ {file, missing_some_nulls_1, missing_some_nulls_2},
+ {Scalar{int64_t{90}}, Scalar{int64_t{60}}, Scalar{int64_t{30}},
+ Scalar{int32_t{3333}}, Scalar{int32_t{-33}}},
+ /*expect_all_valid=*/true);
+
+ // testAllNulls
+ run_case({Expressions::CountStar(), Expressions::Count("all_nulls"),
+ Expressions::CountNull("all_nulls"), Expressions::Max("all_nulls"),
+ Expressions::Min("all_nulls")},
+ {file, missing_some_nulls_1, missing_some_nulls_2},
+ {Scalar{int64_t{90}}, Scalar{int64_t{0}}, Scalar{int64_t{90}},
std::nullopt,
+ std::nullopt},
+ /*expect_all_valid=*/true);
+
+ // testSomeNulls -> missing null counts for field 4
+ run_case({Expressions::CountStar(), Expressions::Count("some_nulls"),
+ Expressions::CountNull("some_nulls"),
Expressions::Max("some_nulls"),
+ Expressions::Min("some_nulls")},
+ {file, missing_some_nulls_1, missing_some_nulls_2},
+ {Scalar{int64_t{90}}, std::nullopt, std::nullopt, std::nullopt,
std::nullopt},
+ /*expect_all_valid=*/false);
+
+ // testNoStats -> field 2 has no metrics
+ run_case({Expressions::CountStar(), Expressions::Count("no_stats"),
+ Expressions::CountNull("no_stats"), Expressions::Max("no_stats"),
+ Expressions::Min("no_stats")},
+ {file, missing_some_nulls_1, missing_some_nulls_2},
+ {Scalar{int64_t{90}}, std::nullopt, std::nullopt, std::nullopt,
std::nullopt},
+ /*expect_all_valid=*/false);
+
+ // testIntAggregateAllMissingStats -> id missing optional stats
+ run_case({Expressions::CountStar(), Expressions::Count("id"),
+ Expressions::CountNull("id"), Expressions::Max("id"),
Expressions::Min("id")},
+ {missing_all_optional_stats},
+ {Scalar{int64_t{20}}, std::nullopt, std::nullopt, std::nullopt,
std::nullopt},
+ /*expect_all_valid=*/false);
+
+ // testOptionalColAllMissingStats -> field 2 missing everything
+ run_case({Expressions::CountStar(), Expressions::Count("no_stats"),
+ Expressions::CountNull("no_stats"), Expressions::Max("no_stats"),
+ Expressions::Min("no_stats")},
+ {missing_all_optional_stats},
+ {Scalar{int64_t{20}}, std::nullopt, std::nullopt, std::nullopt,
std::nullopt},
+ /*expect_all_valid=*/false);
+
+ // testMissingSomeStats -> some_nulls missing null stats entirely
+ run_case({Expressions::CountStar(), Expressions::Count("some_nulls"),
+ Expressions::Max("some_nulls"), Expressions::Min("some_nulls")},
+ {missing_some_stats},
+ {Scalar{int64_t{20}}, std::nullopt, std::nullopt, std::nullopt},
+ /*expect_all_valid=*/false);
+}
+
} // namespace iceberg