carloea2 commented on code in PR #4136:
URL: https://github.com/apache/texera/pull/4136#discussion_r2656970453


##########
file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala:
##########
@@ -1372,4 +1405,378 @@ class DatasetResource {
         Right(response)
     }
   }
+
+  // === Multipart helpers ===
+
+  private def getDatasetBy(ownerEmail: String, datasetName: String) = {
+    val dataset = context
+      .select(DATASET.fields: _*)
+      .from(DATASET)
+      .leftJoin(USER)
+      .on(USER.UID.eq(DATASET.OWNER_UID))
+      .where(USER.EMAIL.eq(ownerEmail))
+      .and(DATASET.NAME.eq(datasetName))
+      .fetchOneInto(classOf[Dataset])
+    if (dataset == null) {
+      throw new BadRequestException("Dataset not found")
+    }
+    dataset
+  }
+
+  private def validateFilePathOrThrow(filePath: String): String = {
+    val p = Option(filePath).getOrElse("")
+    val s = p.replace("\\", "/")
+    if (
+      p.isEmpty ||
+      s.startsWith("/") ||
+      s.split("/").exists(seg => seg == "." || seg == "..") ||
+      s.exists(ch => ch == 0.toChar || ch < 0x20.toChar || ch == 0x7f.toChar)
+    ) throw new BadRequestException("Invalid filePath")
+    p
+  }
+
+  private def initMultipartUpload(
+      did: Integer,
+      encodedFilePath: String,
+      numParts: Optional[Integer],
+      uid: Integer
+  ): Response = {
+
+    withTransaction(context) { ctx =>
+      if (!userHasWriteAccess(ctx, did, uid)) {
+        throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE)
+      }
+
+      val dataset = getDatasetByID(ctx, did)
+      val repositoryName = dataset.getRepositoryName
+
+      val filePath =
+        validateFilePathOrThrow(URLDecoder.decode(encodedFilePath, 
StandardCharsets.UTF_8.name()))
+
+      val numPartsValue = numParts.toScala.getOrElse {
+        throw new BadRequestException("numParts is required for 
initialization")
+      }
+      if (numPartsValue < 1 || numPartsValue > 
MAXIMUM_NUM_OF_MULTIPART_S3_PARTS) {
+        throw new BadRequestException(
+          "numParts must be between 1 and " + MAXIMUM_NUM_OF_MULTIPART_S3_PARTS
+        )
+      }
+
+      // Reject if a session already exists
+      val exists = ctx.fetchExists(
+        ctx
+          .selectOne()
+          .from(DATASET_UPLOAD_SESSION)
+          .where(
+            DATASET_UPLOAD_SESSION.UID
+              .eq(uid)
+              .and(DATASET_UPLOAD_SESSION.DID.eq(did))
+              .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath))
+          )
+      )
+      if (exists) {
+        throw new WebApplicationException(
+          "Upload already in progress for this filePath",
+          Response.Status.CONFLICT
+        )
+      }
+
+      val presign = LakeFSStorageClient.initiatePresignedMultipartUploads(
+        repositoryName,
+        filePath,
+        numPartsValue
+      )
+
+      val uploadIdStr = presign.getUploadId
+      val physicalAddr = presign.getPhysicalAddress
+
+      // If anything fails after this point, abort LakeFS multipart
+      try {
+        val rowsInserted = ctx
+          .insertInto(DATASET_UPLOAD_SESSION)
+          .set(DATASET_UPLOAD_SESSION.FILE_PATH, filePath)
+          .set(DATASET_UPLOAD_SESSION.DID, did)
+          .set(DATASET_UPLOAD_SESSION.UID, uid)
+          .set(DATASET_UPLOAD_SESSION.UPLOAD_ID, uploadIdStr)
+          .set(DATASET_UPLOAD_SESSION.PHYSICAL_ADDRESS, physicalAddr)
+          .set(DATASET_UPLOAD_SESSION.NUM_PARTS_REQUESTED, numPartsValue)
+          .onDuplicateKeyIgnore()
+          .execute()
+
+        if (rowsInserted != 1) {
+          LakeFSStorageClient.abortPresignedMultipartUploads(
+            repositoryName,
+            filePath,
+            uploadIdStr,
+            physicalAddr
+          )
+          throw new WebApplicationException(
+            "Upload already in progress for this filePath",
+            Response.Status.CONFLICT
+          )
+        }
+
+        // Pre-create part rows 1..numPartsValue with empty ETag.
+        // This makes per-part locking cheap and deterministic.
+
+        val gs = DSL.generateSeries(1, numPartsValue).asTable("gs", "pn")
+        val PN = gs.field("pn", classOf[Integer])

Review Comment:
   Fixed it. Thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to