This is an automated email from the ASF dual-hosted git repository.

alsay pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datasketches-go.git

commit 5aa13f2bbdb1cfa34b028738804756ba9b8e4677
Author: Pierre Lacave <[email protected]>
AuthorDate: Wed Dec 20 17:43:44 2023 +0100

    Add LongFrequency merge
---
 common/utils.go                            |  4 ++
 frequencies/long_sketch.go                 | 30 +++++++++-
 frequencies/long_sketch_test.go            | 48 +++++++++++++++-
 frequencies/reverse_purge_long_hash_map.go | 92 ++++++++++++++++++++++++++----
 4 files changed, 159 insertions(+), 15 deletions(-)

diff --git a/common/utils.go b/common/utils.go
index 0919843..305fa00 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -24,6 +24,10 @@ import (
        "strconv"
 )
 
+const (
+       InverseGolden = float64(0.6180339887498949025)
+)
+
 // InvPow2 returns 2^(-e).
 func InvPow2(e int) (float64, error) {
        if (e | 1024 - e - 1) < 0 {
diff --git a/frequencies/long_sketch.go b/frequencies/long_sketch.go
index 837d956..9c3d730 100644
--- a/frequencies/long_sketch.go
+++ b/frequencies/long_sketch.go
@@ -190,6 +190,14 @@ func (s *LongSketch) getCurrentMapCapacity() int {
        return s.curMapCap
 }
 
+func (s *LongSketch) getStreamLength() int64 {
+       return s.streamWeight
+}
+
+func (s *LongSketch) isEmpty() bool {
+       return s.getNumActiveItems() == 0
+}
+
 func (s *LongSketch) Update(item int64, count int64) error {
        if count == 0 {
                return nil
@@ -198,7 +206,10 @@ func (s *LongSketch) Update(item int64, count int64) error 
{
                return fmt.Errorf("count may not be negative")
        }
        s.streamWeight += count
-       s.hashMap.adjustOrPutValue(item, count)
+       err := s.hashMap.adjustOrPutValue(item, count)
+       if err != nil {
+               return err
+       }
 
        if s.hashMap.numActive > s.curMapCap {
                // Over the threshold, we need to do something
@@ -217,6 +228,23 @@ func (s *LongSketch) Update(item int64, count int64) error 
{
        return nil
 }
 
+func (s *LongSketch) merge(other *LongSketch) (*LongSketch, error) {
+       if other == nil || other.isEmpty() {
+               return s, nil
+       }
+       streamWt := s.streamWeight + other.streamWeight //capture before merge
+       iter := other.hashMap.iterator()
+       for iter.next() {
+               err := s.Update(iter.getKey(), iter.getValue())
+               if err != nil {
+                       return nil, err
+               }
+       }
+       s.offset += other.offset
+       s.streamWeight = streamWt //corrected streamWeight
+       return s, nil
+}
+
 func (s *LongSketch) serializeToString() (string, error) {
        var sb strings.Builder
        //start the string with parameters of the sketch
diff --git a/frequencies/long_sketch_test.go b/frequencies/long_sketch_test.go
index cc1c115..9d02fc3 100644
--- a/frequencies/long_sketch_test.go
+++ b/frequencies/long_sketch_test.go
@@ -25,8 +25,8 @@ import (
 func TestFrequentItsemsStringSerialTest(t *testing.T) {
        sketch, err := NewLongSketchWithDefault(8)
        assert.NoError(t, err)
-       //sketch2, err := NewLongSketchWithDefault(128)
-       //assert.NoError(t, err)
+       sketch2, err := NewLongSketchWithDefault(128)
+       assert.NoError(t, err)
        sketch.Update(10, 100)
        sketch.Update(10, 100)
        sketch.Update(15, 3443)
@@ -42,4 +42,48 @@ func TestFrequentItsemsStringSerialTest(t *testing.T) {
        assert.Equal(t, ser, newSer0)
        assert.Equal(t, sketch.getMaximumMapCapacity(), 
newSk0.getMaximumMapCapacity())
        assert.Equal(t, sketch.getCurrentMapCapacity(), 
newSk0.getCurrentMapCapacity())
+
+       sketch2.Update(190, 12902390)
+       sketch2.Update(191, 12902390)
+       sketch2.Update(192, 12902390)
+       sketch2.Update(193, 12902390)
+       sketch2.Update(194, 12902390)
+       sketch2.Update(195, 12902390)
+       sketch2.Update(196, 12902390)
+       sketch2.Update(197, 12902390)
+       sketch2.Update(198, 12902390)
+       sketch2.Update(199, 12902390)
+       sketch2.Update(200, 12902390)
+       sketch2.Update(201, 12902390)
+       sketch2.Update(202, 12902390)
+       sketch2.Update(203, 12902390)
+       sketch2.Update(204, 12902390)
+       sketch2.Update(205, 12902390)
+       sketch2.Update(206, 12902390)
+       sketch2.Update(207, 12902390)
+       sketch2.Update(208, 12902390)
+
+       s2, err := sketch2.serializeToString()
+       assert.NoError(t, err)
+       newSk2, err := NewLongSketchFromString(s2)
+       assert.NoError(t, err)
+       newS2, err := newSk2.serializeToString()
+       assert.NoError(t, err)
+       assert.Equal(t, s2, newS2)
+       assert.Equal(t, sketch2.getMaximumMapCapacity(), 
newSk2.getMaximumMapCapacity())
+       assert.Equal(t, sketch2.getCurrentMapCapacity(), 
newSk2.getCurrentMapCapacity())
+       assert.Equal(t, sketch2.getStreamLength(), newSk2.getStreamLength())
+
+       mergedSketch, err := sketch.merge(sketch2)
+       assert.NoError(t, err)
+       ms, err := mergedSketch.serializeToString()
+       assert.NoError(t, err)
+       newMs, err := NewLongSketchFromString(ms)
+       assert.NoError(t, err)
+       newSMs, err := newMs.serializeToString()
+       assert.NoError(t, err)
+       assert.Equal(t, ms, newSMs)
+       assert.Equal(t, mergedSketch.getMaximumMapCapacity(), 
newMs.getMaximumMapCapacity())
+       assert.Equal(t, mergedSketch.getCurrentMapCapacity(), 
newMs.getCurrentMapCapacity())
+       assert.Equal(t, mergedSketch.getStreamLength(), newMs.getStreamLength())
 }
diff --git a/frequencies/reverse_purge_long_hash_map.go 
b/frequencies/reverse_purge_long_hash_map.go
index 82417ea..66c18c5 100644
--- a/frequencies/reverse_purge_long_hash_map.go
+++ b/frequencies/reverse_purge_long_hash_map.go
@@ -35,6 +35,17 @@ type reversePurgeLongHashMap struct {
        numActive     int
 }
 
+type iteratorHashMap struct {
+       keys_      []int64
+       values_    []int64
+       states_    []int16
+       numActive_ int
+       stride_    int
+       mask_      int
+       i_         int
+       count_     int
+}
+
 const (
        loadFactor = float64(0.75)
        driftLimit = 1024 //used only in stress testing
@@ -74,7 +85,7 @@ func (r *reversePurgeLongHashMap) getCapacity() int {
 //
 // key the key of the value to increment
 // adjustAmount the amount by which to increment the value
-func (r *reversePurgeLongHashMap) adjustOrPutValue(key int64, adjustAmount 
int64) {
+func (r *reversePurgeLongHashMap) adjustOrPutValue(key int64, adjustAmount 
int64) error {
        var (
                arrayMask = len(r.keys) - 1
                probe     = hash(key) & int64(arrayMask)
@@ -90,8 +101,8 @@ func (r *reversePurgeLongHashMap) adjustOrPutValue(key 
int64, adjustAmount int64
        //found either an empty slot or the key
        if r.states[probe] == 0 { //found empty slot
                // adding the key and value to the table
-               if r.numActive >= r.loadThreshold {
-                       panic("numActive >= loadThreshold")
+               if r.numActive > r.loadThreshold {
+                       return fmt.Errorf("numActive >= loadThreshold")
                }
                r.keys[probe] = key
                r.values[probe] = adjustAmount
@@ -103,9 +114,10 @@ func (r *reversePurgeLongHashMap) adjustOrPutValue(key 
int64, adjustAmount int64
                }
                r.values[probe] += adjustAmount
        }
+       return nil
 }
 
-func (r *reversePurgeLongHashMap) resize(newSize int) {
+func (r *reversePurgeLongHashMap) resize(newSize int) error {
        oldKeys := r.keys
        oldValues := r.values
        oldStates := r.states
@@ -115,12 +127,13 @@ func (r *reversePurgeLongHashMap) resize(newSize int) {
        r.loadThreshold = int(float64(newSize) * loadFactor)
        r.lgLength = bits.TrailingZeros(uint(newSize))
        r.numActive = 0
-       for i := 0; i < len(oldKeys); i++ {
+       err := error(nil)
+       for i := 0; i < len(oldKeys) && err == nil; i++ {
                if oldStates[i] > 0 {
-                       r.adjustOrPutValue(oldKeys[i], oldValues[i])
+                       err = r.adjustOrPutValue(oldKeys[i], oldValues[i])
                }
        }
-
+       return err
 }
 
 func (r *reversePurgeLongHashMap) purge(sampleSize int) int64 {
@@ -232,7 +245,7 @@ func deserializeReversePurgeLongHashMapFromString(string 
string) (*reversePurgeL
                return nil, err
        }
        j := 2
-       for i := 0; i < numActive; i++ {
+       for i := 0; i < numActive && err == nil; i++ {
                key, err := strconv.Atoi(tokens[j])
                if err != nil {
                        return nil, err
@@ -241,7 +254,10 @@ func deserializeReversePurgeLongHashMapFromString(string 
string) (*reversePurgeL
                if err != nil {
                        return nil, err
                }
-               table.adjustOrPutValue(int64(key), int64(value))
+               err = table.adjustOrPutValue(int64(key), int64(value))
+               if err != nil {
+                       return nil, err
+               }
                j += 2
        }
        return table, nil
@@ -257,10 +273,62 @@ func deserializeFromStringArray(tokens []string) 
(*reversePurgeLongHashMap, erro
        }
        j := 2 + ignore
        for i := 0; i < int(numActive); i++ {
-               key, _ := strconv.ParseUint(tokens[j], 10, 64)
-               value, _ := strconv.ParseUint(tokens[j+1], 10, 64)
-               hashMap.adjustOrPutValue(int64(key), int64(value))
+               key, err := strconv.ParseUint(tokens[j], 10, 64)
+               if err != nil {
+                       return nil, err
+               }
+               value, err := strconv.ParseUint(tokens[j+1], 10, 64)
+               if err != nil {
+                       return nil, err
+               }
+               err = hashMap.adjustOrPutValue(int64(key), int64(value))
+               if err != nil {
+                       return nil, err
+               }
                j += 2
        }
        return hashMap, nil
 }
+
+func (s *reversePurgeLongHashMap) iterator() *iteratorHashMap {
+       return &iteratorHashMap{
+               keys_:      s.keys,
+               values_:    s.values,
+               states_:    s.states,
+               numActive_: s.numActive,
+       }
+}
+
+func newIterator(keys []int64, values []int64, states []int16, numActive int) 
*iteratorHashMap {
+       stride := int(uint64(float64(len(keys))*common.InverseGolden) | 1)
+       return &iteratorHashMap{
+               keys_:      keys,
+               values_:    values,
+               states_:    states,
+               numActive_: numActive,
+
+               stride_: stride,
+               mask_:   len(keys) - 1,
+               i_:      -stride,
+       }
+}
+
+func (i *iteratorHashMap) next() bool {
+       i.i_ = (i.i_ + i.stride_) & i.mask_
+       for i.count_ < i.numActive_ {
+               if i.states_[i.i_] > 0 {
+                       i.count_++
+                       return true
+               }
+               i.i_ = (i.i_ + i.stride_) & i.mask_
+       }
+       return false
+}
+
+func (i *iteratorHashMap) getKey() int64 {
+       return i.keys_[i.i_]
+}
+
+func (i *iteratorHashMap) getValue() int64 {
+       return i.values_[i.i_]
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to