This is an automated email from the ASF dual-hosted git repository.

mkataria pushed a commit to branch OAK-11716
in repository https://gitbox.apache.org/repos/asf/jackrabbit-oak.git

commit a31ddf71797265ec2d574bcc5c08111b671e7a75
Author: Mohit Kataria <[email protected]>
AuthorDate: Sat May 10 17:26:36 2025 +0530

    OAK-11716: Capture inference service stats
---
 oak-search-elastic/pom.xml                         |   6 +
 .../query/inference/InferenceServiceManager.java   |  10 +-
 .../query/inference/InferenceServiceMetrics.java   | 404 +++++++++++++++++++++
 .../inference/InferenceServiceUsingConfig.java     |  24 +-
 .../InferenceServiceUsingIndexConfig.java          |  62 +++-
 .../inference/InferenceServiceMetricsTest.java     | 259 +++++++++++++
 6 files changed, 745 insertions(+), 20 deletions(-)

diff --git a/oak-search-elastic/pom.xml b/oak-search-elastic/pom.xml
index 4f951259c6..7c49696898 100644
--- a/oak-search-elastic/pom.xml
+++ b/oak-search-elastic/pom.xml
@@ -221,6 +221,12 @@
       <scope>provided</scope>
     </dependency>
 
+    <!-- Metrics -->
+    <dependency>
+      <groupId>io.dropwizard.metrics</groupId>
+      <artifactId>metrics-core</artifactId>
+    </dependency>
+
     <!-- Nullability annotations -->
     <dependency>
       <groupId>org.jetbrains</groupId>
diff --git 
a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceManager.java
 
