This is an automated email from the ASF dual-hosted git repository.
github-merge-queue[bot] pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/texera.git
The following commit(s) were added to refs/heads/main by this push:
new 1b0ec78aa9 feat(huggingFace): add HuggingFaceModelResource for model
browsing and media proxy (#5124)
1b0ec78aa9 is described below
commit 1b0ec78aa954bb1eab07e79d684a94c0e5a450b5
Author: Prateek Ganigi <[email protected]>
AuthorDate: Sun May 31 13:48:02 2026 -0700
feat(huggingFace): add HuggingFaceModelResource for model browsing and
media proxy (#5124)
### What changes were proposed in this PR?
Introduces `HuggingFaceModelResource` - a Jersey REST resource at
`/api/huggingface/*` that backs the upcoming HuggingFace operator's
model picker, audio upload, and media preview UI. Five endpoints:
| Endpoint | Purpose |
|---|---|
| `GET /api/huggingface/models?task=…[&search=…]` | Browse or search HF
models |
| `GET /api/huggingface/tasks` | List HF pipeline tags with hosted
inference |
| `POST /api/huggingface/upload-audio?filename=…` | Stream-upload audio
files |
| `GET /api/huggingface/audio-preview?path=…` | Stream uploaded audio
back |
| `GET /api/huggingface/media-proxy?url=…` | Proxy allowlisted remote
media URLs (CORS bypass) |
Plus a single-line registration of the resource in
`TexeraWebApplication`.
**Architectural notes:**
- **Token sourcing:** the user's HF token arrives via the `X-HF-Token`
request header (forwarded by the frontend from the operator's property
panel in a follow-up PR). When absent, requests go to HF Hub
anonymously. There is no server-side env-var token.
- **Caching:** bounded Guava `Cache` (size + TTL) for `/models` and
`/tasks` results. User-token requests bypass the cache to avoid serving
one user's token-scoped list to another.
- **Streaming upload:** `/upload-audio` reads `InputStream` straight to
disk in 8 KB chunks with a 25 MiB cap (returns 413 on exceedance) - the
request body is never buffered in memory. Extension allowlist rejects
non-audio types up front.
- **SSRF protection:** `/media-proxy` requires the URL's host to be in
an allowlist (HF, fal.media, replicate.delivery/com) with a leading-dot
suffix guard against lookalike domains.
- **Bounded fan-out:** `/tasks` uses a dedicated `ForkJoinPool(4)` for
its per-task probe instead of the JVM's global common pool, with
explicit 429/503 detection that logs at WARN.
- **Truncation visibility:** browse and search responses carry an
`X-Texera-Truncated: true` header when results were capped, so the
frontend can show "list incomplete" hints.
- **Error responses:** generic Jackson-built JSON bodies (no exception
internals leak to clients); details are logged server-side.
### Any related issues, documentation, or discussions?
Tracked in #5134 & #5041(umbrella issue for the HuggingFace operator
end-to-end implementation). This PR is the backend foundation;
subsequent PRs will add the operator class, frontend property panel,
result-panel media rendering, and developer documentation.
Closes #5134
### How was this PR tested?
- Unit tests:
`amber/src/test/scala/.../HuggingFaceModelResourceSpec.scala` - 86
ScalaTest cases covering token sanitization, SSRF allowlist (including
lookalike-domain rejection), JSON error escaping, MIME type inference,
the audio-upload validation/size-cap/extension paths, audio-preview path
validation and traversal rejection, media-proxy rejection paths, cache
hit/bypass semantics, and the temp-dir sweep. Run with `sbt
'WorkflowExecutionService/testOnly
org.apache.texera.web.resource.HuggingFaceModelResourceSpec'` - all 86
pass in ~6 seconds, no external network required.
- Manual smoke tests against a local backend:
- `GET /api/huggingface/tasks` returns the expected JSON task list.
- `GET /api/huggingface/models?task=text-generation` returns the
paginated model list; `text-generation` shows the `X-Texera-Truncated:
true` header when `MAX_PAGES=50` is hit.
- `POST /upload-audio?filename=evil.sh` → 400 (extension allowlist).
- `POST /upload-audio` with a 30 MiB body → 413 (size cap).
- `GET /media-proxy?url=http://localhost:8080/` → 403 (SSRF allowlist).
### Was this PR authored or co-authored using generative AI tooling?
Co-authored with Claude Opus 4.7 in compliance with ASF
---------
Co-authored-by: Elliot Lin <[email protected]>
Co-authored-by: Claude Opus 4.7 (1M context) <[email protected]>
Co-authored-by: Xuan Gu <[email protected]>
---
.../apache/texera/web/TexeraWebApplication.scala | 1 +
.../web/resource/HuggingFaceModelResource.scala | 750 +++++++++++++++++++++
.../resource/HuggingFaceModelResourceSpec.scala | 731 ++++++++++++++++++++
3 files changed, 1482 insertions(+)
diff --git
a/amber/src/main/scala/org/apache/texera/web/TexeraWebApplication.scala
b/amber/src/main/scala/org/apache/texera/web/TexeraWebApplication.scala
index 98b7c68c97..5438eea4d0 100644
--- a/amber/src/main/scala/org/apache/texera/web/TexeraWebApplication.scala
+++ b/amber/src/main/scala/org/apache/texera/web/TexeraWebApplication.scala
@@ -160,6 +160,7 @@ class TexeraWebApplication
environment.jersey.register(classOf[UserQuotaResource])
environment.jersey.register(classOf[AdminSettingsResource])
environment.jersey.register(classOf[AIAssistantResource])
+ environment.jersey.register(classOf[HuggingFaceModelResource])
AuthResource.createAdminUser()
diff --git
a/amber/src/main/scala/org/apache/texera/web/resource/HuggingFaceModelResource.scala
b/amber/src/main/scala/org/apache/texera/web/resource/HuggingFaceModelResource.scala
new file mode 100644
index 0000000000..bdda6e8ffc
--- /dev/null
+++
b/amber/src/main/scala/org/apache/texera/web/resource/HuggingFaceModelResource.scala
@@ -0,0 +1,750 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.texera.web.resource
+
+import com.fasterxml.jackson.core.`type`.TypeReference
+import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
+import com.google.common.cache.{Cache, CacheBuilder}
+import kong.unirest.Unirest
+import org.slf4j.{Logger, LoggerFactory}
+
+import java.io.InputStream
+import java.net.URI
+import java.nio.file.{Files, Path => NioPath, Paths}
+import java.util.concurrent.{Callable, ForkJoinPool, TimeUnit}
+import java.util.stream.Collectors
+import javax.annotation.security.RolesAllowed
+import javax.ws.rs._
+import javax.ws.rs.core.{MediaType, Response}
+import scala.jdk.CollectionConverters._
+
+/**
+ * REST resource that proxies the Hugging Face Hub API for the HuggingFace
operator.
+ *
+ * - GET /api/huggingface/models?task=…[&search=…] browse or search HF
models
+ * - GET /api/huggingface/tasks list HF pipeline
tags with hosted inference
+ * - POST /api/huggingface/upload-audio?filename=… stream-upload an
audio file
+ * - GET /api/huggingface/audio-preview?path=… stream back an
uploaded audio file
+ * - GET /api/huggingface/media-proxy?url=… proxy an allowlisted
remote media URL
+ *
+ * Token sourcing: the user supplies their own HF token via the `X-HF-Token`
+ * request header (forwarded by the frontend from the operator's property
+ * panel). If the header is absent, requests go to HF Hub anonymously —
+ * HF serves public model/task lists at public rate limits without auth.
+ * The browse cache is bypassed whenever a user token is supplied, so one
+ * user's private-model visibility never leaks into another user's response.
+ */
+@Path("/huggingface")
+@Produces(Array(MediaType.APPLICATION_JSON))
+@RolesAllowed(Array("REGULAR", "ADMIN"))
+class HuggingFaceModelResource {
+
+ import HuggingFaceModelResource._
+
+ @GET
+ @Path("/models")
+ def listModels(
+ @QueryParam("task") @DefaultValue("text-generation") task: String,
+ @QueryParam("search") search: String,
+ @HeaderParam("X-HF-Token") userToken: String
+ ): Response = {
+ try {
+ val hfToken = sanitizeToken(userToken)
+ val isUserToken = hfToken.nonEmpty
+
+ // ── Search mode: forward query to HF Hub, return results directly ──
+ if (search != null && search.trim.nonEmpty) {
+ return fetchSearchResults(task, search.trim, hfToken)
+ }
+
+ // ── Browse mode: return ALL models for this task ──
+ // Only cache anonymous results, so a user with private-model visibility
+ // can't have their token-scoped list served to a different user.
+ if (!isUserToken) {
+ val cached = modelCache.getIfPresent(task)
+ if (cached != null) {
+ return Response.ok(cached).build()
+ }
+ }
+
+ val pageResult = fetchAllModelsForTask(task, hfToken)
+ val json = objectMapper.writeValueAsString(pageResult.models)
+ if (!isUserToken) modelCache.put(task, json)
+
+ val builder = Response.ok(json)
+ if (pageResult.truncated) builder.header(TRUNCATED_HEADER, "true")
+ builder.build()
+ } catch {
+ case e: Exception =>
+ logger.error(s"Failed to fetch HF models for task '$task'", e)
+ errorResponse(Response.Status.INTERNAL_SERVER_ERROR, "Failed to fetch
models.")
+ }
+ }
+
+ /**
+ * Streams an audio file from the request body to a temp file under
+ * `${java.io.tmpdir}/texera-hf-audio`. Enforces an extension allowlist
+ * and a max payload size (rejected with 413 once exceeded). Old files
+ * in the temp dir are best-effort cleaned on each upload.
+ */
+ @POST
+ @Path("/upload-audio")
+ @Consumes(Array(MediaType.WILDCARD))
+ def uploadAudioReference(
+ @QueryParam("filename") filename: String,
+ stream: InputStream
+ ): Response = {
+ try {
+ if (stream == null) {
+ return errorResponse(Response.Status.BAD_REQUEST, "Audio payload is
required.")
+ }
+
+ val safeFileName = Option(filename)
+ .map(_.trim)
+ .filter(_.nonEmpty)
+ .map(name => Paths.get(name).getFileName.toString)
+ .getOrElse("audio.bin")
+ val extension = {
+ val idx = safeFileName.lastIndexOf('.')
+ if (idx >= 0 && idx < safeFileName.length - 1)
+ safeFileName.substring(idx).toLowerCase
+ else ""
+ }
+ if (!ALLOWED_AUDIO_EXTENSIONS.contains(extension)) {
+ return errorResponse(
+ Response.Status.BAD_REQUEST,
+ "Unsupported audio file extension."
+ )
+ }
+
+ val tempDir = audioTempDir
+ Files.createDirectories(tempDir)
+ sweepOldAudioFiles(tempDir)
+
+ val tempFile: NioPath = Files.createTempFile(tempDir, "hf-audio-",
extension)
+ tempFile.toFile.deleteOnExit()
+
+ val out = Files.newOutputStream(tempFile)
+ var totalBytes = 0L
+ try {
+ val buf = new Array[Byte](8 * 1024)
+ var read = stream.read(buf)
+ while (read != -1) {
+ totalBytes += read
+ if (totalBytes > MAX_AUDIO_BYTES) {
+ out.close()
+ Files.deleteIfExists(tempFile)
+ return errorResponse(
+ Response.Status.REQUEST_ENTITY_TOO_LARGE,
+ "Audio payload exceeds the size limit."
+ )
+ }
+ out.write(buf, 0, read)
+ read = stream.read(buf)
+ }
+ } finally {
+ out.close()
+ }
+
+ if (totalBytes == 0L) {
+ Files.deleteIfExists(tempFile)
+ return errorResponse(Response.Status.BAD_REQUEST, "Audio payload is
empty.")
+ }
+
+ val json = objectMapper.writeValueAsString(
+ Map(
+ "path" -> tempFile.toAbsolutePath.toString,
+ "fileName" -> safeFileName
+ ).asJava
+ )
+ Response.ok(json).build()
+ } catch {
+ case e: Exception =>
+ logger.error("Failed to upload audio", e)
+ errorResponse(Response.Status.INTERNAL_SERVER_ERROR, "Failed to upload
audio.")
+ }
+ }
+
+ @GET
+ @Path("/audio-preview")
+ def previewUploadedAudio(@QueryParam("path") path: String): Response = {
+ try {
+ val trimmedPath = Option(path).map(_.trim).getOrElse("")
+ if (trimmedPath.isEmpty) {
+ return errorResponse(Response.Status.BAD_REQUEST, "Audio path is
required.")
+ }
+
+ val tempDir = audioTempDir.toAbsolutePath.normalize()
+ val requestedPath = Paths.get(trimmedPath).toAbsolutePath.normalize()
+ if (!requestedPath.startsWith(tempDir)) {
+ return errorResponse(
+ Response.Status.FORBIDDEN,
+ "Audio path is outside the allowed preview directory."
+ )
+ }
+ if (!Files.exists(requestedPath) || !Files.isRegularFile(requestedPath))
{
+ return errorResponse(Response.Status.NOT_FOUND, "Uploaded audio file
was not found.")
+ }
+
+ // Defense-in-depth: even though /upload-audio enforces MAX_AUDIO_BYTES
on
+ // ingest, refuse to buffer an oversized file into heap on the response
+ // side. Catches files placed via a future-bug or out-of-band write.
+ val size = Files.size(requestedPath)
+ if (size > MAX_AUDIO_BYTES) {
+ logger.warn(
+ s"Uploaded audio file size $size exceeds cap $MAX_AUDIO_BYTES;
rejecting."
+ )
+ return errorResponse(
+ Response.Status.REQUEST_ENTITY_TOO_LARGE,
+ "Uploaded audio file exceeds the size limit."
+ )
+ }
+
+ val contentType = Option(Files.probeContentType(requestedPath))
+ .filter(_.trim.nonEmpty)
+ .getOrElse(inferAudioContentType(requestedPath))
+ Response.ok(Files.readAllBytes(requestedPath), contentType).build()
+ } catch {
+ case e: Exception =>
+ logger.error("Failed to read uploaded audio", e)
+ errorResponse(Response.Status.INTERNAL_SERVER_ERROR, "Failed to read
uploaded audio.")
+ }
+ }
+
+ /**
+ * Proxies a remote media URL to bypass browser CORS for HF inference
responses.
+ * Only http(s) URLs whose host is in ALLOWED_MEDIA_HOST_SUFFIXES are
accepted,
+ * blocking SSRF probes against internal services.
+ */
+ @GET
+ @Path("/media-proxy")
+ def proxyRemoteMedia(@QueryParam("url") url: String): Response = {
+ try {
+ val trimmedUrl = Option(url).map(_.trim).getOrElse("")
+ if (trimmedUrl.isEmpty) {
+ return errorResponse(Response.Status.BAD_REQUEST, "Media URL is
required.")
+ }
+ if (!trimmedUrl.startsWith("http://") &&
!trimmedUrl.startsWith("https://")) {
+ return errorResponse(
+ Response.Status.BAD_REQUEST,
+ "Only http(s) media URLs are supported."
+ )
+ }
+
+ val parsed =
+ try { new URI(trimmedUrl) }
+ catch { case _: Exception => null }
+ if (parsed == null || parsed.getHost == null ||
!isAllowedMediaHost(parsed.getHost)) {
+ return errorResponse(Response.Status.FORBIDDEN, "Media URL host is not
allowed.")
+ }
+
+ // Stream the upstream response via asObject(Function<RawResponse, T>) so
+ // we never have to materialise it into a heap-resident byte[] before we
+ // can enforce the size cap. The function returns a MediaProxyOutcome
+ // that this method then converts into a Jersey Response.
+ val outcome = Unirest
+ .get(trimmedUrl)
+ .connectTimeout(CONNECT_TIMEOUT_MS)
+ .socketTimeout(SOCKET_TIMEOUT_LONG_MS)
+ .asObject((raw: kong.unirest.RawResponse) => streamMediaWithCap(raw))
+ .getBody
+
+ outcome match {
+ case MediaProxyOk(bytes, contentType) =>
+ Response.ok(bytes,
contentType.getOrElse(MediaType.APPLICATION_OCTET_STREAM)).build()
+ case MediaProxyError(status, message) =>
+ errorResponse(status, message)
+ }
+ } catch {
+ case e: Exception =>
+ logger.error("Failed to proxy remote media", e)
+ errorResponse(Response.Status.INTERNAL_SERVER_ERROR, "Failed to proxy
remote media.")
+ }
+ }
+
+ /**
+ * Read an upstream media response with a hard size cap. Aborts early both
+ * when the declared Content-Length exceeds the cap and when the body
crosses
+ * it mid-read (in case the upstream lies about Content-Length or omits it).
+ */
+ private def streamMediaWithCap(raw: kong.unirest.RawResponse):
MediaProxyOutcome = {
+ if (raw.getStatus != 200) {
+ logger.warn(s"Upstream media fetch returned ${raw.getStatus}:
${raw.getStatusText}")
+ return MediaProxyError(raw.getStatus, "Failed to fetch remote media.")
+ }
+
+ val declaredLength = Option(raw.getHeaders.getFirst("Content-Length"))
+ .flatMap(s => scala.util.Try(s.trim.toLong).toOption)
+ if (declaredLength.exists(_ > MAX_MEDIA_PROXY_BYTES)) {
+ logger.warn(
+ s"Upstream Content-Length ${declaredLength.get} exceeds cap
$MAX_MEDIA_PROXY_BYTES; rejecting."
+ )
+ return MediaProxyError(
+ Response.Status.REQUEST_ENTITY_TOO_LARGE.getStatusCode,
+ "Remote media exceeds the size limit."
+ )
+ }
+
+ val buffered = new java.io.ByteArrayOutputStream()
+ val buf = new Array[Byte](8 * 1024)
+ val in = raw.getContent
+ var totalBytes = 0L
+ var exceeded = false
+ var read = in.read(buf)
+ while (read != -1 && !exceeded) {
+ totalBytes += read
+ if (totalBytes > MAX_MEDIA_PROXY_BYTES) {
+ exceeded = true
+ } else {
+ buffered.write(buf, 0, read)
+ read = in.read(buf)
+ }
+ }
+ if (exceeded) {
+ logger.warn(s"Upstream media exceeded cap $MAX_MEDIA_PROXY_BYTES
mid-stream; rejecting.")
+ return MediaProxyError(
+ Response.Status.REQUEST_ENTITY_TOO_LARGE.getStatusCode,
+ "Remote media exceeds the size limit."
+ )
+ }
+
+ val contentType = Option(raw.getContentType).map(_.trim).filter(_.nonEmpty)
+ MediaProxyOk(buffered.toByteArray, contentType)
+ }
+
+ /** Search HF Hub for models matching a query within a task. */
+ private def fetchSearchResults(task: String, query: String, hfToken:
String): Response = {
+ var request = Unirest
+ .get("https://huggingface.co/api/models")
+ .queryString("pipeline_tag", task)
+ .queryString("sort", "downloads")
+ .queryString("direction", "-1")
+ .queryString("limit", SEARCH_LIMIT.toString)
+ .queryString("filter", task)
+ .queryString("inference", "warm")
+ .queryString("search", query)
+ .connectTimeout(CONNECT_TIMEOUT_MS)
+ .socketTimeout(SOCKET_TIMEOUT_MS)
+
+ if (hfToken.nonEmpty) {
+ request = request.header("Authorization", s"Bearer $hfToken")
+ }
+
+ val hfResponse = request.asString()
+
+ if (hfResponse.getStatus != 200) {
+ logger.warn(
+ s"HF search returned ${hfResponse.getStatus}:
${hfResponse.getStatusText}"
+ )
+ return errorResponse(hfResponse.getStatus, "Hugging Face API error.")
+ }
+
+ val rawModels = objectMapper.readValue(hfResponse.getBody, listOfMapsType)
+ val out = buildSimplifiedList(rawModels)
+ val truncated = rawModels.size() >= SEARCH_LIMIT
+ val builder = Response.ok(objectMapper.writeValueAsString(out))
+ if (truncated) builder.header(TRUNCATED_HEADER, "true")
+ builder.build()
+ }
+
+ /** GET /api/huggingface/tasks — list HF pipeline tags that have models with
hosted inference. */
+ @GET
+ @Path("/tasks")
+ def listTasks(@HeaderParam("X-HF-Token") userToken: String): Response = {
+ try {
+ val hfToken = sanitizeToken(userToken)
+ val isUserToken = hfToken.nonEmpty
+
+ // Only cache anonymous results — see comment in listModels.
+ if (!isUserToken) {
+ val cached = taskCache.getIfPresent(TASKS_CACHE_KEY)
+ if (cached != null) {
+ return Response.ok(cached).build()
+ }
+ }
+
+ var request = Unirest
+ .get("https://huggingface.co/api/tasks")
+ .connectTimeout(CONNECT_TIMEOUT_MS)
+ .socketTimeout(SOCKET_TIMEOUT_MS)
+
+ if (hfToken.nonEmpty) {
+ request = request.header("Authorization", s"Bearer $hfToken")
+ }
+
+ val hfResponse = request.asString()
+
+ if (hfResponse.getStatus != 200) {
+ logger.warn(
+ s"HF tasks endpoint returned ${hfResponse.getStatus}:
${hfResponse.getStatusText}"
+ )
+ return errorResponse(hfResponse.getStatus, "Hugging Face API error.")
+ }
+
+ // /api/tasks returns { "<pipeline_tag>": { "label": "...", ... }, ... }
+ val root: JsonNode = objectMapper.readTree(hfResponse.getBody)
+ val taskList = new java.util.ArrayList[java.util.Map[String, Object]]()
+ val iter = root.fields()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val tag = entry.getKey
+ val info: JsonNode = entry.getValue
+ val label =
+ if (info != null && info.isObject && info.has("label"))
info.get("label").asText(tag)
+ else tag
+ val taskEntry = new java.util.LinkedHashMap[String, Object]()
+ taskEntry.put("tag", tag)
+ taskEntry.put("label", label)
+ taskList.add(taskEntry)
+ }
+
+ // Bounded fan-out: scope the parallelStream to our own ForkJoinPool
+ // (size = TASK_FETCH_PARALLELISM) instead of the global common pool.
+ val availableTasks =
+ taskCheckPool
+ .submit(new Callable[java.util.List[java.util.Map[String, Object]]] {
+ override def call(): java.util.List[java.util.Map[String, Object]]
= {
+ taskList
+ .parallelStream()
+ .filter(t => hasModelsForTask(t.get("tag").toString, hfToken))
+ .collect(Collectors.toList())
+ }
+ })
+ .get()
+
+ val json = objectMapper.writeValueAsString(availableTasks)
+ if (!isUserToken) taskCache.put(TASKS_CACHE_KEY, json)
+ Response.ok(json).build()
+ } catch {
+ case e: Exception =>
+ logger.error("Failed to fetch HF tasks", e)
+ errorResponse(Response.Status.INTERNAL_SERVER_ERROR, "Failed to fetch
tasks.")
+ }
+ }
+
+ /**
+ * Fetch all models for a given task by paginating the HF Hub Link header.
+ * Stops at MAX_PAGES pages; sets `truncated = true` if pagination stopped
+ * early (either by hitting MAX_PAGES or an upstream error mid-pagination).
+ */
+ private def fetchAllModelsForTask(
+ task: String,
+ hfToken: String
+ ): PageResult = {
+ val allResults = new java.util.ArrayList[java.util.Map[String, Object]]()
+ var nextUrl: String = null
+ var pageCount = 0
+
+ var request = Unirest
+ .get("https://huggingface.co/api/models")
+ .queryString("pipeline_tag", task)
+ .queryString("sort", "downloads")
+ .queryString("direction", "-1")
+ .queryString("limit", PAGE_SIZE.toString)
+ .queryString("filter", task)
+ .queryString("inference", "warm")
+ .connectTimeout(CONNECT_TIMEOUT_MS)
+ .socketTimeout(SOCKET_TIMEOUT_MS)
+
+ if (hfToken.nonEmpty) {
+ request = request.header("Authorization", s"Bearer $hfToken")
+ }
+
+ var hfResponse = request.asString()
+
+ if (hfResponse.getStatus != 200) {
+ throw new RuntimeException(
+ s"HF API returned ${hfResponse.getStatus} for task '$task'"
+ )
+ }
+
+ var rawModels = objectMapper.readValue(hfResponse.getBody, listOfMapsType)
+ allResults.addAll(buildSimplifiedList(rawModels))
+ pageCount += 1
+
+ nextUrl = extractNextLink(hfResponse.getHeaders.getFirst("Link"))
+
+ while (nextUrl != null && pageCount < MAX_PAGES) {
+ var nextRequest = Unirest
+ .get(nextUrl)
+ .connectTimeout(CONNECT_TIMEOUT_MS)
+ .socketTimeout(SOCKET_TIMEOUT_MS)
+ if (hfToken.nonEmpty) {
+ nextRequest = nextRequest.header("Authorization", s"Bearer $hfToken")
+ }
+
+ hfResponse = nextRequest.asString()
+
+ if (hfResponse.getStatus != 200) {
+ logger.warn(
+ s"HF pagination stopped early at page $pageCount for task '$task'
with status ${hfResponse.getStatus}"
+ )
+ return PageResult(allResults, truncated = true)
+ }
+
+ rawModels = objectMapper.readValue(hfResponse.getBody, listOfMapsType)
+ allResults.addAll(buildSimplifiedList(rawModels))
+ pageCount += 1
+
+ nextUrl = extractNextLink(hfResponse.getHeaders.getFirst("Link"))
+ }
+
+ val truncated = nextUrl != null && pageCount >= MAX_PAGES
+ if (truncated) {
+ logger.warn(s"HF pagination stopped at MAX_PAGES=$MAX_PAGES for task
'$task'")
+ }
+
+ PageResult(allResults, truncated)
+ }
+
+ /**
+ * Parse the Link header to extract the URL with rel="next".
+ * Format: <https://huggingface.co/api/models?...>; rel="next"
+ */
+ private def extractNextLink(linkHeader: String): String = {
+ if (linkHeader == null || linkHeader.isEmpty) return null
+
+ val parts = linkHeader.split(",")
+ for (part <- parts) {
+ val trimmed = part.trim
+ if (trimmed.contains("rel=\"next\"")) {
+ val start = trimmed.indexOf('<')
+ val end = trimmed.indexOf('>')
+ if (start >= 0 && end > start) {
+ return trimmed.substring(start + 1, end)
+ }
+ }
+ }
+ null
+ }
+
+ /**
+ * Returns true if at least one model exists for the given task with hosted
inference.
+ * Logs 429/503 explicitly so callers can spot HF rate-limit pressure.
+ */
+ private def hasModelsForTask(task: String, hfToken: String): Boolean = {
+ try {
+ var request = Unirest
+ .get("https://huggingface.co/api/models")
+ .queryString("pipeline_tag", task)
+ .queryString("filter", task)
+ .queryString("limit", "1")
+ .queryString("inference", "warm")
+ .connectTimeout(CONNECT_TIMEOUT_SHORT_MS)
+ .socketTimeout(SOCKET_TIMEOUT_SHORT_MS)
+
+ if (hfToken.nonEmpty) {
+ request = request.header("Authorization", s"Bearer $hfToken")
+ }
+
+ val response = request.asString()
+ response.getStatus match {
+ case 200 =>
+ val models = objectMapper.readValue(response.getBody, listOfMapsType)
+ !models.isEmpty
+ case 429 | 503 =>
+ logger.warn(
+ s"HF rate-limit/unavailable (status ${response.getStatus}) when
checking task '$task'"
+ )
+ false
+ case other =>
+ logger.debug(s"HF returned status $other when checking task '$task'")
+ false
+ }
+ } catch {
+ case e: Exception =>
+ logger.debug(s"hasModelsForTask failed for '$task': ${e.getMessage}")
+ false
+ }
+ }
+
+ /** Convert raw HF model maps into simplified maps for the frontend. */
+ private def buildSimplifiedList(
+ rawModels: java.util.List[java.util.Map[String, Object]]
+ ): java.util.List[java.util.Map[String, Object]] = {
+ val out = new java.util.ArrayList[java.util.Map[String, Object]]()
+ val iter = rawModels.iterator()
+ while (iter.hasNext) {
+ val model = iter.next()
+ val id = if (model.get("id") != null) model.get("id").toString else ""
+ val downloads: java.lang.Long = model.get("downloads") match {
+ case n: java.lang.Number => n.longValue()
+ case _ => 0L
+ }
+ val likes: java.lang.Long = model.get("likes") match {
+ case n: java.lang.Number => n.longValue()
+ case _ => 0L
+ }
+ val pipelineTag =
+ if (model.get("pipeline_tag") != null)
model.get("pipeline_tag").toString else ""
+
+ val entry = new java.util.LinkedHashMap[String, Object]()
+ entry.put("id", id)
+ entry.put("label", id)
+ entry.put("pipeline_tag", pipelineTag)
+ entry.put("downloads", downloads)
+ entry.put("likes", likes)
+ out.add(entry)
+ }
+ out
+ }
+}
+
+object HuggingFaceModelResource {
+ private val logger: Logger =
LoggerFactory.getLogger(classOf[HuggingFaceModelResource])
+
+ private val objectMapper: ObjectMapper = new ObjectMapper()
+
+ private val listOfMapsType =
+ new TypeReference[java.util.List[java.util.Map[String, Object]]]() {}
+
+ // ── Network timeouts (ms) ──
+ private val CONNECT_TIMEOUT_MS = 10000
+ private val SOCKET_TIMEOUT_MS = 30000
+ private val CONNECT_TIMEOUT_SHORT_MS = 5000
+ private val SOCKET_TIMEOUT_SHORT_MS = 10000
+ private val SOCKET_TIMEOUT_LONG_MS = 120000
+
+ // ── Pagination ──
+ private val PAGE_SIZE = 1000
+ private val MAX_PAGES = 50
+ private val SEARCH_LIMIT = 100
+
+ /** Response header set when a list response was truncated (server-side
limit hit). */
+ private[resource] val TRUNCATED_HEADER = "X-Texera-Truncated"
+
+ // ── Caches: bounded with TTL ──
+ private val MODEL_CACHE_MAX_SIZE = 100L
+ private val MODEL_CACHE_TTL_MINUTES = 60L
+ private val TASK_CACHE_MAX_SIZE = 8L
+ private val TASK_CACHE_TTL_MINUTES = 60L
+
+ private[resource] val modelCache: Cache[String, String] = CacheBuilder
+ .newBuilder()
+ .maximumSize(MODEL_CACHE_MAX_SIZE)
+ .expireAfterWrite(MODEL_CACHE_TTL_MINUTES, TimeUnit.MINUTES)
+ .build()
+
+ private[resource] val taskCache: Cache[String, String] = CacheBuilder
+ .newBuilder()
+ .maximumSize(TASK_CACHE_MAX_SIZE)
+ .expireAfterWrite(TASK_CACHE_TTL_MINUTES, TimeUnit.MINUTES)
+ .build()
+
+ private[resource] val TASKS_CACHE_KEY = "all"
+
+ // ── /tasks fan-out throttle: bounded ForkJoinPool instead of the global
common pool ──
+ private val TASK_FETCH_PARALLELISM = 4
+ private val taskCheckPool = new ForkJoinPool(TASK_FETCH_PARALLELISM)
+
+ // ── /upload-audio constraints ──
+ private[resource] val MAX_AUDIO_BYTES: Long = 25L * 1024L * 1024L // 25 MiB
+ private[resource] val ALLOWED_AUDIO_EXTENSIONS: Set[String] =
+ Set(".wav", ".mp3", ".mpeg", ".flac", ".ogg", ".oga", ".webm", ".opus",
".amr", ".m4a", ".aac")
+ private[resource] val AUDIO_TEMP_TTL_MS: Long = 60L * 60L * 1000L // 1 hour
+
+ // ── /media-proxy size cap: bounds the upstream response we buffer in heap
──
+ // Sized to cover HF inference media outputs (text-to-image ~5 MiB,
+ // text-to-video ~30 MiB) with headroom. Bumps should land in their own PR.
+ private[resource] val MAX_MEDIA_PROXY_BYTES: Long = 50L * 1024L * 1024L //
50 MiB
+
+ /** Outcome of streaming an upstream media response with the size cap. */
+ private[resource] sealed trait MediaProxyOutcome
+ private[resource] case class MediaProxyOk(bytes: Array[Byte], contentType:
Option[String])
+ extends MediaProxyOutcome
+ private[resource] case class MediaProxyError(status: Int, message: String)
+ extends MediaProxyOutcome
+
+ // ── /media-proxy allowlist (SSRF protection) ──
+ // Add new hosts here when integrating with a new HF inference provider.
+ private val ALLOWED_MEDIA_HOST_SUFFIXES: Set[String] = Set(
+ "huggingface.co",
+ "fal.media",
+ "replicate.delivery",
+ "replicate.com"
+ )
+
+ private[resource] def audioTempDir: NioPath =
+ Paths.get(System.getProperty("java.io.tmpdir"), "texera-hf-audio")
+
+ /** Delete audio temp files older than AUDIO_TEMP_TTL_MS. Best-effort. */
+ private[resource] def sweepOldAudioFiles(tempDir: NioPath): Unit = {
+ val cutoff = System.currentTimeMillis() - AUDIO_TEMP_TTL_MS
+ try {
+ val stream = Files.list(tempDir)
+ try {
+ stream.forEach { p =>
+ try {
+ if (Files.isRegularFile(p) &&
Files.getLastModifiedTime(p).toMillis < cutoff) {
+ Files.deleteIfExists(p)
+ }
+ } catch {
+ case _: Exception => // skip files we can't stat/delete
+ }
+ }
+ } finally {
+ stream.close()
+ }
+ } catch {
+ case e: Exception =>
+ logger.debug(s"Audio temp dir sweep failed: ${e.getMessage}")
+ }
+ }
+
+ /** Allow exact host or subdomain of any entry in
ALLOWED_MEDIA_HOST_SUFFIXES. */
+ private[resource] def isAllowedMediaHost(host: String): Boolean = {
+ if (host == null || host.isEmpty) return false
+ val lower = host.toLowerCase
+ ALLOWED_MEDIA_HOST_SUFFIXES.exists(suffix => lower == suffix ||
lower.endsWith("." + suffix))
+ }
+
+ /** Trim and null-coalesce the X-HF-Token header value; empty means
anonymous. */
+ private[resource] def sanitizeToken(headerValue: String): String =
+ Option(headerValue).map(_.trim).filter(_.nonEmpty).getOrElse("")
+
+ /** Build a JSON error body using Jackson so the message is properly
escaped. */
+ private[resource] def errorJson(message: String): String =
+ objectMapper.writeValueAsString(Map("error" -> message).asJava)
+
+ private def errorResponse(status: Response.Status, message: String):
Response =
+ Response.status(status).entity(errorJson(message)).build()
+
+ private def errorResponse(statusCode: Int, message: String): Response =
+ Response.status(statusCode).entity(errorJson(message)).build()
+
+ private[resource] def inferAudioContentType(path: NioPath): String = {
+ val fileName =
Option(path.getFileName).map(_.toString.toLowerCase).getOrElse("")
+ if (fileName.endsWith(".mp3") || fileName.endsWith(".mpeg")) "audio/mpeg"
+ else if (fileName.endsWith(".wav")) "audio/wav"
+ else if (fileName.endsWith(".flac")) "audio/flac"
+ else if (fileName.endsWith(".ogg") || fileName.endsWith(".oga"))
"audio/ogg"
+ else if (fileName.endsWith(".webm")) "audio/webm"
+ else if (fileName.endsWith(".opus")) "audio/webm;codecs=opus"
+ else if (fileName.endsWith(".amr")) "audio/amr"
+ else if (fileName.endsWith(".m4a")) "audio/m4a"
+ else "application/octet-stream"
+ }
+
+ /** Result of a paginated fetch — `truncated` is true if pagination stopped
early. */
+ private case class PageResult(
+ models: java.util.List[java.util.Map[String, Object]],
+ truncated: Boolean
+ )
+}
diff --git
a/amber/src/test/scala/org/apache/texera/web/resource/HuggingFaceModelResourceSpec.scala
b/amber/src/test/scala/org/apache/texera/web/resource/HuggingFaceModelResourceSpec.scala
new file mode 100644
index 0000000000..38402bd647
--- /dev/null
+++
b/amber/src/test/scala/org/apache/texera/web/resource/HuggingFaceModelResourceSpec.scala
@@ -0,0 +1,731 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.texera.web.resource
+
+import com.fasterxml.jackson.databind.ObjectMapper
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import java.io.{ByteArrayInputStream, InputStream}
+import java.nio.charset.StandardCharsets
+import java.nio.file.{Files, Path, Paths}
+import javax.ws.rs.core.Response
+
+/**
+ * Tests for [[HuggingFaceModelResource]] covering the validation, security,
+ * caching, and filesystem behavior that can be exercised without contacting
+ * Hugging Face Hub. Paths that require live HF API calls (the actual fetch
+ * loops in `listModels` browse-mode-uncached and `listTasks` uncached) are
+ * left to integration testing.
+ */
+class HuggingFaceModelResourceSpec extends AnyFunSuite with BeforeAndAfterEach
{
+
+ import HuggingFaceModelResource._
+
+ private val mapper = new ObjectMapper()
+ private var resource: HuggingFaceModelResource = _
+
+ override def beforeEach(): Unit = {
+ resource = new HuggingFaceModelResource()
+ // Reset caches between tests so cache hits from one test can't leak into
another.
+ modelCache.invalidateAll()
+ taskCache.invalidateAll()
+ // Make sure the audio temp dir exists for tests that read from it.
+ Files.createDirectories(audioTempDir)
+ }
+
+ override def afterEach(): Unit = {
+ // Clean up any temp files this test created.
+ if (Files.exists(audioTempDir)) {
+ val stream = Files.list(audioTempDir)
+ try {
+ stream.forEach { p =>
+ try Files.deleteIfExists(p)
+ catch { case _: Exception => () }
+ }
+ } finally {
+ stream.close()
+ }
+ }
+ modelCache.invalidateAll()
+ taskCache.invalidateAll()
+ }
+
+ // Helper: read a Response's string entity (assumes the body is a String).
+ private def entityString(response: Response): String =
+ response.getEntity match {
+ case s: String => s
+ case other => other.toString
+ }
+
+ // Helper: read a Response's byte entity (assumes the body is a byte array).
+ private def entityBytes(response: Response): Array[Byte] =
+ response.getEntity.asInstanceOf[Array[Byte]]
+
+ // Helper: assert that a Response carries a JSON error body shaped {
"error": "..." }.
+ private def assertErrorBody(response: Response): Unit = {
+ val body = entityString(response)
+ val node = mapper.readTree(body)
+ assert(node.has("error"), s"expected JSON error body, got: $body")
+ }
+
+ // Helper: build a small in-memory InputStream from a UTF-8 string.
+ private def streamOf(s: String): InputStream =
+ new ByteArrayInputStream(s.getBytes(StandardCharsets.UTF_8))
+
+ // Helper: build an InputStream of `n` zero-bytes.
+ private def streamOfBytes(n: Int): InputStream =
+ new ByteArrayInputStream(new Array[Byte](n))
+
+ // ────────────────────────────────────────────────────────────────────────
+ // sanitizeToken
+ // ────────────────────────────────────────────────────────────────────────
+
+ test("sanitizeToken returns empty string when input is null") {
+ assert(sanitizeToken(null) == "")
+ }
+
+ test("sanitizeToken returns empty string when input is empty") {
+ assert(sanitizeToken("") == "")
+ }
+
+ test("sanitizeToken returns empty string when input is whitespace only") {
+ assert(sanitizeToken(" ") == "")
+ assert(sanitizeToken("\t\n") == "")
+ }
+
+ test("sanitizeToken trims surrounding whitespace") {
+ assert(sanitizeToken(" hf_abc123 ") == "hf_abc123")
+ }
+
+ test("sanitizeToken preserves a valid token unchanged") {
+ assert(sanitizeToken("hf_abc123XYZ") == "hf_abc123XYZ")
+ }
+
+ test("sanitizeToken preserves tokens containing special characters") {
+ assert(sanitizeToken("abc-xyz_123.45") == "abc-xyz_123.45")
+ }
+
+ // ────────────────────────────────────────────────────────────────────────
+ // isAllowedMediaHost — SSRF allowlist
+ // ────────────────────────────────────────────────────────────────────────
+
+ test("isAllowedMediaHost rejects null host") {
+ assert(!isAllowedMediaHost(null))
+ }
+
+ test("isAllowedMediaHost rejects empty host") {
+ assert(!isAllowedMediaHost(""))
+ }
+
+ test("isAllowedMediaHost accepts exact match on huggingface.co") {
+ assert(isAllowedMediaHost("huggingface.co"))
+ }
+
+ test("isAllowedMediaHost accepts HF Hub CDN subdomains") {
+ assert(isAllowedMediaHost("cdn-uploads.huggingface.co"))
+ assert(isAllowedMediaHost("cdn-lfs.huggingface.co"))
+ }
+
+ test("isAllowedMediaHost is case-insensitive") {
+ assert(isAllowedMediaHost("HUGGINGFACE.CO"))
+ assert(isAllowedMediaHost("Cdn-LFS.HuggingFace.co"))
+ }
+
+ test("isAllowedMediaHost accepts fal.media and its subdomains") {
+ assert(isAllowedMediaHost("fal.media"))
+ assert(isAllowedMediaHost("v3b.fal.media"))
+ }
+
+ test("isAllowedMediaHost accepts replicate.delivery and its subdomains") {
+ assert(isAllowedMediaHost("replicate.delivery"))
+ assert(isAllowedMediaHost("cdn.replicate.delivery"))
+ }
+
+ test("isAllowedMediaHost accepts replicate.com and its subdomains") {
+ assert(isAllowedMediaHost("replicate.com"))
+ assert(isAllowedMediaHost("api.replicate.com"))
+ }
+
+ test("isAllowedMediaHost rejects lookalike domains (leading-dot guard)") {
+ // The critical security test: evilhuggingface.co must NOT match
huggingface.co.
+ assert(!isAllowedMediaHost("evilhuggingface.co"))
+ assert(!isAllowedMediaHost("notfal.media"))
+ assert(!isAllowedMediaHost("xreplicate.com"))
+ }
+
+ test("isAllowedMediaHost rejects unrelated public domains") {
+ assert(!isAllowedMediaHost("google.com"))
+ assert(!isAllowedMediaHost("example.org"))
+ }
+
+ test("isAllowedMediaHost rejects localhost") {
+ assert(!isAllowedMediaHost("localhost"))
+ assert(!isAllowedMediaHost("LOCALHOST"))
+ }
+
+ test("isAllowedMediaHost rejects loopback IPs") {
+ assert(!isAllowedMediaHost("127.0.0.1"))
+ assert(!isAllowedMediaHost("0.0.0.0"))
+ }
+
+ test("isAllowedMediaHost rejects private IP ranges") {
+ assert(!isAllowedMediaHost("10.0.0.1"))
+ assert(!isAllowedMediaHost("192.168.1.1"))
+ assert(!isAllowedMediaHost("172.16.0.1"))
+ }
+
+ test("isAllowedMediaHost rejects cloud metadata IP") {
+ assert(!isAllowedMediaHost("169.254.169.254"))
+ }
+
+ // ────────────────────────────────────────────────────────────────────────
+ // errorJson — JSON escaping
+ // ────────────────────────────────────────────────────────────────────────
+
+ test("errorJson produces well-formed JSON for a simple message") {
+ val out = errorJson("Failed to fetch models.")
+ val node = mapper.readTree(out)
+ assert(node.get("error").asText() == "Failed to fetch models.")
+ }
+
+ test("errorJson escapes double quotes in the message") {
+ val out = errorJson("She said \"hi\"")
+ // Must round-trip cleanly back to the original — Jackson handled the
escaping.
+ val node = mapper.readTree(out)
+ assert(node.get("error").asText() == "She said \"hi\"")
+ }
+
+ test("errorJson escapes backslashes in the message") {
+ val out = errorJson("path C:\\Users\\evil")
+ val node = mapper.readTree(out)
+ assert(node.get("error").asText() == "path C:\\Users\\evil")
+ }
+
+ test("errorJson escapes newlines and tabs in the message") {
+ val out = errorJson("line1\nline2\tindented")
+ val node = mapper.readTree(out)
+ assert(node.get("error").asText() == "line1\nline2\tindented")
+ }
+
+ test("errorJson handles empty message") {
+ val out = errorJson("")
+ val node = mapper.readTree(out)
+ assert(node.get("error").asText() == "")
+ }
+
+ // ────────────────────────────────────────────────────────────────────────
+ // inferAudioContentType — extension → MIME type
+ // ────────────────────────────────────────────────────────────────────────
+
+ test("inferAudioContentType returns audio/mpeg for .mp3") {
+ assert(inferAudioContentType(Paths.get("clip.mp3")) == "audio/mpeg")
+ }
+
+ test("inferAudioContentType returns audio/mpeg for .mpeg") {
+ assert(inferAudioContentType(Paths.get("clip.mpeg")) == "audio/mpeg")
+ }
+
+ test("inferAudioContentType returns audio/wav for .wav") {
+ assert(inferAudioContentType(Paths.get("clip.wav")) == "audio/wav")
+ }
+
+ test("inferAudioContentType returns audio/flac for .flac") {
+ assert(inferAudioContentType(Paths.get("clip.flac")) == "audio/flac")
+ }
+
+ test("inferAudioContentType returns audio/ogg for .ogg") {
+ assert(inferAudioContentType(Paths.get("clip.ogg")) == "audio/ogg")
+ }
+
+ test("inferAudioContentType returns audio/ogg for .oga") {
+ assert(inferAudioContentType(Paths.get("clip.oga")) == "audio/ogg")
+ }
+
+ test("inferAudioContentType returns audio/webm for .webm") {
+ assert(inferAudioContentType(Paths.get("clip.webm")) == "audio/webm")
+ }
+
+ test("inferAudioContentType returns audio/webm;codecs=opus for .opus") {
+ assert(inferAudioContentType(Paths.get("clip.opus")) ==
"audio/webm;codecs=opus")
+ }
+
+ test("inferAudioContentType returns audio/amr for .amr") {
+ assert(inferAudioContentType(Paths.get("clip.amr")) == "audio/amr")
+ }
+
+ test("inferAudioContentType returns audio/m4a for .m4a") {
+ assert(inferAudioContentType(Paths.get("clip.m4a")) == "audio/m4a")
+ }
+
+ test("inferAudioContentType falls back to octet-stream for unknown
extension") {
+ assert(inferAudioContentType(Paths.get("clip.xyz")) ==
"application/octet-stream")
+ assert(inferAudioContentType(Paths.get("noextension")) ==
"application/octet-stream")
+ }
+
+ test("inferAudioContentType is case-insensitive") {
+ assert(inferAudioContentType(Paths.get("clip.WAV")) == "audio/wav")
+ assert(inferAudioContentType(Paths.get("clip.MP3")) == "audio/mpeg")
+ }
+
+ // ────────────────────────────────────────────────────────────────────────
+ // uploadAudioReference — input validation & size cap
+ // ────────────────────────────────────────────────────────────────────────
+
+ test("uploadAudioReference returns 400 when stream is null") {
+ val response = resource.uploadAudioReference("voice.wav", null)
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference returns 400 when stream is empty") {
+ val response = resource.uploadAudioReference("voice.wav", streamOfBytes(0))
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference rejects .sh extension") {
+ val response = resource.uploadAudioReference("evil.sh",
streamOf("payload"))
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference rejects .html extension") {
+ val response = resource.uploadAudioReference("trick.html",
streamOf("<script>"))
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference rejects .bat extension") {
+ val response = resource.uploadAudioReference("run.bat", streamOf("@echo
off"))
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference rejects .exe extension") {
+ val response = resource.uploadAudioReference("malware.exe", streamOf("MZ"))
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference rejects files with no extension") {
+ val response = resource.uploadAudioReference("recording", streamOf("data"))
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference rejects null filename (default audio.bin not in
allowlist)") {
+ val response = resource.uploadAudioReference(null, streamOf("data"))
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference rejects empty filename (default audio.bin not in
allowlist)") {
+ val response = resource.uploadAudioReference("", streamOf("data"))
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference rejects whitespace-only filename") {
+ val response = resource.uploadAudioReference(" ", streamOf("data"))
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference accepts a valid .wav upload") {
+ val payload =
"RIFF....WAVE....fake-wav-content".getBytes(StandardCharsets.UTF_8)
+ val response = resource.uploadAudioReference("voice.wav", new
ByteArrayInputStream(payload))
+ assert(response.getStatus == 200)
+
+ val node = mapper.readTree(entityString(response))
+ assert(node.has("path"))
+ assert(node.has("fileName"))
+ assert(node.get("fileName").asText() == "voice.wav")
+
+ // Verify the file was actually written with the right contents.
+ val savedPath = Paths.get(node.get("path").asText())
+ assert(Files.exists(savedPath))
+ assert(Files.readAllBytes(savedPath).sameElements(payload))
+ // The saved file should land inside the audioTempDir.
+
assert(savedPath.toAbsolutePath.normalize().startsWith(audioTempDir.toAbsolutePath.normalize()))
+ }
+
+ test("uploadAudioReference lowercases the extension for the temp file") {
+ val response = resource.uploadAudioReference("voice.WAV", streamOf("RIFF"))
+ assert(response.getStatus == 200)
+
+ val node = mapper.readTree(entityString(response))
+ val savedPath = Paths.get(node.get("path").asText())
+ assert(savedPath.getFileName.toString.endsWith(".wav"))
+ }
+
+ test("uploadAudioReference strips path components from filename") {
+ // ?filename=../../etc/passwd should be reduced to passwd (no extension) —
rejected
+ val response = resource.uploadAudioReference("../../etc/passwd",
streamOf("data"))
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference returns 413 for payload exceeding
MAX_AUDIO_BYTES") {
+ val oversize = MAX_AUDIO_BYTES.toInt + 1
+ val response = resource.uploadAudioReference("big.wav",
streamOfBytes(oversize))
+ assert(response.getStatus == 413)
+ assertErrorBody(response)
+ }
+
+ test("uploadAudioReference cleans up partial file when size cap is
exceeded") {
+ val sweepBefore = listAudioTempFiles()
+ val oversize = MAX_AUDIO_BYTES.toInt + 1
+ val response = resource.uploadAudioReference("big.wav",
streamOfBytes(oversize))
+ assert(response.getStatus == 413)
+ val sweepAfter = listAudioTempFiles()
+ // No new file should remain after the rejection (existing files
unchanged).
+ assert(
+ sweepAfter.length <= sweepBefore.length,
+ s"oversize upload left a partial file: before=$sweepBefore
after=$sweepAfter"
+ )
+ }
+
+ test("uploadAudioReference accepts all allowlisted extensions") {
+ ALLOWED_AUDIO_EXTENSIONS.foreach { ext =>
+ val response = resource.uploadAudioReference(s"clip$ext",
streamOf("data"))
+ assert(response.getStatus == 200, s"extension $ext should have been
accepted")
+ }
+ }
+
+ private def listAudioTempFiles(): Array[Path] = {
+ if (!Files.exists(audioTempDir)) return Array.empty
+ val stream = Files.list(audioTempDir)
+ try {
+ val arr =
stream.toArray.asInstanceOf[Array[Object]].map(_.asInstanceOf[Path])
+ arr
+ } finally {
+ stream.close()
+ }
+ }
+
+ // ────────────────────────────────────────────────────────────────────────
+ // previewUploadedAudio — path validation
+ // ────────────────────────────────────────────────────────────────────────
+
+ test("previewUploadedAudio returns 400 when path is null") {
+ val response = resource.previewUploadedAudio(null)
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("previewUploadedAudio returns 400 when path is empty") {
+ val response = resource.previewUploadedAudio("")
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("previewUploadedAudio returns 400 when path is whitespace") {
+ val response = resource.previewUploadedAudio(" ")
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("previewUploadedAudio returns 403 when path is outside the temp
directory") {
+ val response = resource.previewUploadedAudio("/etc/passwd")
+ assert(response.getStatus == 403)
+ assertErrorBody(response)
+ }
+
+ test("previewUploadedAudio rejects path traversal attempts") {
+ val traversalPath =
+ audioTempDir.toAbsolutePath.toString + "/../../etc/passwd"
+ val response = resource.previewUploadedAudio(traversalPath)
+ assert(response.getStatus == 403)
+ assertErrorBody(response)
+ }
+
+ test("previewUploadedAudio returns 404 for a non-existent file inside temp
dir") {
+ val ghost =
audioTempDir.resolve("does-not-exist.wav").toAbsolutePath.toString
+ val response = resource.previewUploadedAudio(ghost)
+ assert(response.getStatus == 404)
+ assertErrorBody(response)
+ }
+
+ test("previewUploadedAudio returns 404 when path points to a directory, not
a file") {
+ val response =
resource.previewUploadedAudio(audioTempDir.toAbsolutePath.toString)
+ assert(response.getStatus == 404)
+ assertErrorBody(response)
+ }
+
+ test("previewUploadedAudio streams back a valid file with correct
content-type") {
+ val payload = "fake-wav-bytes".getBytes(StandardCharsets.UTF_8)
+ val file = Files.createTempFile(audioTempDir, "test-preview-", ".wav")
+ Files.write(file, payload)
+
+ val response = resource.previewUploadedAudio(file.toAbsolutePath.toString)
+ assert(response.getStatus == 200)
+ val bytes = entityBytes(response)
+ assert(bytes.sameElements(payload))
+ }
+
+ test(
+ "previewUploadedAudio returns 413 when the on-disk file exceeds
MAX_AUDIO_BYTES (defense-in-depth)"
+ ) {
+ // /upload-audio caps ingest at MAX_AUDIO_BYTES, but the preview endpoint
+ // shouldn't trust that invariant — a future bug or out-of-band write could
+ // leave an oversized file in the temp dir. Reads of those files must not
+ // OOM the JVM.
+ val file = Files.createTempFile(audioTempDir, "test-oversize-", ".wav")
+ // Create a sparse file of size MAX_AUDIO_BYTES + 1 without actually
+ // writing that many bytes to disk.
+ val raf = new java.io.RandomAccessFile(file.toFile, "rw")
+ try raf.setLength(MAX_AUDIO_BYTES + 1)
+ finally raf.close()
+
+ val response = resource.previewUploadedAudio(file.toAbsolutePath.toString)
+ assert(response.getStatus == 413)
+ assertErrorBody(response)
+ }
+
+ test("previewUploadedAudio normalizes the path before checking containment")
{
+ val payload = "ok".getBytes(StandardCharsets.UTF_8)
+ val file = Files.createTempFile(audioTempDir, "test-norm-", ".wav")
+ Files.write(file, payload)
+
+ // Same file referenced via a non-normalized path (extra slashes /
dot-segments).
+ val weird = audioTempDir.toAbsolutePath.toString + "/./" +
file.getFileName.toString
+ val response = resource.previewUploadedAudio(weird)
+ assert(response.getStatus == 200)
+ }
+
+ // ────────────────────────────────────────────────────────────────────────
+ // proxyRemoteMedia — input validation & SSRF
+ // ────────────────────────────────────────────────────────────────────────
+
+ test("proxyRemoteMedia returns 400 for null URL") {
+ val response = resource.proxyRemoteMedia(null)
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia returns 400 for empty URL") {
+ val response = resource.proxyRemoteMedia("")
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia returns 400 for whitespace URL") {
+ val response = resource.proxyRemoteMedia(" ")
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia rejects file:// URLs") {
+ val response = resource.proxyRemoteMedia("file:///etc/passwd")
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia rejects ftp:// URLs") {
+ val response = resource.proxyRemoteMedia("ftp://example.com/data")
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia rejects javascript: URLs") {
+ val response = resource.proxyRemoteMedia("javascript:alert(1)")
+ assert(response.getStatus == 400)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia rejects localhost via SSRF allowlist (403)") {
+ val response = resource.proxyRemoteMedia("http://localhost:8080/admin")
+ assert(response.getStatus == 403)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia rejects 127.0.0.1 via SSRF allowlist (403)") {
+ val response =
resource.proxyRemoteMedia("http://127.0.0.1:9200/_cat/indices")
+ assert(response.getStatus == 403)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia rejects AWS metadata IP via SSRF allowlist (403)") {
+ val response =
+ resource.proxyRemoteMedia("http://169.254.169.254/latest/meta-data/iam/")
+ assert(response.getStatus == 403)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia rejects private IP ranges via SSRF allowlist (403)") {
+ val response = resource.proxyRemoteMedia("http://10.0.0.5/admin")
+ assert(response.getStatus == 403)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia rejects lookalike huggingface domain (leading-dot
guard)") {
+ val response =
resource.proxyRemoteMedia("https://evilhuggingface.co/payload")
+ assert(response.getStatus == 403)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia rejects arbitrary public domains not on the
allowlist") {
+ val response = resource.proxyRemoteMedia("https://example.com/anything")
+ assert(response.getStatus == 403)
+ assertErrorBody(response)
+ }
+
+ test("proxyRemoteMedia rejects URLs with missing host") {
+ val response = resource.proxyRemoteMedia("http:///no-host-here")
+ assert(response.getStatus == 403)
+ assertErrorBody(response)
+ }
+
+ // ────────────────────────────────────────────────────────────────────────
+ // listModels — cache hit paths (no HF traffic required)
+ // ────────────────────────────────────────────────────────────────────────
+
+ test("listModels returns 200 with cached body when cache hits and no user
token") {
+ val cachedBody = """[{"id":"test-model","label":"test-model"}]"""
+ modelCache.put("text-generation", cachedBody)
+
+ val response = resource.listModels("text-generation", null, null)
+ assert(response.getStatus == 200)
+ assert(entityString(response) == cachedBody)
+ }
+
+ test("listModels cache hit does NOT carry the truncated header") {
+ val cachedBody = """[{"id":"x"}]"""
+ modelCache.put("foo", cachedBody)
+
+ val response = resource.listModels("foo", null, null)
+ assert(response.getHeaderString(TRUNCATED_HEADER) == null)
+ }
+
+ test("listModels cache hit is keyed by task — different task is a miss") {
+ modelCache.put("text-classification", """[{"id":"a"}]""")
+ // We don't want to actually hit HF, so we just assert that the cache for
"image-classification"
+ // is empty after `put` — i.e., Guava cache lookup is task-specific.
+ assert(modelCache.getIfPresent("image-classification") == null)
+ assert(modelCache.getIfPresent("text-classification") != null)
+ }
+
+ test("listModels with X-HF-Token header bypasses the cache (does not read
from it)") {
+ val cachedBody = """[{"id":"only-cached"}]"""
+ modelCache.put("text-generation", cachedBody)
+
+ // We can't easily assert the resource then *successfully* calls HF
without a mock,
+ // but we can verify the cache content is unchanged after a user-token call
+ // (i.e., user-token requests don't populate the same cache slot).
+ val before = modelCache.getIfPresent("text-generation")
+ try {
+ resource.listModels("text-generation", null, "hf_user_token_xyz")
+ } catch {
+ case _: Throwable => () // network may fail in unit tests; we only care
about cache state
+ }
+ val after = modelCache.getIfPresent("text-generation")
+ assert(before == after, "user-token request should not alter the anonymous
cache slot")
+ }
+
+ // ────────────────────────────────────────────────────────────────────────
+ // listTasks — cache hit paths (no HF traffic required)
+ // ────────────────────────────────────────────────────────────────────────
+
+ test("listTasks returns 200 with cached body when cache hits and no user
token") {
+ val cachedBody = """[{"tag":"text-generation","label":"Text
Generation"}]"""
+ taskCache.put(TASKS_CACHE_KEY, cachedBody)
+
+ val response = resource.listTasks(null)
+ assert(response.getStatus == 200)
+ assert(entityString(response) == cachedBody)
+ }
+
+ test("listTasks with empty token header still reads from cache (sanitized to
anonymous)") {
+ val cachedBody = """[{"tag":"x","label":"X"}]"""
+ taskCache.put(TASKS_CACHE_KEY, cachedBody)
+
+ val response = resource.listTasks(" ")
+ assert(response.getStatus == 200)
+ assert(entityString(response) == cachedBody)
+ }
+
+ test("listTasks with X-HF-Token header bypasses the cache") {
+ val cachedBody = """[{"tag":"only-cached"}]"""
+ taskCache.put(TASKS_CACHE_KEY, cachedBody)
+
+ val before = taskCache.getIfPresent(TASKS_CACHE_KEY)
+ try {
+ resource.listTasks("hf_user_token_xyz")
+ } catch {
+ case _: Throwable => ()
+ }
+ val after = taskCache.getIfPresent(TASKS_CACHE_KEY)
+ assert(before == after, "user-token request should not alter the anonymous
task cache slot")
+ }
+
+ // ────────────────────────────────────────────────────────────────────────
+ // sweepOldAudioFiles — temp directory cleanup
+ // ────────────────────────────────────────────────────────────────────────
+
+ test("sweepOldAudioFiles deletes files older than the TTL") {
+ val oldFile = Files.createTempFile(audioTempDir, "test-sweep-old-", ".wav")
+ Files.write(oldFile, "old".getBytes(StandardCharsets.UTF_8))
+ // Force the lastModified time to be older than the TTL window.
+ val oldTime = java.nio.file.attribute.FileTime.fromMillis(
+ System.currentTimeMillis() - AUDIO_TEMP_TTL_MS - 60000L
+ )
+ Files.setLastModifiedTime(oldFile, oldTime)
+
+ sweepOldAudioFiles(audioTempDir)
+
+ assert(!Files.exists(oldFile), "old file should have been swept")
+ }
+
+ test("sweepOldAudioFiles preserves files newer than the TTL") {
+ val freshFile = Files.createTempFile(audioTempDir, "test-sweep-fresh-",
".wav")
+ Files.write(freshFile, "fresh".getBytes(StandardCharsets.UTF_8))
+ // Default mtime is now; explicitly set to be safe.
+ val recentTime =
java.nio.file.attribute.FileTime.fromMillis(System.currentTimeMillis())
+ Files.setLastModifiedTime(freshFile, recentTime)
+
+ sweepOldAudioFiles(audioTempDir)
+
+ assert(Files.exists(freshFile), "fresh file should have been preserved")
+ }
+
+ test("sweepOldAudioFiles handles a missing directory gracefully") {
+ val ghostDir =
+ Paths.get(System.getProperty("java.io.tmpdir"), "texera-hf-audio-ghost-"
+ System.nanoTime())
+ // Don't create it. The sweep should swallow the IOException and not throw.
+ sweepOldAudioFiles(ghostDir)
+ // (no assertion needed — reaching this line means no exception escaped)
+ succeed
+ }
+
+ test("sweepOldAudioFiles only deletes regular files, not subdirectories") {
+ val subdir = Files.createTempDirectory(audioTempDir, "test-sweep-subdir-")
+ val oldTime = java.nio.file.attribute.FileTime.fromMillis(
+ System.currentTimeMillis() - AUDIO_TEMP_TTL_MS - 60000L
+ )
+ Files.setLastModifiedTime(subdir, oldTime)
+
+ sweepOldAudioFiles(audioTempDir)
+
+ assert(Files.exists(subdir), "subdirectory should be preserved (sweep only
deletes files)")
+ // cleanup
+ Files.deleteIfExists(subdir)
+ }
+}