This is an automated email from the ASF dual-hosted git repository. mkataria pushed a commit to branch OAK-11716 in repository https://gitbox.apache.org/repos/asf/jackrabbit-oak.git
commit a31ddf71797265ec2d574bcc5c08111b671e7a75 Author: Mohit Kataria <[email protected]> AuthorDate: Sat May 10 17:26:36 2025 +0530 OAK-11716: Capture inference service stats --- oak-search-elastic/pom.xml | 6 + .../query/inference/InferenceServiceManager.java | 10 +- .../query/inference/InferenceServiceMetrics.java | 404 +++++++++++++++++++++ .../inference/InferenceServiceUsingConfig.java | 24 +- .../InferenceServiceUsingIndexConfig.java | 62 +++- .../inference/InferenceServiceMetricsTest.java | 259 +++++++++++++ 6 files changed, 745 insertions(+), 20 deletions(-) diff --git a/oak-search-elastic/pom.xml b/oak-search-elastic/pom.xml index 4f951259c6..7c49696898 100644 --- a/oak-search-elastic/pom.xml +++ b/oak-search-elastic/pom.xml @@ -221,6 +221,12 @@ <scope>provided</scope> </dependency> + <!-- Metrics --> + <dependency> + <groupId>io.dropwizard.metrics</groupId> + <artifactId>metrics-core</artifactId> + </dependency> + <!-- Nullability annotations --> <dependency> <groupId>org.jetbrains</groupId> diff --git a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceManager.java b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceManager.java index a5b2fe389f..4411e0a572 100644 --- a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceManager.java +++ b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceManager.java @@ -33,8 +33,10 @@ public class InferenceServiceManager { private static final String CACHE_SIZE_PROPERTY = "oak.inference.cache.size"; private static final int CACHE_SIZE = SystemPropertySupplier.create(CACHE_SIZE_PROPERTY, 100).get(); + private static final int UNCACHED_SERVICE_CACHE_SIZE = 0; private static final ConcurrentHashMap<String, InferenceService> SERVICES = new ConcurrentHashMap<>(); + private static final InferenceServiceMetrics uncachedServiceMetrics = new InferenceServiceMetrics("UNCACHED_SERVICE", UNCACHED_SERVICE_CACHE_SIZE); @Deprecated public static InferenceService getInstance(@NotNull String url, String model) { @@ -43,10 +45,10 @@ public class InferenceServiceManager { if (SERVICES.size() >= MAX_CACHED_SERVICES) { LOGGER.warning("InferenceServiceManager maximum cached services reached: " + MAX_CACHED_SERVICES); LOGGER.warning("Returning a new InferenceService instance with no cache"); - return new InferenceServiceUsingIndexConfig(url, 0); + return new InferenceServiceUsingIndexConfig(url, UNCACHED_SERVICE_CACHE_SIZE, uncachedServiceMetrics); } - return SERVICES.computeIfAbsent(k, key -> new InferenceServiceUsingIndexConfig(url, CACHE_SIZE)); + return SERVICES.computeIfAbsent(k, key -> new InferenceServiceUsingIndexConfig(url, CACHE_SIZE, new InferenceServiceMetrics(k, CACHE_SIZE))); } public static InferenceService getInstance(InferenceModelConfig inferenceModelConfig) { @@ -58,8 +60,8 @@ public class InferenceServiceManager { if (SERVICES.size() >= MAX_CACHED_SERVICES) { LOGGER.warning("InferenceServiceManager maximum cached services reached: " + MAX_CACHED_SERVICES); LOGGER.warning("Returning a new InferenceService instance with no cache"); - return new InferenceServiceUsingConfig(inferenceModelConfig); + return new InferenceServiceUsingConfig(inferenceModelConfig, uncachedServiceMetrics); } - return SERVICES.computeIfAbsent(key, k -> new InferenceServiceUsingConfig(inferenceModelConfig)); + return SERVICES.computeIfAbsent(key, k -> new InferenceServiceUsingConfig(inferenceModelConfig, new InferenceServiceMetrics(k, inferenceModelConfig.getCacheSize()))); } } diff --git a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetrics.java b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetrics.java new file mode 100644 index 0000000000..a8ce559773 --- /dev/null +++ b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetrics.java @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.jackrabbit.oak.plugins.index.elastic.query.inference; + +import com.codahale.metrics.Counter; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.Meter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.Snapshot; +import com.codahale.metrics.Timer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Collects and reports metrics for the inference service. + */ +public class InferenceServiceMetrics { + final static Logger LOG = LoggerFactory.getLogger(InferenceServiceMetrics.class); + + // Tracks the last time metrics were logged + private long lastLogTimeMillis; + private String metricsServiceKey; + private int cacheSize; + + // Metric constants for both output property names and registry base names + public static final String TOTAL_REQUESTS = "totalRequests"; + public static final String CACHE_HITS = "cacheHits"; + public static final String CACHE_MISSES = "cacheMisses"; + public static final String CACHE_HIT_RATE = "cacheHitRate"; + public static final String CACHE_SIZE = "cacheSize"; + public static final String REQUEST_ERRORS = "requestErrors"; + public static final String ERROR_RATE = "errorRate"; + public static final String REQUEST_TIMEOUTS = "requestTimeouts"; + public static final String TOTAL_REQUEST_TIME = "totalRequestTime"; + public static final String AVG_REQUEST_TIME = "avgRequestTime"; + public static final String MAX_REQUEST_TIME = "maxRequestTime"; + public static final String TIME_HISTOGRAM = "timeHistogram"; + public static final String REQUESTS = "requests"; + public static final String ERRORS = "errors"; + public static final String TIMEOUTS = "timeouts"; + + // Registry specific names (using camelCase where possible, hyphens only when needed for convention) + public static final String REQUEST_TIMER = "requestTimer"; + public static final String REQUEST_TIMES = "requestTimes"; + public static final String CACHE_HITS_METRIC = CACHE_HITS; + public static final String CACHE_MISSES_METRIC = CACHE_MISSES; + public static final String TOTAL_REQUESTS_METRIC = TOTAL_REQUESTS; + public static final String TOTAL_CACHE_HITS = "totalCacheHits"; + public static final String TOTAL_CACHE_MISSES = "totalCacheMisses"; + public static final String TOTAL_ERRORS = "totalErrors"; + + // Metric constants for histogram and percentile keys + private static final String KEY_COUNT = "count"; + private static final String KEY_MIN = "min"; + private static final String KEY_MAX = "max"; + private static final String KEY_MEAN = "mean"; + private static final String KEY_STD_DEV = "stdDev"; + private static final String KEY_PERCENTILES = "percentiles"; + private static final String KEY_REQUEST_PERCENTILES = "requestPercentiles"; + public static final String KEY_ERROR_TIME_DATA = "errorTimeData"; + private static final String ERROR_TIMES = "errorTimes"; + + // Metric constants for percentile keys + private static final String KEY_50TH = "50th"; + private static final String KEY_75TH = "75th"; + private static final String KEY_95TH = "95th"; + private static final String KEY_98TH = "98th"; + private static final String KEY_99TH = "99th"; + private static final String KEY_999TH = "999th"; + + // Rate keys + private static final String KEY_REQUEST_RATE_1M = "requestRate1m"; + private static final String KEY_REQUEST_RATE_5M = "requestRate5m"; + private static final String KEY_ERROR_RATE_1M = "errorRate1m"; + private static final String KEY_ERROR_RATE_5M = "errorRate5m"; + private static final String KEY_TIMEOUT_RATE_1M = "timeoutRate1m"; + private static final String KEY_TIMEOUT_RATE_5M = "timeoutRate5m"; + + private final MetricRegistry metrics = new MetricRegistry(); + + // Meters for rate measurements + private final Meter requests; + private final Meter hits; + private final Meter misses; + private final Meter errors; + private final Meter timeouts; + + // Counters for absolute counts + private final Counter totalRequestCounter; + private final Counter cacheHitCounter; + private final Counter cacheMissCounter; + private final Counter errorCounter; + + // Timers and histograms for timing statistics + private final Timer requestTimer; + private final Histogram requestTimes; + private final Histogram errorTimes; + + public InferenceServiceMetrics(String metricsServiceKey, int cacheSize) { + this.lastLogTimeMillis = System.currentTimeMillis(); + this.metricsServiceKey = metricsServiceKey; + this.cacheSize = cacheSize; + + // Initialize meters + this.requests = metrics.meter(MetricRegistry.name(InferenceServiceMetrics.class, REQUESTS)); + this.hits = metrics.meter(MetricRegistry.name(InferenceServiceMetrics.class, CACHE_HITS_METRIC)); + this.misses = metrics.meter(MetricRegistry.name(InferenceServiceMetrics.class, CACHE_MISSES_METRIC)); + this.errors = metrics.meter(MetricRegistry.name(InferenceServiceMetrics.class, ERRORS)); + this.timeouts = metrics.meter(MetricRegistry.name(InferenceServiceMetrics.class, TIMEOUTS)); + + // Initialize counters + this.totalRequestCounter = metrics.counter(MetricRegistry.name(InferenceServiceMetrics.class, TOTAL_REQUESTS_METRIC)); + this.cacheHitCounter = metrics.counter(MetricRegistry.name(InferenceServiceMetrics.class, TOTAL_CACHE_HITS)); + this.cacheMissCounter = metrics.counter(MetricRegistry.name(InferenceServiceMetrics.class, TOTAL_CACHE_MISSES)); + this.errorCounter = metrics.counter(MetricRegistry.name(InferenceServiceMetrics.class, TOTAL_ERRORS)); + + // Initialize timers and histograms + this.requestTimer = metrics.timer(MetricRegistry.name(InferenceServiceMetrics.class, REQUEST_TIMER)); + this.requestTimes = metrics.histogram(MetricRegistry.name(InferenceServiceMetrics.class, REQUEST_TIMES)); + this.errorTimes = metrics.histogram(MetricRegistry.name(InferenceServiceMetrics.class, ERROR_TIMES)); + } + + /** + * Records a request start + * + * @return Timer.Context that should be stopped when the request completes + */ + public Timer.Context requestStarted() { + totalRequestCounter.inc(); + requests.mark(); + return requestTimer.time(); + } + + /** + * Records a cache hit + */ + public void cacheHit() { + cacheHitCounter.inc(); + hits.mark(); + } + + /** + * Records a cache miss + */ + public void cacheMiss() { + cacheMissCounter.inc(); + misses.mark(); + } + + /** + * Records a request error + * + * @param timeMillis Time taken before the error occurred in milliseconds + * @param timerContext Timer context to stop, if available (can be null) + */ + public void requestError(long timeMillis, Timer.Context timerContext) { + errorCounter.inc(); + errors.mark(); + + // Record time in the error timer + errorTimes.update(timeMillis); + + // Stop the request timer context if provided (this marks the end of the entire operation, even if it's an error) + if (timerContext != null) { + timerContext.stop(); + } + + LOG.debug("Request error occurred after {} ms", timeMillis); + } + + /** + * Records a request error + */ + public void requestError() { + errorCounter.inc(); + errors.mark(); + + // Without timing information, we'll use -1 as placeholder + // This won't affect percentiles but will be counted in the error histogram + errorTimes.update(-1); + + LOG.debug("Request error occurred (timing unknown)"); + } + + /** + * Records the request completion time + * + * @param timeMillis Time taken to complete the request in milliseconds + * @param timerContext Timer context to stop, if available (can be null) + */ + public void requestCompleted(long timeMillis, Timer.Context timerContext) { + // Update histogram + requestTimes.update(timeMillis); + + // Stop timer context if provided + if (timerContext != null) { + timerContext.stop(); + } + + LOG.debug("Request completed in {} ms", timeMillis); + } + + /** + * Records the request completion time + * + * @param timeMillis Time taken to complete the request in milliseconds + */ + public void requestCompleted(long timeMillis) { + requestCompleted(timeMillis, null); + } + + /** + * Returns the cache hit rate percentage (0-100). + */ + public double getCacheHitRate() { + long hits = cacheHitCounter.getCount(); + long misses = cacheMissCounter.getCount(); + long total = hits + misses; + return total > 0 ? (hits * 100.0 / total) : 0.0; + } + + /** + * Returns metrics as a map for monitoring. + */ + public Map<String, Object> getMetrics() { + Map<String, Object> metricsMap = new LinkedHashMap<>(); + long total = totalRequestCounter.getCount(); + long hits = cacheHitCounter.getCount(); + long missesCount = cacheMissCounter.getCount(); + long errorsCount = errorCounter.getCount(); + + metricsMap.put(TOTAL_REQUESTS, total); + metricsMap.put(CACHE_HITS, hits); + metricsMap.put(CACHE_MISSES, missesCount); + metricsMap.put(CACHE_HIT_RATE, getCacheHitRate()); + metricsMap.put(CACHE_SIZE, cacheSize); + metricsMap.put(REQUEST_ERRORS, errorsCount); + metricsMap.put(ERROR_RATE, total > 0 ? (errorsCount * 100.0 / total) : 0.0); + + // Timer statistics + Snapshot histSnapshot = requestTimes.getSnapshot(); + metricsMap.put(TOTAL_REQUEST_TIME, requestTimer.getCount() * histSnapshot.getMean() / 1_000_000.0); // Convert from ns to ms + metricsMap.put(AVG_REQUEST_TIME, histSnapshot.getMean()); + metricsMap.put(MAX_REQUEST_TIME, histSnapshot.getMax()); + + // Add histogram data + Map<String, Object> histogramData = new LinkedHashMap<>(); + histogramData.put(KEY_COUNT, requestTimes.getCount()); + histogramData.put(KEY_MIN, histSnapshot.getMin()); + histogramData.put(KEY_MAX, histSnapshot.getMax()); + histogramData.put(KEY_MEAN, histSnapshot.getMean()); + histogramData.put(KEY_STD_DEV, histSnapshot.getStdDev()); + + // Add percentiles + Map<String, Object> percentiles = new LinkedHashMap<>(); + percentiles.put(KEY_50TH, histSnapshot.getMedian()); + percentiles.put(KEY_75TH, histSnapshot.get75thPercentile()); + percentiles.put(KEY_95TH, histSnapshot.get95thPercentile()); + percentiles.put(KEY_98TH, histSnapshot.get98thPercentile()); + percentiles.put(KEY_99TH, histSnapshot.get99thPercentile()); + percentiles.put(KEY_999TH, histSnapshot.get999thPercentile()); + + histogramData.put(KEY_REQUEST_PERCENTILES, percentiles); + metricsMap.put(TIME_HISTOGRAM, histogramData); + + // Add error histogram data + if (errorCounter.getCount() > 0) { + Snapshot errorHistSnapshot = errorTimes.getSnapshot(); + + Map<String, Object> errorHistogramData = new LinkedHashMap<>(); + errorHistogramData.put(KEY_COUNT, errorTimes.getCount()); + errorHistogramData.put(KEY_MIN, errorHistSnapshot.getMin()); + errorHistogramData.put(KEY_MAX, errorHistSnapshot.getMax()); + errorHistogramData.put(KEY_MEAN, errorHistSnapshot.getMean()); + errorHistogramData.put(KEY_STD_DEV, errorHistSnapshot.getStdDev()); + + // Add percentiles + Map<String, Object> errorPercentiles = new LinkedHashMap<>(); + errorPercentiles.put(KEY_50TH, errorHistSnapshot.getMedian()); + errorPercentiles.put(KEY_75TH, errorHistSnapshot.get75thPercentile()); + errorPercentiles.put(KEY_95TH, errorHistSnapshot.get95thPercentile()); + errorPercentiles.put(KEY_98TH, errorHistSnapshot.get98thPercentile()); + errorPercentiles.put(KEY_99TH, errorHistSnapshot.get99thPercentile()); + errorPercentiles.put(KEY_999TH, errorHistSnapshot.get999thPercentile()); + + errorHistogramData.put(KEY_PERCENTILES, errorPercentiles); + metricsMap.put(KEY_ERROR_TIME_DATA, errorHistogramData); + } + + // Add rates + metricsMap.put(KEY_REQUEST_RATE_1M, requests.getOneMinuteRate()); + metricsMap.put(KEY_REQUEST_RATE_5M, requests.getFiveMinuteRate()); + metricsMap.put(KEY_ERROR_RATE_1M, errors.getOneMinuteRate()); + metricsMap.put(KEY_ERROR_RATE_5M, errors.getFiveMinuteRate()); + metricsMap.put(KEY_TIMEOUT_RATE_1M, timeouts.getOneMinuteRate()); + metricsMap.put(KEY_TIMEOUT_RATE_5M, timeouts.getFiveMinuteRate()); + return metricsMap; + } + + /** + * Returns the total number of requests processed + */ + public long getTotalRequests() { + return totalRequestCounter.getCount(); + } + + /** + * Returns the Dropwizard Metrics registry used by this class. + * This can be used to add additional metrics or to register reporters. + */ + public MetricRegistry getMetricRegistry() { + return metrics; + } + + public void logMetricsSummary() { + logMetricsSummary(0, 0); + } + + public void logMetricsSummary(int intervalMillis, int requestCountThreshold) { + if (lastLogTimeMillis + intervalMillis > System.currentTimeMillis() || + totalRequestCounter.getCount() > requestCountThreshold) { + return; // Skip logging if the interval has not passed + } + + // Avoid format specifier issues by converting everything to strings + Map<String, Object> metricsMap = getMetrics(); + double hitRate = (Double) metricsMap.get(CACHE_HIT_RATE); + Object avgTime = metricsMap.get(AVG_REQUEST_TIME); + double avgTimeValue = avgTime instanceof Long ? (double) (Long) avgTime : (Double) avgTime; + + // Convert timer values to doubles for safe formatting + Snapshot timerSnapshot = requestTimes.getSnapshot(); + double median = timerSnapshot.getMedian(); + double p95 = timerSnapshot.get95thPercentile(); + double p99 = timerSnapshot.get99thPercentile(); + double maxTimer = timerSnapshot.getMax(); + double oneMinRate = requests.getOneMinuteRate(); + double fiveMinRate = requests.getFiveMinuteRate(); + double fifteenMinRate = requests.getFifteenMinuteRate(); + + // Error timer values + Snapshot errorTimerSnapshot = errorTimes.getSnapshot(); + double errorRate = (Double) metricsMap.get(ERROR_RATE); + double errorMedian = errorTimerSnapshot.getMedian(); + double errorP95 = errorTimerSnapshot.get95thPercentile(); + double errorP99 = errorTimerSnapshot.get99thPercentile(); + double errorMaxTimer = errorTimerSnapshot.getMax(); + double errorRate1m = errors.getOneMinuteRate(); + double errorRate5m = errors.getFiveMinuteRate(); + double errorRate15m = errors.getFifteenMinuteRate(); + + StringBuilder logMessage = new StringBuilder(); + logMessage.append("Inference service metrics for ").append(metricsServiceKey) + .append(": requests=").append(metricsMap.get(TOTAL_REQUESTS)) + .append(", hitRate=").append(Double.toString(hitRate)).append("%") + .append(", errorRate=").append(Double.toString(errorRate)).append("%") + .append(", avgTime=").append(Double.toString(avgTimeValue)).append("ms") + .append(", maxTime=").append(metricsMap.get(MAX_REQUEST_TIME)).append("ms") + .append(", lastLogTime=").append(lastLogTimeMillis); + + // Add percentiles + logMessage.append(", successPercentiles [50th=").append(Double.toString(median)).append("ms") + .append(", 95th=").append(Double.toString(p95)).append("ms") + .append(", 99th=").append(Double.toString(p99)).append("ms") + .append(", max=").append(Double.toString(maxTimer)).append("ms]"); + + // Add success rates + logMessage.append(", successRates [1m=").append(Double.toString(oneMinRate)).append("req/s") + .append(", 5m=").append(Double.toString(fiveMinRate)).append("req/s") + .append(", 15m=").append(Double.toString(fifteenMinRate)).append("req/s]"); + + // Add error rates + logMessage.append(", errorRates [1m=").append(Double.toString(errorRate1m)).append("err/s") + .append(", 5m=").append(Double.toString(errorRate5m)).append("err/s") + .append(", 15m=").append(Double.toString(errorRate15m)).append("err/s]"); + + // Add error percentiles + logMessage.append(", errorPercentiles=[50th=").append(Double.toString(errorMedian)).append("ms") + .append(", 95th=").append(Double.toString(errorP95)).append("ms") + .append(", 99th=").append(Double.toString(errorP99)).append("ms") + .append(", max=").append(Double.toString(errorMaxTimer)).append("ms]"); + + LOG.info(logMessage.toString()); + } +} \ No newline at end of file diff --git a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingConfig.java b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingConfig.java index 7816a85ce8..48e22ac626 100644 --- a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingConfig.java +++ b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingConfig.java @@ -18,8 +18,7 @@ */ package org.apache.jackrabbit.oak.plugins.index.elastic.query.inference; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonMappingException; +import com.codahale.metrics.Timer; import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -53,9 +52,10 @@ public class InferenceServiceUsingConfig implements InferenceService { private final long timeoutMillis; private final InferenceModelConfig inferenceModelConfig; private final String[] headersValue; + private final InferenceServiceMetrics metrics; - public InferenceServiceUsingConfig(InferenceModelConfig inferenceModelConfig) { + public InferenceServiceUsingConfig(InferenceModelConfig inferenceModelConfig, InferenceServiceMetrics metrics) { try { this.uri = new URI(inferenceModelConfig.getEmbeddingServiceUrl()); } catch (URISyntaxException e) { @@ -65,6 +65,7 @@ public class InferenceServiceUsingConfig implements InferenceService { this.httpClient = HttpClient.newHttpClient(); this.timeoutMillis = inferenceModelConfig.getTimeoutMillis(); this.inferenceModelConfig = inferenceModelConfig; + this.metrics = metrics; this.headersValue = inferenceModelConfig.getHeader().getInferenceHeaderPayload() .entrySet().stream() .flatMap(e -> Stream.of(e.getKey(), e.getValue())) @@ -78,10 +79,17 @@ public class InferenceServiceUsingConfig implements InferenceService { } public List<Float> embeddings(String text, long timeoutMillis) { + // Track the request + Timer.Context timerContext = metrics.requestStarted(); + long startTime = System.currentTimeMillis(); + if (cache.containsKey(text)) { + metrics.cacheHit(); + metrics.requestCompleted(System.currentTimeMillis() - startTime, timerContext); return cache.get(text); } + metrics.cacheMiss(); List<Float> result = null; try { // Create the JSON payload. @@ -108,13 +116,23 @@ public class InferenceServiceUsingConfig implements InferenceService { }); result = embeddingList; cache.put(text, result); + metrics.requestCompleted(System.currentTimeMillis() - startTime, timerContext); return result; + } else { + metrics.requestError(System.currentTimeMillis() - startTime, timerContext); + LOG.error("Failed to get embeddings. Status code: {}, Response: {}", response.statusCode(), response.body()); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); + metrics.requestError(System.currentTimeMillis() - startTime, timerContext); throw new InferenceServiceException("Failed to get embeddings", e); } catch (IOException e) { + metrics.requestError(System.currentTimeMillis() - startTime, timerContext); throw new InferenceServiceException("Unable to extract embeddings from inference service response", e); + } finally { + //TODO evaluate and update how often we want to log these stats. + // Setting it to log every 10 minutes for now. + metrics.logMetricsSummary(10 * 60 * 1000, Integer.MAX_VALUE); } return result; } diff --git a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingIndexConfig.java b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingIndexConfig.java index 03badd2b8a..e5c7f42b16 100644 --- a/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingIndexConfig.java +++ b/oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceUsingIndexConfig.java @@ -20,6 +20,8 @@ package org.apache.jackrabbit.oak.plugins.index.elastic.query.inference; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.net.URI; @@ -38,15 +40,17 @@ import java.util.stream.Collectors; * EXPERIMENTAL: A service that sends text to an inference service and receives embeddings in return. * The embeddings are cached to avoid repeated calls to the inference service. */ -public class InferenceServiceUsingIndexConfig implements InferenceService{ +public class InferenceServiceUsingIndexConfig implements InferenceService { + private static final Logger LOG = LoggerFactory.getLogger(InferenceServiceUsingIndexConfig.class); private static final ObjectMapper MAPPER = new ObjectMapper(); private final URI uri; private final Cache<String, List<Float>> cache; private final HttpClient httpClient; + private final InferenceServiceMetrics metrics; - public InferenceServiceUsingIndexConfig(String url, int cacheSize) { + public InferenceServiceUsingIndexConfig(String url, int cacheSize, InferenceServiceMetrics metrics) { try { this.uri = new URI(url); } catch (URISyntaxException e) { @@ -54,6 +58,7 @@ public class InferenceServiceUsingIndexConfig implements InferenceService{ } this.cache = new Cache<>(cacheSize); this.httpClient = HttpClient.newHttpClient(); + this.metrics = metrics; } @Override @@ -62,49 +67,80 @@ public class InferenceServiceUsingIndexConfig implements InferenceService{ } public List<Float> embeddings(String text, long timeoutMillis) { - if (cache.containsKey(text)) { - return cache.get(text); - } + metrics.requestStarted(); + long startTime = System.currentTimeMillis(); try { + if (cache.containsKey(text)) { + metrics.cacheHit(); + return cache.get(text); + } + + metrics.cacheMiss(); + // Create the JSON payload. String jsonInputString = "{\"text\":\"" + text + "\"}"; // Build the HttpRequest. HttpRequest request = HttpRequest.newBuilder() - .uri(uri) - .timeout(java.time.Duration.ofMillis(timeoutMillis)) - .header("Content-Type", "application/json; utf-8") - .POST(HttpRequest.BodyPublishers.ofString(jsonInputString, StandardCharsets.UTF_8)) - .build(); + .uri(uri) + .timeout(java.time.Duration.ofMillis(timeoutMillis)) + .header("Content-Type", "application/json; utf-8") + .POST(HttpRequest.BodyPublishers.ofString(jsonInputString, StandardCharsets.UTF_8)) + .build(); // Send the request and get the response. + LOG.debug("Sending request to inference service: {}", uri); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + // Check response status + if (response.statusCode() != 200) { + metrics.requestError(); + LOG.warn("Inference service returned non-200 status code: {} - {}", response.statusCode(), response.body()); + throw new InferenceServiceException("Inference service returned status code: " + response.statusCode()); + } + // Parse the response string into a JsonNode. JsonNode jsonResponse = MAPPER.readTree(response.body()); // Extract the 'embedding' property. JsonNode embedding = jsonResponse.get("embedding"); + if (embedding == null) { + metrics.requestError(); + LOG.warn("Inference service response did not contain 'embedding' property: {}", response.body()); + throw new InferenceServiceException("Invalid response from inference service: missing 'embedding' property"); + } + double[] embeddings = MAPPER.treeToValue(embedding, double[].class); // Convert the array of doubles to a list of floats. List<Float> result = Arrays.stream(embeddings) - .mapToObj(d -> ((Double) d).floatValue()) - .collect(Collectors.toList()); + .mapToObj(d -> ((Double) d).floatValue()) + .collect(Collectors.toList()); cache.put(text, result); + + LOG.debug("Successfully retrieved embeddings for text of length {}", text.length()); return result; } catch (InterruptedException e) { Thread.currentThread().interrupt(); + metrics.requestError(); + LOG.warn("Inference service request was interrupted", e); throw new InferenceServiceException("Failed to get embeddings", e); } catch (IOException e) { + metrics.requestError(); + LOG.warn("Error communicating with inference service", e); throw new InferenceServiceException("Unable to extract embeddings from inference service response", e); + } finally { + long requestTime = System.currentTimeMillis() - startTime; + metrics.requestCompleted(requestTime); + //TODO evaluate and update how often we want to log these stats. + // Setting it to log every 10 minutes for now. + metrics.logMetricsSummary(10 * 60 * 1000, Integer.MAX_VALUE); } } - private static class Cache<K, V> extends LinkedHashMap<K, V> { private final int maxEntries; diff --git a/oak-search-elastic/src/test/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetricsTest.java b/oak-search-elastic/src/test/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetricsTest.java new file mode 100644 index 0000000000..b29393860e --- /dev/null +++ b/oak-search-elastic/src/test/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/inference/InferenceServiceMetricsTest.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.jackrabbit.oak.plugins.index.elastic.query.inference; + +import ch.qos.logback.classic.Level; +import com.codahale.metrics.Timer; +import org.apache.jackrabbit.oak.commons.junit.LogCustomizer; +import org.junit.Before; +import org.junit.Test; + +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class InferenceServiceMetricsTest { + + private InferenceServiceMetrics metrics; + private static final String TEST_SERVICE_KEY = "testService"; + private static final int TEST_CACHE_SIZE = 100; + private static final String KEY_REQUEST_PERCENTILES = "requestPercentiles"; + + @Before + public void setUp() { + metrics = new InferenceServiceMetrics(TEST_SERVICE_KEY, TEST_CACHE_SIZE); + } + + @Test + public void testInitialState() { + assertEquals(0, metrics.getTotalRequests()); + + Map<String, Object> metricsMap = metrics.getMetrics(); + assertEquals(0L, metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS)); + assertEquals(0L, metricsMap.get(InferenceServiceMetrics.CACHE_HITS)); + assertEquals(0L, metricsMap.get(InferenceServiceMetrics.CACHE_MISSES)); + assertEquals(0.0, metricsMap.get(InferenceServiceMetrics.CACHE_HIT_RATE)); + assertEquals(TEST_CACHE_SIZE, metricsMap.get(InferenceServiceMetrics.CACHE_SIZE)); + assertEquals(0L, metricsMap.get(InferenceServiceMetrics.REQUEST_ERRORS)); + assertEquals(0.0, metricsMap.get(InferenceServiceMetrics.ERROR_RATE)); + } + + @Test + public void testRequestTracking() { + // Start a request + Timer.Context context = metrics.requestStarted(); + assertEquals(1, metrics.getTotalRequests()); + + // Complete the request + metrics.requestCompleted(150, context); + + Map<String, Object> metricsMap = metrics.getMetrics(); + assertEquals(1L, metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS)); + + // Second request without context + metrics.requestStarted(); + metrics.requestCompleted(200); + + metricsMap = metrics.getMetrics(); + assertEquals(2L, metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS)); + } + + @Test + public void testCacheHitRate() { + // No cache activity + assertEquals(0.0, metrics.getCacheHitRate(), 0.001); + + // Record some hits and misses + metrics.cacheHit(); + metrics.cacheHit(); + metrics.cacheMiss(); + + // Should be 2/3 = 66.67% + assertEquals(66.67, metrics.getCacheHitRate(), 0.01); + + // Add more misses + metrics.cacheMiss(); + metrics.cacheMiss(); + + // Should be 2/5 = 40% + assertEquals(40.0, metrics.getCacheHitRate(), 0.001); + + Map<String, Object> metricsMap = metrics.getMetrics(); + assertEquals(2L, metricsMap.get(InferenceServiceMetrics.CACHE_HITS)); + assertEquals(3L, metricsMap.get(InferenceServiceMetrics.CACHE_MISSES)); + assertEquals(40.0, metricsMap.get(InferenceServiceMetrics.CACHE_HIT_RATE)); + } + + @Test + public void testErrorTracking() { + // Start a request and record an error + Timer.Context context = metrics.requestStarted(); + metrics.requestError(100, context); + + // Start another request and record an error without timing + metrics.requestStarted(); + metrics.requestError(); + + Map<String, Object> metricsMap = metrics.getMetrics(); + assertEquals(2L, metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS)); + assertEquals(2L, metricsMap.get(InferenceServiceMetrics.REQUEST_ERRORS)); + assertEquals(100.0, metricsMap.get(InferenceServiceMetrics.ERROR_RATE)); + + // Add a successful request + metrics.requestStarted(); + metrics.requestCompleted(150); + + metricsMap = metrics.getMetrics(); + assertEquals(3L, metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS)); + assertEquals(2L, metricsMap.get(InferenceServiceMetrics.REQUEST_ERRORS)); + assertEquals(66.67, (double) metricsMap.get(InferenceServiceMetrics.ERROR_RATE), 0.01); + } + + @Test + public void testTimingMetrics() { + // Record multiple requests with different timings + for (int i = 0; i < 10; i++) { + Timer.Context context = metrics.requestStarted(); + metrics.requestCompleted(100 + (i * 50), context); + } + + Map<String, Object> metricsMap = metrics.getMetrics(); + + // Check time histogram exists + assertTrue(metricsMap.containsKey(InferenceServiceMetrics.TIME_HISTOGRAM)); + + @SuppressWarnings("unchecked") + Map<String, Object> histogram = (Map<String, Object>) metricsMap.get(InferenceServiceMetrics.TIME_HISTOGRAM); + + // Verify histogram has expected metrics + assertEquals(10L, histogram.get("count")); + assertTrue(histogram.containsKey("min")); + assertTrue(histogram.containsKey("max")); + assertTrue(histogram.containsKey("mean")); + assertTrue(histogram.containsKey("stdDev")); + + // Verify percentiles exist + assertTrue(histogram.containsKey(KEY_REQUEST_PERCENTILES)); + } + + @Test + public void testMetricRegistry() { + // Verify metric registry is available + assertNotNull(metrics.getMetricRegistry()); + } + + @Test + public void testLogMetricsSummaryOutput() { + // Setup the LogCustomizer to capture log messages from InferenceServiceMetrics + LogCustomizer custom = LogCustomizer + .forLogger(InferenceServiceMetrics.class.getName()) + .enable(Level.INFO) + .create(); + + try { + custom.starting(); + + // 1. Generate comprehensive metrics data + // Multiple requests with different timing values + for (int i = 0; i < 5; i++) { + Timer.Context context = metrics.requestStarted(); + metrics.requestCompleted(100 + (i * 50), context); + } + + // Add cache hits and misses to test hit rate + for (int i = 0; i < 3; i++) { + metrics.cacheHit(); + } + for (int i = 0; i < 2; i++) { + metrics.cacheMiss(); + } + + // Add some errors to test error rate + for (int i = 0; i < 2; i++) { + Timer.Context context = metrics.requestStarted(); + metrics.requestError(75 + i * 25, context); + } + + // At this point we should have: + // - 7 total requests (5 successful, 2 errors) + // - 3 cache hits, 2 cache misses (60% hit rate) + // - 2 errors (28.6% error rate) + // - Various timing metrics + + // Verify metrics were recorded correctly + Map<String, Object> metricsMap = metrics.getMetrics(); + assertEquals(7L, metricsMap.get(InferenceServiceMetrics.TOTAL_REQUESTS)); + assertEquals(3L, metricsMap.get(InferenceServiceMetrics.CACHE_HITS)); + assertEquals(2L, metricsMap.get(InferenceServiceMetrics.CACHE_MISSES)); + assertEquals(60.0, metricsMap.get(InferenceServiceMetrics.CACHE_HIT_RATE)); + assertEquals(2L, metricsMap.get(InferenceServiceMetrics.REQUEST_ERRORS)); + assertEquals(28.57, (double) metricsMap.get(InferenceServiceMetrics.ERROR_RATE), 0.01); + + // Check both histograms exist + @SuppressWarnings("unchecked") + Map<String, Object> histogram = (Map<String, Object>) metricsMap.get(InferenceServiceMetrics.TIME_HISTOGRAM); + assertEquals(5L, ((Number) histogram.get("count")).longValue()); + + @SuppressWarnings("unchecked") + Map<String, Object> errorHistogram = (Map<String, Object>) metricsMap.get(InferenceServiceMetrics.KEY_ERROR_TIME_DATA); + assertNotNull("Error histogram should exist", errorHistogram); + assertEquals(2L, ((Number) errorHistogram.get("count")).longValue()); + + // 2. Log the metrics summary + // Force logging regardless of time interval by using parameters that won't trigger early return + metrics.logMetricsSummary(0, 100); + + // 3. Verify the log output contains all expected metrics + List<String> logs = custom.getLogs(); + assertFalse("Log should contain at least one entry", logs.isEmpty()); + + String logMessage = logs.get(0); + assertTrue("Log should contain the service key", logMessage.contains(TEST_SERVICE_KEY)); + + // Check all metrics are in the log + assertTrue("Log should contain requests count", logMessage.contains("requests=7")); + assertTrue("Log should contain hit rate", logMessage.contains("hitRate=")); + assertTrue("Log should contain error rate", logMessage.contains("errorRate=")); + assertTrue("Log should contain avgTime", logMessage.contains("avgTime=")); + assertTrue("Log should contain maxTime", logMessage.contains("maxTime=")); + assertTrue("Log should contain percentiles", logMessage.contains("successPercentiles [50th=")); + assertTrue("Log should contain rates", logMessage.contains("successRates [1m=")); + assertTrue("Log should contain error rates", logMessage.contains("errorRates [1m=")); + + // Check additional metrics present in updated implementation + assertTrue("Log should contain error rate metrics", logMessage.contains("err/s")); + + // Check for request rate metrics + assertTrue("Log should contain 1-minute request rate", logMessage.contains("1m=")); + assertTrue("Log should contain 5-minute request rate", logMessage.contains("5m=")); + assertTrue("Log should contain 15-minute request rate", logMessage.contains("15m=")); + + // Verify error histogram details are included in the log output + assertTrue("Log should include error timing information", + logMessage.contains(InferenceServiceMetrics.KEY_ERROR_TIME_DATA) || + metricsMap.containsKey(InferenceServiceMetrics.KEY_ERROR_TIME_DATA)); + } finally { + custom.finished(); + } + } +} \ No newline at end of file
