ZENOTME commented on code in PR #77:
URL: https://github.com/apache/datasketches-rust/pull/77#discussion_r2808919252


##########
datasketches/src/theta/sketch.rs:
##########
@@ -250,6 +302,538 @@ impl ThetaSketch {
     }
 }
 
+/// Compact (immutable) theta sketch.
+///
+/// This is the serialized-friendly form of a theta sketch: a compact array of 
retained hash values
+/// plus theta and a 16-bit seed hash. It can be ordered (sorted ascending) or 
unordered.
+#[derive(Clone, Debug)]
+pub struct CompactThetaSketch {
+    entries: Vec<u64>,
+    theta: u64,
+    seed_hash: u16,
+    ordered: bool,
+    empty: bool,
+}
+
+impl CompactThetaSketch {
+    /// Returns the cardinality estimate.
+    pub fn estimate(&self) -> f64 {
+        if self.is_empty() {
+            return 0.0;
+        }
+        let num_retained = self.num_retained() as f64;
+        if self.theta == MAX_THETA {
+            return num_retained;
+        }
+        let theta = self.theta as f64 / MAX_THETA as f64;
+        num_retained / theta
+    }
+
+    /// Returns theta as a fraction (0.0 to 1.0).
+    pub fn theta(&self) -> f64 {
+        self.theta as f64 / MAX_THETA as f64
+    }
+
+    /// Returns theta as u64.
+    pub fn theta64(&self) -> u64 {
+        self.theta
+    }
+
+    /// Returns true if this sketch is empty.
+    pub fn is_empty(&self) -> bool {
+        self.empty
+    }
+
+    /// Returns true if this sketch is in estimation mode.
+    pub fn is_estimation_mode(&self) -> bool {
+        self.theta < MAX_THETA
+    }
+
+    /// Returns the number of retained entries.
+    pub fn num_retained(&self) -> usize {
+        self.entries.len()
+    }
+
+    /// Returns true if retained entries are ordered (sorted ascending).
+    pub fn is_ordered(&self) -> bool {
+        self.ordered
+    }
+
+    /// Returns the 16-bit seed hash.
+    pub fn seed_hash(&self) -> u16 {
+        self.seed_hash
+    }
+
+    /// Return iterator over retained hash values.
+    pub fn iter(&self) -> impl Iterator<Item = u64> + '_ {
+        self.entries.iter().copied()
+    }
+
+    /// Returns the approximate lower error bound given the specified number 
of Standard Deviations.
+    pub fn lower_bound(&self, num_std_dev: NumStdDev) -> f64 {
+        if !self.is_estimation_mode() {
+            return self.num_retained() as f64;
+        }
+        binomial_bounds::lower_bound(self.num_retained() as u64, self.theta(), 
num_std_dev)
+            .expect("theta should always be valid")
+    }
+
+    /// Returns the approximate upper error bound given the specified number 
of Standard Deviations.
+    pub fn upper_bound(&self, num_std_dev: NumStdDev) -> f64 {
+        if !self.is_estimation_mode() {
+            return self.num_retained() as f64;
+        }
+        binomial_bounds::upper_bound(
+            self.num_retained() as u64,
+            self.theta(),
+            num_std_dev,
+            self.is_empty(),
+        )
+        .expect("theta should always be valid")
+    }
+
+    fn preamble_longs(&self, compressed: bool) -> u8 {
+        if compressed {
+            if self.is_estimation_mode() { 2 } else { 1 }
+        } else {
+            if self.is_estimation_mode() {
+                3
+            } else {
+                if self.is_empty() || self.entries.len() == 1 {
+                    1
+                } else {
+                    2
+                }
+            }
+        }
+    }
+
+    /// Serializes this sketch in compressed form if applicable.
+    ///
+    /// This uses `serVer = 4` when the sketch is ordered and suitable for 
compression, and falls
+    /// back to uncompressed `serVer = 3` otherwise.
+    pub fn serialize_compressed(&self) -> Vec<u8> {
+        if self.is_suitable_for_compression() {
+            self.serialize_v4()
+        } else {
+            self.serialize()
+        }
+    }
+
+    fn is_suitable_for_compression(&self) -> bool {
+        self.ordered
+            && !self.entries.is_empty()
+            && (self.entries.len() != 1 || self.is_estimation_mode())
+    }
+
+    /// Serializes this sketch into the uncompressed compact theta format.
+    pub fn serialize(&self) -> Vec<u8> {
+        let mut bytes = SketchBytes::with_capacity(64 + self.entries.len() * 
8);
+
+        let pre_longs = self.preamble_longs(false);
+        bytes.write_u8(pre_longs);
+        bytes.write_u8(serialization::UNCOMPRESSED_SERIAL_VERSION);
+        bytes.write_u8(Family::THETA.id);
+        bytes.write_u16_be(0); // unused for compact
+
+        let mut flags = 0u8;
+        flags |= serialization::FLAGS_IS_READ_ONLY;
+        flags |= serialization::FLAGS_IS_COMPACT;
+        if self.is_empty() {
+            flags |= serialization::FLAGS_IS_EMPTY;
+        }
+        if self.is_ordered() {
+            flags |= serialization::FLAGS_IS_ORDERED;
+        }
+        bytes.write_u8(flags);
+
+        bytes.write_u16_le(self.seed_hash);
+
+        if pre_longs > 1 {
+            bytes.write_u32_le(self.entries.len() as u32);
+            bytes.write_u32_be(0); // not used by compact sketches; match 
Java/C++
+        }
+        if self.is_estimation_mode() {
+            bytes.write_u64_le(self.theta64());
+        }
+        for hash in self.entries.iter() {
+            bytes.write_u64_le(*hash);
+        }
+        bytes.into_bytes()
+    }
+
+    fn serialize_v4(&self) -> Vec<u8> {
+        let pre_longs = self.preamble_longs(true);
+        let entry_bits = Self::compute_entry_bits(&self.entries);
+        let num_entries_bytes = Self::num_entries_bytes(self.entries.len());
+
+        // Pre-size exactly: preamble longs (8 bytes each) + num_entries_bytes 
+ packed bits.
+        let compressed_bits = entry_bits as usize * self.entries.len();
+        let compressed_bytes = compressed_bits.div_ceil(8);
+        let out_bytes = (pre_longs as usize * 8) + (num_entries_bytes as 
usize) + compressed_bytes;
+        let mut bytes = SketchBytes::with_capacity(out_bytes);
+
+        bytes.write_u8(pre_longs);
+        bytes.write_u8(serialization::COMPRESSED_SERIAL_VERSION);
+        bytes.write_u8(Family::THETA.id);
+        bytes.write_u8(entry_bits);
+        bytes.write_u8(num_entries_bytes);
+
+        let mut flags = 0u8;
+        flags |= serialization::FLAGS_IS_READ_ONLY;
+        flags |= serialization::FLAGS_IS_COMPACT;
+        flags |= serialization::FLAGS_IS_ORDERED;
+        bytes.write_u8(flags);
+
+        bytes.write_u16_le(self.seed_hash);
+        if self.is_estimation_mode() {
+            bytes.write_u64_le(self.theta);
+        }
+
+        let mut n = self.entries.len() as u32;
+        for _ in 0..num_entries_bytes {
+            bytes.write_u8((n & 0xff) as u8);
+            n >>= 8;
+        }
+
+        // pack deltas
+        let mut previous = 0u64;
+        let mut i = 0usize;
+        let mut block = vec![0u8; entry_bits as usize];
+        while i + 7 < self.entries.len() {
+            let mut deltas = [0u64; 8];
+            for j in 0..8 {
+                let entry = self.entries[i + j];
+                deltas[j] = entry - previous;
+                previous = entry;
+            }
+            block.fill(0);
+            pack_bits_block8(&deltas, &mut block, entry_bits);
+            bytes.write(&block);
+            i += 8;
+        }
+
+        // pack extra deltas if fewer than 8 of them left
+        if i < self.entries.len() {
+            let mut block = vec![0u8; entry_bits as usize];
+            let mut byte_index = 0;
+            let mut bit_index = 0;
+            let mut pack_delta_to_block = |delta: u64| {
+                (byte_index, bit_index) =
+                    pack_bits(delta, entry_bits, &mut block, byte_index, 
bit_index);
+            };
+            while i < self.entries.len() {
+                let delta = self.entries[i] - previous;
+                previous = self.entries[i];
+                pack_delta_to_block(delta);
+                i += 1;
+            }
+            if bit_index == 0 {
+                bytes.write(&block[0..byte_index]);
+            } else {
+                bytes.write(&block[0..byte_index + 1]);
+            }
+        }
+
+        bytes.into_bytes()
+    }
+
+    fn compute_entry_bits(entries: &[u64]) -> u8 {
+        let mut previous = 0u64;
+        let mut ored = 0u64;
+        for &entry in entries {
+            let delta = entry - previous;
+            ored |= delta;
+            previous = entry;
+        }
+        (64 - ored.leading_zeros()) as u8
+    }
+
+    fn num_entries_bytes(num_entries: usize) -> u8 {
+        let n = num_entries as u32;
+        let bits = u32::BITS - n.leading_zeros();
+        bits.div_ceil(8) as u8
+    }
+
+    /// Deserializes a compact theta sketch from bytes.
+    pub fn deserialize(bytes: &[u8]) -> Result<Self, Error> {
+        Self::deserialize_with_seed(bytes, DEFAULT_UPDATE_SEED)
+    }
+
+    /// Deserializes a compact theta sketch from bytes using the provided 
expected seed.
+    pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result<Self, 
Error> {
+        fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> 
Error {
+            move |_| Error::insufficient_data(tag)
+        }
+
+        let mut cursor = SketchSlice::new(bytes);
+        let pre_longs = 
cursor.read_u8().map_err(make_error("preamble_longs"))?;
+        let ser_ver = cursor.read_u8().map_err(make_error("serial_version"))?;
+        let family_id = cursor.read_u8().map_err(make_error("family_id"))?;
+
+        Family::THETA.validate_id(family_id)?;
+
+        // Validate pre_longs is within valid range for Theta sketch
+        ensure_preamble_longs_in_range(
+            Family::THETA.min_pre_longs..=Family::THETA.max_pre_longs,
+            pre_longs,
+        )?;
+
+        match ser_ver {
+            1 => Self::deserialize_v1(cursor, seed),
+            2 => Self::deserialize_v2(pre_longs, cursor, seed),
+            3 => Self::deserialize_v3(pre_longs, cursor, seed),
+            4 => Self::deserialize_v4(pre_longs, cursor, seed),
+            _ => Err(Error::deserial(format!(
+                "unsupported serial version: expected 1, 2, 3, or 4, got 
{ser_ver}",
+            ))),
+        }
+    }
+
+    fn read_entries(
+        cursor: &mut SketchSlice<'_>,
+        num_entries: usize,
+        theta: u64,
+    ) -> Result<Vec<u64>, Error> {
+        let mut entries = Vec::with_capacity(num_entries);
+        for _ in 0..num_entries {
+            let hash = cursor
+                .read_u64_le()
+                .map_err(|_| Error::insufficient_data("entries"))?;
+            if hash == 0 || hash >= theta {
+                return Err(Error::deserial("corrupted: invalid retained hash 
value"));
+            }
+            entries.push(hash);
+        }
+        Ok(entries)
+    }
+
+    fn deserialize_v1(mut cursor: SketchSlice<'_>, expected_seed: u64) -> 
Result<Self, Error> {
+        fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> 
Error {
+            move |_| Error::insufficient_data(tag)
+        }
+
+        let seed_hash = compute_seed_hash(expected_seed);
+        cursor.read_u8().map_err(make_error("<unused>"))?;
+        cursor.read_u32_le().map_err(make_error("<unused_u32_0>"))?;
+        let num_entries = 
cursor.read_u32_le().map_err(make_error("num_entries"))? as usize;
+        cursor.read_u32_le().map_err(make_error("<unused_u32_1>"))?;
+        let theta = cursor.read_u64_le().map_err(make_error("theta_long"))?;
+
+        let empty = num_entries == 0 && theta == MAX_THETA;
+        if empty {
+            return Ok(Self {
+                entries: vec![],
+                theta,
+                seed_hash,
+                ordered: true,
+                empty: true,
+            });
+        }
+
+        let entries = Self::read_entries(&mut cursor, num_entries, theta)?;
+
+        Ok(Self {
+            entries,
+            theta,
+            seed_hash,
+            ordered: true,
+            empty: false,
+        })
+    }
+
+    fn deserialize_v2(
+        pre_longs: u8,
+        mut cursor: SketchSlice<'_>,
+        expected_seed: u64,
+    ) -> Result<Self, Error> {
+        fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> 
Error {
+            move |_| Error::insufficient_data(tag)
+        }
+
+        cursor.read_u8().map_err(make_error("<unused>"))?;
+        cursor.read_u16_le().map_err(make_error("<unused_u16>"))?;
+        let seed_hash = cursor.read_u16_le().map_err(make_error("seed_hash"))?;
+        let expected_seed_hash = compute_seed_hash(expected_seed);
+        if seed_hash != expected_seed_hash {
+            return Err(Error::deserial(format!(
+                "incompatible seed hash: expected {expected_seed_hash}, got 
{seed_hash}",
+            )));
+        }
+
+        match pre_longs {
+            1 => Ok(Self {
+                entries: vec![],
+                theta: MAX_THETA,
+                seed_hash,
+                ordered: true,
+                empty: true,
+            }),
+            2 => {
+                let num_entries = 
cursor.read_u32_le().map_err(make_error("num_entries"))? as usize;
+                cursor.read_u32_le().map_err(make_error("<unused_u32>"))?;
+                let entries = Self::read_entries(&mut cursor, num_entries, 
MAX_THETA)?;
+                Ok(Self {
+                    entries,
+                    theta: MAX_THETA,
+                    seed_hash,
+                    ordered: true,
+                    empty: true,
+                })
+            }
+            3 => {

Review Comment:
   Actually, I noticed that it's different when checking for 
serialization/deserialization of the same field, e.g., theta, as follows: 
   ```
   // serialize
   if (this->is_estimation_mode()) write(os, this->theta_); // means that 
preamble = 3
   
   // deserialize
   if (preamble_longs > 2) theta = read<uint64_t>(is);
   ```
   Change like the following may be better and more consistent:
   ```
   // serialize
   if (this->is_estimation_mode()) write(os, this->theta_); // means that 
preamble = 3
   
   // deserialize
   if (preamble_longs.is_estimate()) theta = read<uint64_t>(is);
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to