LindaSummer commented on code in PR #3312:
URL: https://github.com/apache/kvrocks/pull/3312#discussion_r2912698768
##########
src/commands/cmd_tdigest.cc:
##########
@@ -492,6 +492,67 @@ class CommandTDigestMerge : public Commander {
TDigestMergeOptions options_;
};
+class CommandTDigestTrimmedMean : public Commander {
+ public:
+ Status Parse(const std::vector<std::string> &args) override {
+ if (args.size() != 4) {
+ return {Status::RedisParseErr, errWrongNumOfArguments};
+ }
+
+ key_name_ = args[1];
+
+ auto low_cut_quantile = ParseFloat(args[2]);
+ if (!low_cut_quantile) {
+ return {Status::RedisParseErr, errValueIsNotFloat};
+ }
+ low_cut_quantile_ = *low_cut_quantile;
+
+ auto high_cut_quantile = ParseFloat(args[3]);
+ if (!high_cut_quantile) {
+ return {Status::RedisParseErr, errValueIsNotFloat};
+ }
+ high_cut_quantile_ = *high_cut_quantile;
+
+ if (!std::isfinite(low_cut_quantile_) || low_cut_quantile_ < 0.0 ||
low_cut_quantile_ > 1.0) {
Review Comment:
Using a string validation before numeric validation maybe a better way to
avoid the unstable comparison of float numbers.
The string must be `^(0(?:\.\d*)?)|(1(?:\.0*))$`. Please double confirm my
regex.
We could also use comparing with delta to do this, but from pure literal
text would be more stable.
##########
tests/gocase/unit/type/tdigest/tdigest_test.go:
##########
@@ -717,6 +720,98 @@ func tdigestTests(t *testing.T, configs
util.KvrocksServerConfigs) {
require.EqualValues(t, expected[i], rank, "REVRANK
mismatch at index %d", i)
}
})
+
+ t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) {
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) {
+ emptyKey := "tdigest_empty"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey,
"compression", "100").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1",
"0.9")
+ require.NoError(t, result.Err())
+ require.Equal(t, "nan", result.Val())
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) {
+ key := "tdigest_basic"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2",
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+ require.NoError(t, result.Err())
+ mean, err := strconv.ParseFloat(result.Val().(string), 64)
+ require.NoError(t, err)
+ require.InDelta(t, 5.5, mean, 0.01)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) {
+ key := "tdigest_no_trim"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2",
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1")
+ require.NoError(t, result.Err())
+ mean, err := strconv.ParseFloat(result.Val().(string), 64)
+ require.NoError(t, err)
+ require.InDelta(t, 5.5, mean, 0.01)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) {
+ key := "tdigest_skewed"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1",
"1", "1", "1", "10", "100").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8")
+ require.NoError(t, result.Err())
+ mean, err := strconv.ParseFloat(result.Val().(string), 64)
+ require.NoError(t, err)
+ require.InDelta(t, 2.8, mean, 0.01)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t
*testing.T) {
+ require.ErrorContains(t, rdb.Do(ctx,
"TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg)
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
"key").Err(), errMsgWrongNumberArg)
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
"key", "0.1").Err(), errMsgWrongNumberArg)
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
"key", "0.1", "0.9", "extra").Err(), errMsgWrongNumberArg)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN invalid quantile ranges", func(t
*testing.T) {
+ key := "tdigest_invalid"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2",
"3", "4", "5").Err())
+
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
key, "-0.1", "0.9").Err(), errMsgLowCutQuantileRange)
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
key, "0.1", "1.1").Err(), errMsgHighCutQuantileRange)
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
key, "0.9", "0.1").Err(), errMsgLowCutQuantileLess)
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
key, "0.5", "0.5").Err(), errMsgLowCutQuantileLess)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) {
+ key := "tdigest_single"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42",
"42").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+ require.NoError(t, result.Err())
+ mean, err := strconv.ParseFloat(result.Val().(string), 64)
+ require.NoError(t, err)
+ require.InDelta(t, 42.0, mean, 0.01)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with extreme trimming", func(t *testing.T) {
+ key := "tdigest_extreme"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2",
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.4", "0.6")
+ require.NoError(t, result.Err())
+ mean, err := strconv.ParseFloat(result.Val().(string), 64)
+ require.NoError(t, err)
+ require.False(t, math.IsNaN(mean))
Review Comment:
Maybe this check is redundant since we have the next line.
##########
src/types/tdigest.h:
##########
@@ -309,3 +309,45 @@ inline Status TDigestRank(TD&& td, const
std::vector<double>& inputs, std::vecto
}
return Status::OK();
}
+
+template <typename TD>
+inline StatusOr<double> TDigestTrimmedMean(TD&& td, double low_cut_quantile,
double high_cut_quantile) {
+ if (td.Size() == 0) {
+ return std::numeric_limits<double>::quiet_NaN();
+ }
+
+ const double total_weight = td.TotalWeight();
+ const double leftmost_weight = std::floor(total_weight * low_cut_quantile);
+ const double rightmost_weight = std::ceil(total_weight * high_cut_quantile);
+
+ double count_done = 0.0;
+ double trimmed_sum = 0.0;
+ double trimmed_count = 0.0;
+
+ auto iter = td.Begin();
+ while (iter->Valid()) {
+ auto centroid = GET_OR_RET(iter->GetCentroid());
+ const double n_weight = centroid.weight;
+ double count_add = n_weight;
+
+ count_add -= std::min(std::max(0.0, leftmost_weight - count_done),
count_add);
+ count_add = std::min(std::max(0.0, rightmost_weight - count_done),
count_add);
Review Comment:
Could we have a comment for this? It will be difficult to understand when
see this logic at first time.
##########
src/commands/error_constants.h:
##########
@@ -54,4 +54,7 @@ inline constexpr const char *errParsingNumkeys = "error
parsing numkeys";
inline constexpr const char *errNumkeysMustBePositive = "numkeys need to be a
positive integer";
inline constexpr const char *errWrongKeyword = "wrong keyword";
inline constexpr const char *errInvalidRankValue = "rank needs to be
non-negative";
+inline constexpr const char *errLowCutQuantileRange = "low cut quantile must
be between 0 and 1";
Review Comment:
In redis, i tested and got below error message.
We'd better align with redis's bahavior.
```
localhost:6379> TDIGEST.TRIMMED_MEAN t -0.1 0.2
(error) ERR T-Digest: low_cut_percentile and high_cut_percentile should be
in [0,1]
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]