b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceManager.java
index a5b2fe389f..4411e0a572 100644
--- 
a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceManager.java
+++ 
b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceManager.java
@@ -33,8 +33,10 @@ public class InferenceServiceManager {
 
     private static final String CACHE_SIZE_PROPERTY = 
"oak.inference.cache.size";
     private static final int CACHE_SIZE = 
SystemPropertySupplier.create(CACHE_SIZE_PROPERTY, 100).get();
+    private static final int UNCACHED_SERVICE_CACHE_SIZE = 0;
 
     private static final ConcurrentHashMap<String, InferenceService> SERVICES 
= new ConcurrentHashMap<>();
+    private  static  final InferenceServiceMetrics uncachedServiceMetrics = 
new InferenceServiceMetrics("UNCACHED_SERVICE", UNCACHED_SERVICE_CACHE_SIZE);
 
     @Deprecated
     public static InferenceService getInstance(@NotNull String url, String 
model) {
@@ -43,10 +45,10 @@ public class InferenceServiceManager {
         if (SERVICES.size() >= MAX_CACHED_SERVICES) {
             LOGGER.warning("InferenceServiceManager maximum cached services 
reached: " + MAX_CACHED_SERVICES);
             LOGGER.warning("Returning a new InferenceService instance with no 
cache");
-            return new InferenceServiceUsingIndexConfig(url, 0);
+            return new InferenceServiceUsingIndexConfig(url, 
UNCACHED_SERVICE_CACHE_SIZE, uncachedServiceMetrics);
         }
 
-        return SERVICES.computeIfAbsent(k, key -> new 
InferenceServiceUsingIndexConfig(url, CACHE_SIZE));
+        return SERVICES.computeIfAbsent(k, key -> new 
InferenceServiceUsingIndexConfig(url, CACHE_SIZE, new 
InferenceServiceMetrics(k, CACHE_SIZE)));
     }
 
     public static InferenceService getInstance(InferenceModelConfig 
inferenceModelConfig) {
@@ -58,8 +60,8 @@ public class InferenceServiceManager {
         if (SERVICES.size() >= MAX_CACHED_SERVICES) {
             LOGGER.warning("InferenceServiceManager maximum cached services 
reached: " + MAX_CACHED_SERVICES);
             LOGGER.warning("Returning a new InferenceService instance with no 
cache");
-            return new InferenceServiceUsingConfig(inferenceModelConfig);
+            return new InferenceServiceUsingConfig(inferenceModelConfig, 
uncachedServiceMetrics);
         }
-        return SERVICES.computeIfAbsent(key, k -> new 
InferenceServiceUsingConfig(inferenceModelConfig));
+        return SERVICES.computeIfAbsent(key, k -> new 
InferenceServiceUsingConfig(inferenceModelConfig, new 
InferenceServiceMetrics(k, inferenceModelConfig.getCacheSize())));
     }
 }
diff --git 
a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetrics.java
 
b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetrics.java
new file mode 100644
index 0000000000..a8ce559773
--- /dev/null
+++ 
b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetrics.java
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.jackrabbit.oak.plugins.index.elastic.query.inference;
+
+import com.codahale.metrics.Counter;
+import com.codahale.metrics.Histogram;
+import com.codahale.metrics.Meter;
+import com.codahale.metrics.MetricRegistry;
+import com.codahale.metrics.Snapshot;
+import com.codahale.metrics.Timer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+/**
+ * Collects and reports metrics for the inference service.
+ */
+public class InferenceServiceMetrics {
+    final static Logger LOG = 
LoggerFactory.getLogger(InferenceServiceMetrics.class);
+
+    // Tracks the last time metrics were logged
+    private long lastLogTimeMillis;
+    private String metricsServiceKey;
+    private int cacheSize;
+
+    // Metric constants for both output property names and registry base names
+    public static final String TOTAL_REQUESTS = "totalRequests";
+    public static final String CACHE_HITS = "cacheHits";
+    public static final String CACHE_MISSES = "cacheMisses";
+    public static final String CACHE_HIT_RATE = "cacheHitRate";
+    public static final String CACHE_SIZE = "cacheSize";
+    public static final String REQUEST_ERRORS = "requestErrors";
+    public static final String ERROR_RATE = "errorRate";
+    public static final String REQUEST_TIMEOUTS = "requestTimeouts";
+    public static final String TOTAL_REQUEST_TIME = "totalRequestTime";
+    public static final String AVG_REQUEST_TIME = "avgRequestTime";
+    public static final String MAX_REQUEST_TIME = "maxRequestTime";
+    public static final String TIME_HISTOGRAM = "timeHistogram";
+    public static final String REQUESTS = "requests";
+    public static final String ERRORS = "errors";
+    public static final String TIMEOUTS = "timeouts";
+
+    // Registry specific names (using camelCase where possible, hyphens only 
when needed for convention)
+    public static final String REQUEST_TIMER = "requestTimer";
+    public static final String REQUEST_TIMES = "requestTimes";
+    public static final String CACHE_HITS_METRIC = CACHE_HITS;
+    public static final String CACHE_MISSES_METRIC = CACHE_MISSES;
+    public static final String TOTAL_REQUESTS_METRIC = TOTAL_REQUESTS;
+    public static final String TOTAL_CACHE_HITS = "totalCacheHits";
+    public static final String TOTAL_CACHE_MISSES = "totalCacheMisses";
+    public static final String TOTAL_ERRORS = "totalErrors";
+
+    // Metric constants for histogram and percentile keys
+    private static final String KEY_COUNT = "count";
+    private static final String KEY_MIN = "min";
+    private static final String KEY_MAX = "max";
+    private static final String KEY_MEAN = "mean";
+    private static final String KEY_STD_DEV = "stdDev";
+    private static final String KEY_PERCENTILES = "percentiles";
+    private static final String KEY_REQUEST_PERCENTILES = "requestPercentiles";
+    public static final String KEY_ERROR_TIME_DATA = "errorTimeData";
+    private static final String ERROR_TIMES = "errorTimes";
+
+    // Metric constants for percentile keys
+    private static final String KEY_50TH = "50th";
+    private static final String KEY_75TH = "75th";
+    private static final String KEY_95TH = "95th";
+    private static final String KEY_98TH = "98th";
+    private static final String KEY_99TH = "99th";
+    private static final String KEY_999TH = "999th";
+
+    // Rate keys
+    private static final String KEY_REQUEST_RATE_1M = "requestRate1m";
+    private static final String KEY_REQUEST_RATE_5M = "requestRate5m";
+    private static final String KEY_ERROR_RATE_1M = "errorRate1m";
+    private static final String KEY_ERROR_RATE_5M = "errorRate5m";
+    private static final String KEY_TIMEOUT_RATE_1M = "timeoutRate1m";
+    private static final String KEY_TIMEOUT_RATE_5M = "timeoutRate5m";
+
+    private final MetricRegistry metrics = new MetricRegistry();
+
+    // Meters for rate measurements
+    private final Meter requests;
+    private final Meter hits;
+    private final Meter misses;
+    private final Meter errors;
+    private final Meter timeouts;
+
+    // Counters for absolute counts
+    private final Counter totalRequestCounter;
+    private final Counter cacheHitCounter;
+    private final Counter cacheMissCounter;
+    private final Counter errorCounter;
+
+    // Timers and histograms for timing statistics
+    private final Timer requestTimer;
+    private final Histogram requestTimes;
+    private final Histogram errorTimes;
+
+    public InferenceServiceMetrics(String metricsServiceKey, int cacheSize) {
+        this.lastLogTimeMillis = System.currentTimeMillis();
+        this.metricsServiceKey = metricsServiceKey;
+        this.cacheSize = cacheSize;
+
+        // Initialize meters
+        this.requests = 
metrics.meter(MetricRegistry.name(InferenceServiceMetrics.class, REQUESTS));
+        this.hits = 
metrics.meter(MetricRegistry.name(InferenceServiceMetrics.class, 
CACHE_HITS_METRIC));
+        this.misses = 
metrics.meter(MetricRegistry.name(InferenceServiceMetrics.class, 
CACHE_MISSES_METRIC));
+        this.errors = 
metrics.meter(MetricRegistry.name(InferenceServiceMetrics.class, ERRORS));
+        this.timeouts = 
metrics.meter(MetricRegistry.name(InferenceServiceMetrics.class, TIMEOUTS));
+
+        // Initialize counters
+        this.totalRequestCounter = 
metrics.counter(MetricRegistry.name(InferenceServiceMetrics.class, 
TOTAL_REQUESTS_METRIC));
+        this.cacheHitCounter = 
metrics.counter(MetricRegistry.name(InferenceServiceMetrics.class, 
TOTAL_CACHE_HITS));
+        this.cacheMissCounter = 
metrics.counter(MetricRegistry.name(InferenceServiceMetrics.class, 
TOTAL_CACHE_MISSES));
+        this.errorCounter = 
metrics.counter(MetricRegistry.name(InferenceServiceMetrics.class, 
TOTAL_ERRORS));
+
+        // Initialize timers and histograms
+        this.requestTimer = 
metrics.timer(MetricRegistry.name(InferenceServiceMetrics.class, 
REQUEST_TIMER));
+        this.requestTimes = 
metrics.histogram(MetricRegistry.name(InferenceServiceMetrics.class, 
REQUEST_TIMES));
+        this.errorTimes = 
metrics.histogram(MetricRegistry.name(InferenceServiceMetrics.class, 
ERROR_TIMES));
+    }
+
+    /**
+     * Records a request start
+     *
+     * @return Timer.Context that should be stopped when the request completes
+     */
+    public Timer.Context requestStarted() {
+        totalRequestCounter.inc();
+        requests.mark();
+        return requestTimer.time();
+    }
+
+    /**
+     * Records a cache hit
+     */
+    public void cacheHit() {
+        cacheHitCounter.inc();
+        hits.mark();
+    }
+
+    /**
+     * Records a cache miss
+     */
+    public void cacheMiss() {
+        cacheMissCounter.inc();
+        misses.mark();
+    }
+
+    /**
+     * Records a request error
+     *
+     * @param timeMillis   Time taken before the error occurred in milliseconds
+     * @param timerContext Timer context to stop, if available (can be null)
+     */
+    public void requestError(long timeMillis, Timer.Context timerContext) {
+        errorCounter.inc();
+        errors.mark();
+
+        // Record time in the error timer
+        errorTimes.update(timeMillis);
+
+        // Stop the request timer context if provided (this marks the end of 
the entire operation, even if it's an error)
+        if (timerContext != null) {
+            timerContext.stop();
+        }
+
+        LOG.debug("Request error occurred after {} ms", timeMillis);
+    }
+
+    /**
+     * Records a request error
+     */
+    public void requestError() {
+        errorCounter.inc();
+        errors.mark();
+
+        // Without timing information, we'll use -1 as placeholder
+        // This won't affect percentiles but will be counted in the error 
histogram
+        errorTimes.update(-1);
+
+        LOG.debug("Request error occurred (timing unknown)");
+    }
+
+    /**
+     * Records the request completion time
+     *
+     * @param timeMillis   Time taken to complete the request in milliseconds
+     * @param timerContext Timer context to stop, if available (can be null)
+     */
+    public void requestCompleted(long timeMillis, Timer.Context timerContext) {
+        // Update histogram
+        requestTimes.update(timeMillis);
+
+        // Stop timer context if provided
+        if (timerContext != null) {
+            timerContext.stop();
+        }
+
+        LOG.debug("Request completed in {} ms", timeMillis);
+    }
+
+    /**
+     * Records the request completion time
+     *
+     * @param timeMillis Time taken to complete the request in milliseconds
+     */
+    public void requestCompleted(long timeMillis) {
+        requestCompleted(timeMillis, null);
+    }
+
+    /**
+     * Returns the cache hit rate percentage (0-100).
+     */
+    public double getCacheHitRate() {
+        long hits = cacheHitCounter.getCount();
+        long misses = cacheMissCounter.getCount();
+        long total = hits + misses;
+        return total > 0 ? (hits * 100.0 / total) : 0.0;
+    }
+
+    /**
+     * Returns metrics as a map for monitoring.
+     */
+    public Map<String, Object> getMetrics() {
+        Map<String, Object> metricsMap = new LinkedHashMap<>();
+        long total = totalRequestCounter.getCount();
+        long hits = cacheHitCounter.getCount();
+        long missesCount = cacheMissCounter.getCount();
+        long errorsCount = errorCounter.getCount();
+
+        metricsMap.put(TOTAL_REQUESTS, total);
+        metricsMap.put(CACHE_HITS, hits);
+        metricsMap.put(CACHE_MISSES, missesCount);
+        metricsMap.put(CACHE_HIT_RATE, getCacheHitRate());
+        metricsMap.put(CACHE_SIZE, cacheSize);
+        metricsMap.put(REQUEST_ERRORS, errorsCount);
+        metricsMap.put(ERROR_RATE, total > 0 ? (errorsCount * 100.0 / total) : 
0.0);
+
+        // Timer statistics
+        Snapshot histSnapshot = requestTimes.getSnapshot();
+        metricsMap.put(TOTAL_REQUEST_TIME, requestTimer.getCount() * 
histSnapshot.getMean() / 1_000_000.0); // Convert from ns to ms
+        metricsMap.put(AVG_REQUEST_TIME, histSnapshot.getMean());
+        metricsMap.put(MAX_REQUEST_TIME, histSnapshot.getMax());
+
+        // Add histogram data
+        Map<String, Object> histogramData = new LinkedHashMap<>();
+        histogramData.put(KEY_COUNT, requestTimes.getCount());
+        histogramData.put(KEY_MIN, histSnapshot.getMin());
+        histogramData.put(KEY_MAX, histSnapshot.getMax());
+        histogramData.put(KEY_MEAN, histSnapshot.getMean());
+        histogramData.put(KEY_STD_DEV, histSnapshot.getStdDev());
+
+        // Add percentiles
+        Map<String, Object> percentiles = new LinkedHashMap<>();
+        percentiles.put(KEY_50TH, histSnapshot.getMedian());
+        percentiles.put(KEY_75TH, histSnapshot.get75thPercentile());
+        percentiles.put(KEY_95TH, histSnapshot.get95thPercentile());
+        percentiles.put(KEY_98TH, histSnapshot.get98thPercentile());
+        percentiles.put(KEY_99TH, histSnapshot.get99thPercentile());
+        percentiles.put(KEY_999TH, histSnapshot.get999thPercentile());
+
+        histogramData.put(KEY_REQUEST_PERCENTILES, percentiles);
+        metricsMap.put(TIME_HISTOGRAM, histogramData);
+
+        // Add error histogram data
+        if (errorCounter.getCount() > 0) {
+            Snapshot errorHistSnapshot = errorTimes.getSnapshot();
+
+            Map<String, Object> errorHistogramData = new LinkedHashMap<>();
+            errorHistogramData.put(KEY_COUNT, errorTimes.getCount());
+            errorHistogramData.put(KEY_MIN, errorHistSnapshot.getMin());
+            errorHistogramData.put(KEY_MAX, errorHistSnapshot.getMax());
+            errorHistogramData.put(KEY_MEAN, errorHistSnapshot.getMean());
+            errorHistogramData.put(KEY_STD_DEV, errorHistSnapshot.getStdDev());
+
+            // Add percentiles
+            Map<String, Object> errorPercentiles = new LinkedHashMap<>();
+            errorPercentiles.put(KEY_50TH, errorHistSnapshot.getMedian());
+            errorPercentiles.put(KEY_75TH, 
errorHistSnapshot.get75thPercentile());
+            errorPercentiles.put(KEY_95TH, 
errorHistSnapshot.get95thPercentile());
+            errorPercentiles.put(KEY_98TH, 
errorHistSnapshot.get98thPercentile());
+            errorPercentiles.put(KEY_99TH, 
errorHistSnapshot.get99thPercentile());
+            errorPercentiles.put(KEY_999TH, 
errorHistSnapshot.get999thPercentile());
+
+            errorHistogramData.put(KEY_PERCENTILES, errorPercentiles);
+            metricsMap.put(KEY_ERROR_TIME_DATA, errorHistogramData);
+        }
+
+        // Add rates
+        metricsMap.put(KEY_REQUEST_RATE_1M, requests.getOneMinuteRate());
+        metricsMap.put(KEY_REQUEST_RATE_5M, requests.getFiveMinuteRate());
+        metricsMap.put(KEY_ERROR_RATE_1M, errors.getOneMinuteRate());
+        metricsMap.put(KEY_ERROR_RATE_5M, errors.getFiveMinuteRate());
+        metricsMap.put(KEY_TIMEOUT_RATE_1M, timeouts.getOneMinuteRate());
+        metricsMap.put(KEY_TIMEOUT_RATE_5M, timeouts.getFiveMinuteRate());
+        return metricsMap;
+    }
+
+    /**
+     * Returns the total number of requests processed
+     */
+    public long getTotalRequests() {
+        return totalRequestCounter.getCount();
+    }
+
+    /**
+     * Returns the Dropwizard Metrics registry used by this class.
+     * This can be used to add additional metrics or to register reporters.
+     */
+    public MetricRegistry getMetricRegistry() {
+        return metrics;
+    }
+
+    public void logMetricsSummary() {
+        logMetricsSummary(0, 0);
+    }
+
+    public void logMetricsSummary(int intervalMillis, int 
requestCountThreshold) {
+        if (lastLogTimeMillis + intervalMillis > System.currentTimeMillis() ||
+            totalRequestCounter.getCount() > requestCountThreshold) {
+            return; // Skip logging if the interval has not passed
+        }
+
+        // Avoid format specifier issues by converting everything to strings
+        Map<String, Object> metricsMap = getMetrics();
+        double hitRate = (Double) metricsMap.get(CACHE_HIT_RATE);
+        Object avgTime = metricsMap.get(AVG_REQUEST_TIME);
+        double avgTimeValue = avgTime instanceof Long ? (double) (Long) 
avgTime : (Double) avgTime;
+
+        // Convert timer values to doubles for safe formatting
+        Snapshot timerSnapshot = requestTimes.getSnapshot();
+        double median = timerSnapshot.getMedian();
+        double p95 = timerSnapshot.get95thPercentile();
+        double p99 = timerSnapshot.get99thPercentile();
+        double maxTimer = timerSnapshot.getMax();
+        double oneMinRate = requests.getOneMinuteRate();
+        double fiveMinRate = requests.getFiveMinuteRate();
+        double fifteenMinRate = requests.getFifteenMinuteRate();
+
+        // Error timer values
+        Snapshot errorTimerSnapshot = errorTimes.getSnapshot();
+        double errorRate = (Double) metricsMap.get(ERROR_RATE);
+        double errorMedian = errorTimerSnapshot.getMedian();
+        double errorP95 = errorTimerSnapshot.get95thPercentile();
+        double errorP99 = errorTimerSnapshot.get99thPercentile();
+        double errorMaxTimer = errorTimerSnapshot.getMax();
+        double errorRate1m = errors.getOneMinuteRate();
+        double errorRate5m = errors.getFiveMinuteRate();
+        double errorRate15m = errors.getFifteenMinuteRate();
+
+        StringBuilder logMessage = new StringBuilder();
+        logMessage.append("Inference service metrics for 
").append(metricsServiceKey)
+            .append(": requests=").append(metricsMap.get(TOTAL_REQUESTS))
+            .append(", hitRate=").append(Double.toString(hitRate)).append("%")
+            .append(", 
errorRate=").append(Double.toString(errorRate)).append("%")
+            .append(", 
avgTime=").append(Double.toString(avgTimeValue)).append("ms")
+            .append(", 
maxTime=").append(metricsMap.get(MAX_REQUEST_TIME)).append("ms")
+            .append(", lastLogTime=").append(lastLogTimeMillis);
+
+        // Add percentiles
+        logMessage.append(", successPercentiles 
[50th=").append(Double.toString(median)).append("ms")
+            .append(", 95th=").append(Double.toString(p95)).append("ms")
+            .append(", 99th=").append(Double.toString(p99)).append("ms")
+            .append(", max=").append(Double.toString(maxTimer)).append("ms]");
+
+        // Add success rates
+        logMessage.append(", successRates 
[1m=").append(Double.toString(oneMinRate)).append("req/s")
+            .append(", 
5m=").append(Double.toString(fiveMinRate)).append("req/s")
+            .append(", 
15m=").append(Double.toString(fifteenMinRate)).append("req/s]");
+
+        // Add error rates
+        logMessage.append(", errorRates 
[1m=").append(Double.toString(errorRate1m)).append("err/s")
+            .append(", 
5m=").append(Double.toString(errorRate5m)).append("err/s")
+            .append(", 
15m=").append(Double.toString(errorRate15m)).append("err/s]");
+
+        // Add error percentiles
+        logMessage.append(", 
errorPercentiles=[50th=").append(Double.toString(errorMedian)).append("ms")
+            .append(", 95th=").append(Double.toString(errorP95)).append("ms")
+            .append(", 99th=").append(Double.toString(errorP99)).append("ms")
+            .append(", 
max=").append(Double.toString(errorMaxTimer)).append("ms]");
+
+        LOG.info(logMessage.toString());
+    }
+} 
\ No newline at end of file
diff --git 
a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingConfig.java
 
b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingConfig.java
index 7816a85ce8..48e22ac626 100644
--- 
a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingConfig.java
+++ 
b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingConfig.java
@@ -18,8 +18,7 @@
  */
 package org.apache.jackrabbit.oak.plugins.index.elastic.query.inference;
 
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.JsonMappingException;
+import com.codahale.metrics.Timer;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -53,9 +52,10 @@ public class InferenceServiceUsingConfig implements 
InferenceService {
     private final long timeoutMillis;
     private final InferenceModelConfig inferenceModelConfig;
     private final String[] headersValue;
+    private final InferenceServiceMetrics metrics;
 
 
-    public InferenceServiceUsingConfig(InferenceModelConfig 
inferenceModelConfig) {
+    public InferenceServiceUsingConfig(InferenceModelConfig 
inferenceModelConfig, InferenceServiceMetrics metrics) {
         try {
             this.uri = new URI(inferenceModelConfig.getEmbeddingServiceUrl());
         } catch (URISyntaxException e) {
@@ -65,6 +65,7 @@ public class InferenceServiceUsingConfig implements 
InferenceService {
         this.httpClient = HttpClient.newHttpClient();
         this.timeoutMillis = inferenceModelConfig.getTimeoutMillis();
         this.inferenceModelConfig = inferenceModelConfig;
+        this.metrics = metrics;
         this.headersValue = 
inferenceModelConfig.getHeader().getInferenceHeaderPayload()
             .entrySet().stream()
             .flatMap(e -> Stream.of(e.getKey(), e.getValue()))
@@ -78,10 +79,17 @@ public class InferenceServiceUsingConfig implements 
InferenceService {
     }
 
     public List<Float> embeddings(String text, long timeoutMillis) {
+        // Track the request
+        Timer.Context timerContext = metrics.requestStarted();
+        long startTime = System.currentTimeMillis();
+
         if (cache.containsKey(text)) {
+            metrics.cacheHit();
+            metrics.requestCompleted(System.currentTimeMillis() - startTime, 
timerContext);
             return cache.get(text);
         }
 
+        metrics.cacheMiss();
         List<Float> result = null;
         try {
             // Create the JSON payload.
@@ -108,13 +116,23 @@ public class InferenceServiceUsingConfig implements 
InferenceService {
                 });
                 result = embeddingList;
                 cache.put(text, result);
+                metrics.requestCompleted(System.currentTimeMillis() - 
startTime, timerContext);
                 return result;
+            } else {
+                metrics.requestError(System.currentTimeMillis() - startTime, 
timerContext);
+                LOG.error("Failed to get embeddings. Status code: {}, 
Response: {}", response.statusCode(), response.body());
             }
         } catch (InterruptedException e) {
             Thread.currentThread().interrupt();
+            metrics.requestError(System.currentTimeMillis() - startTime, 
timerContext);
             throw new InferenceServiceException("Failed to get embeddings", e);
         } catch (IOException e) {
+            metrics.requestError(System.currentTimeMillis() - startTime, 
timerContext);
             throw new InferenceServiceException("Unable to extract embeddings 
from inference service response", e);
+        } finally {
+            //TODO evaluate and update how often we want to log these stats.
+            // Setting it to log every 10 minutes for now.
+            metrics.logMetricsSummary(10 * 60 * 1000, Integer.MAX_VALUE);
         }
         return result;
     }
diff --git 
a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingIndexConfig.java
 
b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingIndexConfig.java
index 03badd2b8a..e5c7f42b16 100644
--- 
a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingIndexConfig.java
+++ 
b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingIndexConfig.java
@@ -20,6 +20,8 @@ package 
org.apache.jackrabbit.oak.plugins.index.elastic.query.inference;
 
 import com.fasterxml.jackson.databind.JsonNode;
 import com.fasterxml.jackson.databind.ObjectMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.net.URI;
@@ -38,15 +40,17 @@ import java.util.stream.Collectors;
  * EXPERIMENTAL: A service that sends text to an inference service and 
receives embeddings in return.
  * The embeddings are cached to avoid repeated calls to the inference service.
  */
-public class InferenceServiceUsingIndexConfig implements InferenceService{
+public class InferenceServiceUsingIndexConfig implements InferenceService {
 
+    private static final Logger LOG = 
LoggerFactory.getLogger(InferenceServiceUsingIndexConfig.class);
     private static final ObjectMapper MAPPER = new ObjectMapper();
 
     private final URI uri;
     private final Cache<String, List<Float>> cache;
     private final HttpClient httpClient;
+    private final InferenceServiceMetrics metrics;
 
-    public InferenceServiceUsingIndexConfig(String url, int cacheSize) {
+    public InferenceServiceUsingIndexConfig(String url, int cacheSize, 
InferenceServiceMetrics metrics) {
         try {
             this.uri = new URI(url);
         } catch (URISyntaxException e) {
@@ -54,6 +58,7 @@ public class InferenceServiceUsingIndexConfig implements 
InferenceService{
         }
         this.cache = new Cache<>(cacheSize);
         this.httpClient = HttpClient.newHttpClient();
+        this.metrics = metrics;
     }
 
     @Override
@@ -62,49 +67,80 @@ public class InferenceServiceUsingIndexConfig implements 
InferenceService{
     }
 
     public List<Float> embeddings(String text, long timeoutMillis) {
-        if (cache.containsKey(text)) {
-            return cache.get(text);
-        }
+        metrics.requestStarted();
+        long startTime = System.currentTimeMillis();
 
         try {
+            if (cache.containsKey(text)) {
+                metrics.cacheHit();
+                return cache.get(text);
+            }
+
+            metrics.cacheMiss();
+
             // Create the JSON payload.
             String jsonInputString = "{\"text\":\"" + text + "\"}";
 
             // Build the HttpRequest.
             HttpRequest request = HttpRequest.newBuilder()
-                    .uri(uri)
-                    .timeout(java.time.Duration.ofMillis(timeoutMillis))
-                    .header("Content-Type", "application/json; utf-8")
-                    .POST(HttpRequest.BodyPublishers.ofString(jsonInputString, 
StandardCharsets.UTF_8))
-                    .build();
+                .uri(uri)
+                .timeout(java.time.Duration.ofMillis(timeoutMillis))
+                .header("Content-Type", "application/json; utf-8")
+                .POST(HttpRequest.BodyPublishers.ofString(jsonInputString, 
StandardCharsets.UTF_8))
+                .build();
 
             // Send the request and get the response.
+            LOG.debug("Sending request to inference service: {}", uri);
             HttpResponse<String> response = httpClient.send(request, 
HttpResponse.BodyHandlers.ofString());
 
+            // Check response status
+            if (response.statusCode() != 200) {
+                metrics.requestError();
+                LOG.warn("Inference service returned non-200 status code: {} - 
{}", response.statusCode(), response.body());
+                throw new InferenceServiceException("Inference service 
returned status code: " + response.statusCode());
+            }
+
             // Parse the response string into a JsonNode.
             JsonNode jsonResponse = MAPPER.readTree(response.body());
 
             // Extract the 'embedding' property.
             JsonNode embedding = jsonResponse.get("embedding");
 
+            if (embedding == null) {
+                metrics.requestError();
+                LOG.warn("Inference service response did not contain 
'embedding' property: {}", response.body());
+                throw new InferenceServiceException("Invalid response from 
inference service: missing 'embedding' property");
+            }
+
             double[] embeddings = MAPPER.treeToValue(embedding, 
double[].class);
 
             // Convert the array of doubles to a list of floats.
             List<Float> result = Arrays.stream(embeddings)
-                    .mapToObj(d -> ((Double) d).floatValue())
-                    .collect(Collectors.toList());
+                .mapToObj(d -> ((Double) d).floatValue())
+                .collect(Collectors.toList());
 
             cache.put(text, result);
+
+            LOG.debug("Successfully retrieved embeddings for text of length 
{}", text.length());
             return result;
         } catch (InterruptedException e) {
             Thread.currentThread().interrupt();
+            metrics.requestError();
+            LOG.warn("Inference service request was interrupted", e);
             throw new InferenceServiceException("Failed to get embeddings", e);
         } catch (IOException e) {
+            metrics.requestError();
+            LOG.warn("Error communicating with inference service", e);
             throw new InferenceServiceException("Unable to extract embeddings 
from inference service response", e);
+        } finally {
+            long requestTime = System.currentTimeMillis() - startTime;
+            metrics.requestCompleted(requestTime);
+            //TODO evaluate and update how often we want to log these stats.
+            // Setting it to log every 10 minutes for now.
+            metrics.logMetricsSummary(10 * 60 * 1000, Integer.MAX_VALUE);
         }
     }
 
-
     private static class Cache<K, V> extends LinkedHashMap<K, V> {
         private final int maxEntries;
 
diff --git 
a/oak-search-elastic/src/test/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetricsTest.java
 
b/oak-search-elastic/src/test/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetricsTest.java
new file mode 100644
index 0000000000..b29393860e
--- /dev/null
+++ 
b/oak-search-elastic/src/test/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetricsTest.java
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.jackrabbit.oak.plugins.index.elastic.query.inference;
+
+import ch.qos.logback.classic.Level;
+import com.codahale.metrics.Timer;
+import org.apache.jackrabbit.oak.commons.junit.LogCustomizer;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+public class InferenceServiceMetricsTest {
+
+    private InferenceServiceMetrics metrics;
+    private static final String TEST_SERVICE_KEY = "testService";
+    private static final int TEST_CACHE_SIZE = 100;
+    private static final String KEY_REQUEST_PERCENTILES = "requestPercentiles";
+
+    @Before
+    public void setUp() {
+        metrics = new InferenceServiceMetrics(TEST_SERVICE_KEY, 
TEST_CACHE_SIZE);
+    }
+
+    @Test
+    public void testInitialState() {
+        assertEquals(0, metrics.getTotalRequests());
+
+        Map<String, Object> metricsMap = metrics.getMetrics();
+        assertEquals(0L, 
metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS));
+        assertEquals(0L, metricsMap.get(InferenceServiceMetrics.CACHE_HITS));
+        assertEquals(0L, metricsMap.get(InferenceServiceMetrics.CACHE_MISSES));
+        assertEquals(0.0, 
metricsMap.get(InferenceServiceMetrics.CACHE_HIT_RATE));
+        assertEquals(TEST_CACHE_SIZE, 
metricsMap.get(InferenceServiceMetrics.CACHE_SIZE));
+        assertEquals(0L, 
metricsMap.get(InferenceServiceMetrics.REQUEST_ERRORS));
+        assertEquals(0.0, metricsMap.get(InferenceServiceMetrics.ERROR_RATE));
+    }
+
+    @Test
+    public void testRequestTracking() {
+        // Start a request
+        Timer.Context context = metrics.requestStarted();
+        assertEquals(1, metrics.getTotalRequests());
+
+        // Complete the request
+        metrics.requestCompleted(150, context);
+
+        Map<String, Object> metricsMap = metrics.getMetrics();
+        assertEquals(1L, 
metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS));
+
+        // Second request without context
+        metrics.requestStarted();
+        metrics.requestCompleted(200);
+
+        metricsMap = metrics.getMetrics();
+        assertEquals(2L, 
metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS));
+    }
+
+    @Test
+    public void testCacheHitRate() {
+        // No cache activity
+        assertEquals(0.0, metrics.getCacheHitRate(), 0.001);
+
+        // Record some hits and misses
+        metrics.cacheHit();
+        metrics.cacheHit();
+        metrics.cacheMiss();
+
+        // Should be 2/3 = 66.67%
+        assertEquals(66.67, metrics.getCacheHitRate(), 0.01);
+
+        // Add more misses
+        metrics.cacheMiss();
+        metrics.cacheMiss();
+
+        // Should be 2/5 = 40%
+        assertEquals(40.0, metrics.getCacheHitRate(), 0.001);
+
+        Map<String, Object> metricsMap = metrics.getMetrics();
+        assertEquals(2L, metricsMap.get(InferenceServiceMetrics.CACHE_HITS));
+        assertEquals(3L, metricsMap.get(InferenceServiceMetrics.CACHE_MISSES));
+        assertEquals(40.0, 
metricsMap.get(InferenceServiceMetrics.CACHE_HIT_RATE));
+    }
+
+    @Test
+    public void testErrorTracking() {
+        // Start a request and record an error
+        Timer.Context context = metrics.requestStarted();
+        metrics.requestError(100, context);
+
+        // Start another request and record an error without timing
+        metrics.requestStarted();
+        metrics.requestError();
+
+        Map<String, Object> metricsMap = metrics.getMetrics();
+        assertEquals(2L, 
metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS));
+        assertEquals(2L, 
metricsMap.get(InferenceServiceMetrics.REQUEST_ERRORS));
+        assertEquals(100.0, 
metricsMap.get(InferenceServiceMetrics.ERROR_RATE));
+
+        // Add a successful request
+        metrics.requestStarted();
+        metrics.requestCompleted(150);
+
+        metricsMap = metrics.getMetrics();
+        assertEquals(3L, 
metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS));
+        assertEquals(2L, 
metricsMap.get(InferenceServiceMetrics.REQUEST_ERRORS));
+        assertEquals(66.67, (double) 
metricsMap.get(InferenceServiceMetrics.ERROR_RATE), 0.01);
+    }
+
+    @Test
+    public void testTimingMetrics() {
+        // Record multiple requests with different timings
+        for (int i = 0; i < 10; i++) {
+            Timer.Context context = metrics.requestStarted();
+            metrics.requestCompleted(100 + (i * 50), context);
+        }
+
+        Map<String, Object> metricsMap = metrics.getMetrics();
+
+        // Check time histogram exists
+        
assertTrue(metricsMap.containsKey(InferenceServiceMetrics.TIME_HISTOGRAM));
+
+        @SuppressWarnings("unchecked")
+        Map<String, Object> histogram = (Map<String, Object>) 
metricsMap.get(InferenceServiceMetrics.TIME_HISTOGRAM);
+
+        // Verify histogram has expected metrics
+        assertEquals(10L, histogram.get("count"));
+        assertTrue(histogram.containsKey("min"));
+        assertTrue(histogram.containsKey("max"));
+        assertTrue(histogram.containsKey("mean"));
+        assertTrue(histogram.containsKey("stdDev"));
+
+        // Verify percentiles exist
+        assertTrue(histogram.containsKey(KEY_REQUEST_PERCENTILES));
+    }
+
+    @Test
+    public void testMetricRegistry() {
+        // Verify metric registry is available
+        assertNotNull(metrics.getMetricRegistry());
+    }
+
+    @Test
+    public void testLogMetricsSummaryOutput() {
+        // Setup the LogCustomizer to capture log messages from 
InferenceServiceMetrics
+        LogCustomizer custom = LogCustomizer
+            .forLogger(InferenceServiceMetrics.class.getName())
+            .enable(Level.INFO)
+            .create();
+
+        try {
+            custom.starting();
+
+            // 1. Generate comprehensive metrics data
+            // Multiple requests with different timing values
+            for (int i = 0; i < 5; i++) {
+                Timer.Context context = metrics.requestStarted();
+                metrics.requestCompleted(100 + (i * 50), context);
+            }
+
+            // Add cache hits and misses to test hit rate
+            for (int i = 0; i < 3; i++) {
+                metrics.cacheHit();
+            }
+            for (int i = 0; i < 2; i++) {
+                metrics.cacheMiss();
+            }
+
+            // Add some errors to test error rate
+            for (int i = 0; i < 2; i++) {
+                Timer.Context context = metrics.requestStarted();
+                metrics.requestError(75 + i * 25, context);
+            }
+
+            // At this point we should have:
+            // - 7 total requests (5 successful, 2 errors)
+            // - 3 cache hits, 2 cache misses (60% hit rate)
+            // - 2 errors (28.6% error rate)
+            // - Various timing metrics
+
+            // Verify metrics were recorded correctly
+            Map<String, Object> metricsMap = metrics.getMetrics();
+            assertEquals(7L, 
metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS));
+            assertEquals(3L, 
metricsMap.get(InferenceServiceMetrics.CACHE_HITS));
+            assertEquals(2L, 
metricsMap.get(InferenceServiceMetrics.CACHE_MISSES));
+            assertEquals(60.0, 
metricsMap.get(InferenceServiceMetrics.CACHE_HIT_RATE));
+            assertEquals(2L, 
metricsMap.get(InferenceServiceMetrics.REQUEST_ERRORS));
+            assertEquals(28.57, (double) 
metricsMap.get(InferenceServiceMetrics.ERROR_RATE), 0.01);
+
+            // Check both histograms exist
+            @SuppressWarnings("unchecked")
+            Map<String, Object> histogram = (Map<String, Object>) 
metricsMap.get(InferenceServiceMetrics.TIME_HISTOGRAM);
+            assertEquals(5L, ((Number) histogram.get("count")).longValue());
+
+            @SuppressWarnings("unchecked")
+            Map<String, Object> errorHistogram = (Map<String, Object>) 
metricsMap.get(InferenceServiceMetrics.KEY_ERROR_TIME_DATA);
+            assertNotNull("Error histogram should exist", errorHistogram);
+            assertEquals(2L, ((Number) 
errorHistogram.get("count")).longValue());
+
+            // 2. Log the metrics summary
+            // Force logging regardless of time interval by using parameters 
that won't trigger early return
+            metrics.logMetricsSummary(0, 100);
+
+            // 3. Verify the log output contains all expected metrics
+            List<String> logs = custom.getLogs();
+            assertFalse("Log should contain at least one entry", 
logs.isEmpty());
+
+            String logMessage = logs.get(0);
+            assertTrue("Log should contain the service key", 
logMessage.contains(TEST_SERVICE_KEY));
+
+            // Check all metrics are in the log
+            assertTrue("Log should contain requests count", 
logMessage.contains("requests=7"));
+            assertTrue("Log should contain hit rate", 
logMessage.contains("hitRate="));
+            assertTrue("Log should contain error rate", 
logMessage.contains("errorRate="));
+            assertTrue("Log should contain avgTime", 
logMessage.contains("avgTime="));
+            assertTrue("Log should contain maxTime", 
logMessage.contains("maxTime="));
+            assertTrue("Log should contain percentiles", 
logMessage.contains("successPercentiles [50th="));
+            assertTrue("Log should contain rates", 
logMessage.contains("successRates [1m="));
+            assertTrue("Log should contain error rates", 
logMessage.contains("errorRates [1m="));
+
+            // Check additional metrics present in updated implementation
+            assertTrue("Log should contain error rate metrics", 
logMessage.contains("err/s"));
+
+            // Check for request rate metrics
+            assertTrue("Log should contain 1-minute request rate", 
logMessage.contains("1m="));
+            assertTrue("Log should contain 5-minute request rate", 
logMessage.contains("5m="));
+            assertTrue("Log should contain 15-minute request rate", 
logMessage.contains("15m="));
+
+            // Verify error histogram details are included in the log output
+            assertTrue("Log should include error timing information",
+                
logMessage.contains(InferenceServiceMetrics.KEY_ERROR_TIME_DATA) ||
+                    
metricsMap.containsKey(InferenceServiceMetrics.KEY_ERROR_TIME_DATA));
+        } finally {
+            custom.finished();
+        }
+    }
+} 
\ No newline at end of file


Reply via email to