This is an automated email from the ASF dual-hosted git repository.
pingsutw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new 0ba9509 SUBMARINE-1053. Create model management REST API
0ba9509 is described below
commit 0ba950964456d94ccfe74dca967d3f78c71dcd87
Author: jeff-901 <[email protected]>
AuthorDate: Thu Nov 4 10:07:22 2021 +0800
SUBMARINE-1053. Create model management REST API
### What is this PR for?
Add model management REST API, including list registered model and list
model version...
Delete unused mlflow db file
### What type of PR is it?
Feature
### Todos
### What is the Jira issue?
https://issues.apache.org/jira/browse/SUBMARINE-1053
### How should this be tested?
java test
### Screenshots (if appropriate)
### Questions:
* Do the license files need updating? No
* Are there breaking changes for older versions? No
* Does this need new documentation? No
Author: jeff-901 <[email protected]>
Signed-off-by: Kevin <[email protected]>
Closes #787 from jeff-901/SUBMARINE-1053 and squashes the following commits:
f1fbe661 [jeff-901] fix db url
a95ada29 [jeff-901] fix test
25c5ae3f [jeff-901] add test
9f67e695 [jeff-901] add example entity
43a86558 [jeff-901] delete unused file
a58d675e [jeff-901] edit descruption string
50ef2d2d [jeff-901] add tag test
31116018 [jeff-901] add tag api
2163c487 [jeff-901] add rest api without tags
---
conf/submarine-site.xml | 2 +-
conf/submarine-site.xml.template | 2 +-
.../server/database/utils/ModelBatisUtil.java | 72 -----
.../database/service/RegisteredModelService.java | 2 +-
.../submarine/server/rest/ModelVersionRestApi.java | 302 ++++++++++++++++++++
.../server/rest/RegisteredModelRestApi.java | 307 +++++++++++++++++++++
.../submarine/server/rest/RestConstants.java | 13 +
.../src/main/resources/hibernate.cfg.xml | 2 +-
.../src/main/resources/mbgConfiguration.xml | 2 +-
.../src/main/resources/modelbatis-config.xml | 60 ----
.../src/main/resources/submarine-site.xml | 2 +-
.../server/rest/ModelVersionRestApiTest.java | 178 ++++++++++++
.../server/rest/RegisteredModelRestApiTest.java | 158 +++++++++++
13 files changed, 964 insertions(+), 138 deletions(-)
diff --git a/conf/submarine-site.xml b/conf/submarine-site.xml
index eba2529..a742d47 100755
--- a/conf/submarine-site.xml
+++ b/conf/submarine-site.xml
@@ -113,7 +113,7 @@
</property>
<property>
<name>jdbc.url</name>
-
<value>jdbc:mysql://127.0.0.1:3306/submarine?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false</value>
+
<value>jdbc:mysql://127.0.0.1:3306/submarine?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false&allowMultiQueries=true</value>
</property>
<property>
<name>jdbc.username</name>
diff --git a/conf/submarine-site.xml.template b/conf/submarine-site.xml.template
index ccb2aa0..e234d53 100755
--- a/conf/submarine-site.xml.template
+++ b/conf/submarine-site.xml.template
@@ -113,7 +113,7 @@
</property>
<property>
<name>jdbc.url</name>
-
<value>jdbc:mysql://127.0.0.1:3306/submarine?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false</value>
+
<value>jdbc:mysql://127.0.0.1:3306/submarine?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false&allowMultiQueries=true</value>
</property>
<property>
<name>jdbc.username</name>
diff --git
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/database/utils/ModelBatisUtil.java
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/database/utils/ModelBatisUtil.java
deleted file mode 100644
index 58e2d9f..0000000
---
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/database/utils/ModelBatisUtil.java
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.submarine.server.database.utils;
-
-import org.apache.ibatis.io.Resources;
-import org.apache.ibatis.session.SqlSession;
-import org.apache.ibatis.session.SqlSessionFactory;
-import org.apache.ibatis.session.SqlSessionFactoryBuilder;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import java.io.IOException;
-import java.io.Reader;
-import java.util.Properties;
-
-public class ModelBatisUtil {
- private static final Logger LOG =
LoggerFactory.getLogger(ModelBatisUtil.class);
-
- private static final SqlSessionFactory sqlSessionFactory;
-
- static {
-
- try (Reader reader =
- Resources.getResourceAsReader("modelbatis-config.xml");
- ) {
- String jdbcClassName = "com.mysql.jdbc.Driver";
- String jdbcUrl = "jdbc:mysql://127.0.0.1:3306/mlflowdb";
- String jdbcUserName = "mlflow";
- String jdbcPassword = "password";
- LOG.info("MyBatisUtil -> jdbcClassName: {}, jdbcUrl: {}, jdbcUserName:
{}, jdbcPassword: {}",
- jdbcClassName, jdbcUrl, jdbcUserName, jdbcPassword);
-
- Properties props = new Properties();
- props.setProperty("jdbc.driverClassName", jdbcClassName);
- props.setProperty("jdbc.url", jdbcUrl);
- props.setProperty("jdbc.username", jdbcUserName);
- props.setProperty("jdbc.password", jdbcPassword);
-
- sqlSessionFactory = new SqlSessionFactoryBuilder().build(reader, props);
- } catch (IOException e) {
- LOG.error(e.getMessage(), e);
- throw new RuntimeException(e.getMessage());
- }
- }
-
- /**
- * Get Session.
- *
- * @return SqlSession
- */
- public static SqlSession getSqlSession() {
- return sqlSessionFactory.openSession();
- }
-
-
-}
diff --git
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/service/RegisteredModelService.java
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/service/RegisteredModelService.java
index 02dae36..76ffd8a 100644
---
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/service/RegisteredModelService.java
+++
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/service/RegisteredModelService.java
@@ -20,7 +20,7 @@
package org.apache.submarine.server.model.database.service;
import org.apache.ibatis.session.SqlSession;
-import
org.apache.submarine.commons.runtime.exception.SubmarineRuntimeException;
+import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
import org.apache.submarine.server.database.utils.MyBatisUtil;
import
org.apache.submarine.server.model.database.entities.RegisteredModelEntity;
import
org.apache.submarine.server.model.database.mappers.RegisteredModelMapper;
diff --git
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ModelVersionRestApi.java
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ModelVersionRestApi.java
new file mode 100644
index 0000000..56c405d
--- /dev/null
+++
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ModelVersionRestApi.java
@@ -0,0 +1,302 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+package org.apache.submarine.server.rest;
+
+import java.util.List;
+import javax.ws.rs.Consumes;
+import javax.ws.rs.DELETE;
+import javax.ws.rs.DefaultValue;
+import javax.ws.rs.GET;
+import javax.ws.rs.PATCH;
+import javax.ws.rs.POST;
+import javax.ws.rs.Path;
+import javax.ws.rs.PathParam;
+import javax.ws.rs.Produces;
+import javax.ws.rs.QueryParam;
+import javax.ws.rs.core.MediaType;
+import javax.ws.rs.core.Response;
+import io.swagger.v3.oas.annotations.Operation;
+import io.swagger.v3.oas.annotations.media.Content;
+import io.swagger.v3.oas.annotations.media.Schema;
+import io.swagger.v3.oas.annotations.responses.ApiResponse;
+
+
+import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
+import org.apache.submarine.server.model.database.entities.ModelVersionEntity;
+import
org.apache.submarine.server.model.database.entities.ModelVersionTagEntity;
+import org.apache.submarine.server.model.database.service.ModelVersionService;
+
+
+import
org.apache.submarine.server.model.database.service.ModelVersionTagService;
+import org.apache.submarine.server.response.JsonResponse;
+
+/**
+ * Model version REST API v1.
+ */
+@Path(RestConstants.V1 + "/" + RestConstants.MODEL_VERSION)
+@Produces({ MediaType.APPLICATION_JSON + "; " + RestConstants.CHARSET_UTF8 })
+public class ModelVersionRestApi {
+
+ /* Model version service */
+ private final ModelVersionService modelVersionService = new
ModelVersionService();
+
+ /* Model version tag service */
+ private final ModelVersionTagService modelVersionTagService = new
ModelVersionTagService();
+
+ /**
+ * Return the Pong message for test the connectivity.
+ *
+ * @return Pong message
+ */
+ @GET
+ @Path(RestConstants.PING)
+ @Consumes(MediaType.APPLICATION_JSON)
+ @Operation(summary = "Ping submarine server", tags = {
+ "model-version"}, description = "Return the Pong message for test the
connectivity", responses = {
+ @ApiResponse(responseCode = "200", description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
String.class)))})
+ public Response ping() {
+ return new
JsonResponse.Builder<String>(Response.Status.OK).success(true).result("Pong").build();
+ }
+
+ /**
+ * List all model versions under same registered model name.
+ *
+ * @param name registered model name
+ * @return model version list
+ */
+ @GET
+ @Path("/{name}")
+ @Operation(summary = "List model versions", tags = {"model-version"},
responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class)))})
+ public Response
listModelVersions(@PathParam(RestConstants.MODEL_VERSION_NAME) String name) {
+ try {
+ List<ModelVersionEntity> modelVersionList =
modelVersionService.selectAllVersions(name);
+ return new
JsonResponse.Builder<List<ModelVersionEntity>>(Response.Status.OK).success(true)
+ .message("List all model version
instances").result(modelVersionList).build();
+ } catch (SubmarineRuntimeException e) {
+ return parseModelVersionServiceException(e);
+ }
+ }
+
+ /**
+ * Get detailed info about the model version by name and version.
+ *
+ * @param name model version's name
+ * @param version model version's version
+ * @return detailed info about the model version
+ */
+ @GET
+ @Path("/{name}/{version}")
+ @Operation(summary = "Get detailed info about the model version", tags =
{"model-version"}, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))),
+ @ApiResponse(responseCode = "404", description = "ModelVersionEntity not
found")})
+ public Response getModelVersion(@PathParam(RestConstants.MODEL_VERSION_NAME)
String name,
+
@PathParam(RestConstants.MODEL_VERSION_VERSION) Integer version) {
+ try {
+ ModelVersionEntity modelVersion =
modelVersionService.selectWithTag(name, version);
+ return new
JsonResponse.Builder<ModelVersionEntity>(Response.Status.OK).success(true)
+ .message("Get the model version
instance").result(modelVersion).build();
+ } catch (SubmarineRuntimeException e) {
+ return parseModelVersionServiceException(e);
+ }
+ }
+
+ /**
+ * Delete the model version with model version name and version.
+ *
+ * @param name model version's name
+ * @param version model version's version
+ * @return seccess message
+ */
+ @DELETE
+ @Path("/{name}/{version}")
+ @Operation(summary = "Delete the model version", tags = {"model-version"},
responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))),
+ @ApiResponse(responseCode = "404", description = "ModelVersionEntity not
found")})
+ public Response
deleteModelVersion(@PathParam(RestConstants.MODEL_VERSION_NAME) String name,
+
@PathParam(RestConstants.MODEL_VERSION_VERSION) Integer version) {
+ try {
+ modelVersionService.delete(name, version);
+ return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
+ .message("Delete the model version instance").build();
+ } catch (SubmarineRuntimeException e) {
+ return parseModelVersionServiceException(e);
+ }
+ }
+
+ /**
+ * Update the model version.
+ *
+ * @param entity model version entity
+ * example: {
+ * 'name': 'example_name',
+ * 'version': 1,
+ * 'description': 'new_description',
+ * 'currentStage': 'production',
+ * 'dataset': 'new_dataset'
+ * }
+ * @return success message
+ */
+ @PATCH
+ @Path("")
+ @Operation(summary = "Update the model version", tags = {"model-version"},
responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))),
+ @ApiResponse(responseCode = "404", description = "ModelVersionEntity not
found")})
+ public Response updateModelVersion(ModelVersionEntity entity) {
+ try {
+ checkModelVersion(entity);
+ modelVersionService.update(entity);
+ return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
+ .message("Update the model version instance").build();
+ } catch (SubmarineRuntimeException e) {
+ return parseModelVersionServiceException(e);
+ }
+ }
+
+ /**
+ * Create a model version tag.
+ *
+ * @param name model version's name
+ * @param version model version's version
+ * @param tag tag name
+ * @return success message
+ */
+ @POST
+ @Path("/tag")
+ @Consumes({ RestConstants.MEDIA_TYPE_YAML, MediaType.APPLICATION_JSON })
+ @Operation(summary = "Create a model version tag instance", tags = {
"model-version" }, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))) })
+ public Response createModelVersionTag(@DefaultValue("") @QueryParam("name")
String name,
+ @DefaultValue("")
@QueryParam("version") String version,
+ @DefaultValue("") @QueryParam("tag")
String tag) {
+ try {
+ checkModelVersionTag(name, version, tag);
+ ModelVersionTagEntity modelVersionTag = new ModelVersionTagEntity();
+ modelVersionTag.setName(name);
+ modelVersionTag.setVersion(Integer.parseInt(version));
+ modelVersionTag.setTag(tag);
+ modelVersionTagService.insert(modelVersionTag);
+ return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
+ .message("Create a model version tag instance").build();
+ } catch (SubmarineRuntimeException e) {
+ return parseModelVersionServiceException(e);
+ }
+ }
+
+ /**
+ * Delete a model version tag.
+ *
+ * @param name model version's name
+ * @param version model version's version
+ * @param tag tag name
+ * @return success message
+ */
+ @DELETE
+ @Path("/tag")
+ @Consumes({ RestConstants.MEDIA_TYPE_YAML, MediaType.APPLICATION_JSON })
+ @Operation(summary = "Delete a model version tag instance", tags = {
"model-version" }, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))) })
+ public Response deleteModelVersionTag(@DefaultValue("") @QueryParam("name")
String name,
+ @DefaultValue("")
@QueryParam("version") String version,
+ @DefaultValue("") @QueryParam("tag")
String tag) {
+ try {
+ checkModelVersionTag(name, version, tag);
+ ModelVersionTagEntity modelVersionTag = new ModelVersionTagEntity();
+ modelVersionTag.setName(name);
+ modelVersionTag.setVersion(Integer.parseInt(version));
+ modelVersionTag.setTag(tag);
+ modelVersionTagService.delete(modelVersionTag);
+ return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
+ .message("Delete a model version tag instance").build();
+ } catch (SubmarineRuntimeException e) {
+ return parseModelVersionServiceException(e);
+ }
+ }
+
+ private Response parseModelVersionServiceException(SubmarineRuntimeException
e) {
+ return new
JsonResponse.Builder<String>(e.getCode()).message(e.getMessage()).build();
+ }
+
+ private void checkModelVersion(ModelVersionEntity entity) {
+ if (entity == null) {
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Model version entity object is null.");
+ }
+ if (entity.getName() == null || entity.getName().equals("")) {
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Model version's name is null.");
+ }
+ if (entity.getVersion() == null) {
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Model version's version is null.");
+ }
+ ModelVersionEntity modelVersion =
modelVersionService.select(entity.getName(), entity.getVersion());
+ if (modelVersion == null) {
+ throw new
SubmarineRuntimeException(Response.Status.NOT_FOUND.getStatusCode(),
+ "Invalid. Model version entity with same name and version is not
existed.");
+ }
+ }
+
+ /**
+ * Check if model version tag is valid.
+ *
+ * @param name model version's name
+ * @param version model version's version
+ * @param tag tag name
+ */
+ private void checkModelVersionTag(String name, String version, String tag) {
+ if (name.equals("")){
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Model version's name is null.");
+ }
+ if (version.equals("")) {
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Model version's version is null.");
+ }
+ Integer versionNum;
+ try {
+ versionNum = Integer.parseInt(version);
+ if (versionNum < 1){
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Model version's version must be bigger than 0.");
+ }
+ } catch (NumberFormatException e){
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Model version's version must be an integer.");
+ }
+ if (tag.equals("")) {
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Tag name is null.");
+ }
+ ModelVersionEntity modelVersion = modelVersionService.select(name,
+ versionNum);
+ if (modelVersion == null){
+ throw new
SubmarineRuntimeException(Response.Status.NOT_FOUND.getStatusCode(),
+ "Invalid. Model version " + name + " version " + versionNum + " is
not existed.");
+ }
+ }
+}
diff --git
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RegisteredModelRestApi.java
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RegisteredModelRestApi.java
new file mode 100644
index 0000000..56d0086
--- /dev/null
+++
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RegisteredModelRestApi.java
@@ -0,0 +1,307 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+package org.apache.submarine.server.rest;
+
+import java.util.List;
+import javax.ws.rs.Consumes;
+import javax.ws.rs.DELETE;
+import javax.ws.rs.DefaultValue;
+import javax.ws.rs.GET;
+import javax.ws.rs.PATCH;
+import javax.ws.rs.POST;
+import javax.ws.rs.Path;
+import javax.ws.rs.PathParam;
+import javax.ws.rs.Produces;
+import javax.ws.rs.QueryParam;
+import javax.ws.rs.core.MediaType;
+import javax.ws.rs.core.Response;
+import io.swagger.v3.oas.annotations.Operation;
+import io.swagger.v3.oas.annotations.media.Content;
+import io.swagger.v3.oas.annotations.media.Schema;
+import io.swagger.v3.oas.annotations.responses.ApiResponse;
+
+
+import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
+import
org.apache.submarine.server.model.database.entities.RegisteredModelEntity;
+import
org.apache.submarine.server.model.database.entities.RegisteredModelTagEntity;
+import
org.apache.submarine.server.model.database.service.RegisteredModelService;
+
+import
org.apache.submarine.server.model.database.service.RegisteredModelTagService;
+import org.apache.submarine.server.response.JsonResponse;
+
+
+
+
+/**
+ * Registered model REST API v1.
+ */
+@Path(RestConstants.V1 + "/" + RestConstants.REGISTERED_MODEL)
+@Produces({ MediaType.APPLICATION_JSON + "; " + RestConstants.CHARSET_UTF8 })
+public class RegisteredModelRestApi {
+
+ /* Registered model service */
+ private final RegisteredModelService registeredModelService = new
RegisteredModelService();
+
+ /* Registered model tag service */
+ private final RegisteredModelTagService registeredModelTagService = new
RegisteredModelTagService();
+
+ /**
+ * Return the Pong message for test the connectivity.
+ *
+ * @return Pong message
+ */
+ @GET
+ @Path(RestConstants.PING)
+ @Consumes(MediaType.APPLICATION_JSON)
+ @Operation(summary = "Ping submarine server", tags = {
+ "registered-model" }, description = "Return the Pong message for test
the connectivity", responses = {
+ @ApiResponse(responseCode = "200", description = "successful operation",
+ content = @Content(schema = @Schema(implementation = String.class)))
})
+ public Response ping() {
+ return new
JsonResponse.Builder<String>(Response.Status.OK).success(true).result("Pong").build();
+ }
+
+ /**
+ * Create a registered model.
+ *
+ * @param entity registered model entity
+ * example: {
+ * 'name': 'example_name'
+ * 'description': 'example_description'
+ * 'tags': ['123', '456']
+ * }
+ * @return success message
+ */
+ @POST
+ @Consumes({ RestConstants.MEDIA_TYPE_YAML, MediaType.APPLICATION_JSON })
+ @Operation(summary = "Create a registered model instance", tags = {
"registered-model" }, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))) })
+ public Response createRegisteredModel(RegisteredModelEntity entity) {
+ try {
+ checkRegisteredModel(entity);
+ registeredModelService.insert(entity);
+ return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
+ .message("Create a registered model instance").build();
+ } catch (SubmarineRuntimeException e) {
+ return parseRegisteredModelServiceException(e);
+ }
+ }
+
+ /**
+ * List all registered models.
+ *
+ * @return registered model list
+ */
+ @GET
+ @Operation(summary = "List registered models", tags = { "registered-model"
}, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))) })
+ public Response listRegisteredModels() {
+ try {
+ List<RegisteredModelEntity> registeredModelList =
registeredModelService.selectAll();
+ return new
JsonResponse.Builder<List<RegisteredModelEntity>>(Response.Status.OK).success(true)
+ .message("List all registered model
instances").result(registeredModelList).build();
+ } catch (SubmarineRuntimeException e) {
+ return parseRegisteredModelServiceException(e);
+ }
+ }
+
+ /**
+ * Get detailed info about the registered model by registered model name.
+ *
+ * @param name registered model name
+ * @return detailed info about the registered model
+ */
+ @GET
+ @Path("/{name}")
+ @Operation(summary = "Get detailed info about the registered model",
+ tags = { "registered-model" }, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))),
+ @ApiResponse(responseCode = "404", description = "RegisteredModelEntity
not found") })
+ public Response
getRegisteredModel(@PathParam(RestConstants.REGISTERED_MODEL_NAME) String name)
{
+ try {
+ RegisteredModelEntity registeredModel =
registeredModelService.selectWithTag(name);
+ return new
JsonResponse.Builder<RegisteredModelEntity>(Response.Status.OK).success(true)
+ .message("Get the registered model
instance").result(registeredModel).build();
+ } catch (SubmarineRuntimeException e) {
+ return parseRegisteredModelServiceException(e);
+ }
+ }
+
+ /**
+ * Update the registered model with registered model name.
+ *
+ * @param name old registered model name
+ * @param entity registered model entity
+ * example: {
+ * 'name': 'new_name'
+ * 'description': 'new_description'
+ * }
+ * @return success message
+ */
+ @PATCH
+ @Path("/{name}")
+ @Operation(summary = "Update the registered model", tags = {
"registered-model" }, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))),
+ @ApiResponse(responseCode = "404", description = "RegisteredModelEntity
not found") })
+ public Response updateRegisteredModel(
+ @PathParam(RestConstants.REGISTERED_MODEL_NAME) String name,
RegisteredModelEntity entity) {
+ try {
+ RegisteredModelEntity oldRegisteredModelEntity =
registeredModelService.select(name);
+ if (oldRegisteredModelEntity == null) {
+ throw new
SubmarineRuntimeException(Response.Status.NOT_FOUND.getStatusCode(),
+ "Invalid. Registered model " + name + " is not existed.");
+ }
+ checkRegisteredModel(entity);
+ if (!name.equals(entity.getName())) {
+ registeredModelService.rename(name, entity.getName());
+ }
+ registeredModelService.update(entity);
+ } catch (SubmarineRuntimeException e) {
+ return parseRegisteredModelServiceException(e);
+ }
+
+ return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
+ .message("Update the registered model instance").build();
+ }
+
+ /**
+ * Delete the registered model with registered model name.
+ *
+ * @param name registered model name
+ * @return success message
+ */
+ @DELETE
+ @Path("/{name}")
+ @Operation(summary = "Delete the registered model", tags = {
"registered-model" }, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))),
+ @ApiResponse(responseCode = "404", description = "RegisteredModelEntity
not found") })
+ public Response
deleteRegisteredModel(@PathParam(RestConstants.REGISTERED_MODEL_NAME) String
name) {
+ try {
+ registeredModelService.delete(name);
+ return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
+ .message("Delete the registered model instance").build();
+ } catch (SubmarineRuntimeException e) {
+ return parseRegisteredModelServiceException(e);
+ }
+ }
+
+ /**
+ * Create a registered model tag.
+ *
+ * @param name registered model name
+ * @param tag tag name
+ * @return success message
+ */
+ @POST
+ @Path("/tag")
+ @Consumes({ RestConstants.MEDIA_TYPE_YAML, MediaType.APPLICATION_JSON })
+ @Operation(summary = "Create a registered model tag instance", tags = {
"registered-model" }, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))) })
+ public Response createRegisteredModelTag(@DefaultValue("")
@QueryParam("name") String name,
+ @DefaultValue("")
@QueryParam("tag") String tag) {
+ try {
+ checkRegisteredModelTag(name, tag);
+ RegisteredModelTagEntity registeredModelTag = new
RegisteredModelTagEntity();
+ registeredModelTag.setName(name);
+ registeredModelTag.setTag(tag);
+ registeredModelTagService.insert(registeredModelTag);
+ return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
+ .message("Create a registered model tag instance").build();
+ } catch (SubmarineRuntimeException e) {
+ return parseRegisteredModelServiceException(e);
+ }
+ }
+
+ /**
+ * Delete a registered model tag.
+ *
+ * @param name registered model name
+ * @param tag tag name
+ * @return success message
+ */
+ @DELETE
+ @Path("/tag")
+ @Consumes({ RestConstants.MEDIA_TYPE_YAML, MediaType.APPLICATION_JSON })
+ @Operation(summary = "Delete a registered model tag instance", tags = {
"registered-model" }, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation =
JsonResponse.class))) })
+ public Response deleteRegisteredModelTag(@DefaultValue("")
@QueryParam("name") String name,
+ @DefaultValue("")
@QueryParam("tag") String tag) {
+ try {
+ checkRegisteredModelTag(name, tag);
+ RegisteredModelTagEntity registeredModelTag = new
RegisteredModelTagEntity();
+ registeredModelTag.setName(name);
+ registeredModelTag.setTag(tag);
+ registeredModelTagService.delete(registeredModelTag);
+ return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
+ .message("Delete a registered model tag instance").build();
+ } catch (SubmarineRuntimeException e) {
+ return parseRegisteredModelServiceException(e);
+ }
+ }
+
+ private Response
parseRegisteredModelServiceException(SubmarineRuntimeException e) {
+ return new
JsonResponse.Builder<String>(e.getCode()).message(e.getMessage()).build();
+ }
+
+ /**
+ * Check if registered model spec is valid spec.
+ *
+ * @param entity registered model entity
+ */
+ private void checkRegisteredModel(RegisteredModelEntity entity) {
+ if (entity == null) {
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Registered model entity object is null.");
+ }
+ if (entity.getName() == null || entity.getName().equals("")) {
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Registered model name is null.");
+ }
+ }
+
+ /**
+ * Check if registered model tag is valid spec.
+ *
+ * @param name registered model name
+ * @param tag tag name
+ */
+ private void checkRegisteredModelTag(String name, String tag) {
+ if (name.equals("")) {
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Registered model name is null.");
+ }
+ if (tag.equals("")) {
+ throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
+ "Invalid. Tag name is null.");
+ }
+ RegisteredModelEntity registeredModel =
registeredModelService.select(name);
+ if (registeredModel == null){
+ throw new
SubmarineRuntimeException(Response.Status.NOT_FOUND.getStatusCode(),
+ "Invalid. Registered model " + name + " is not existed.");
+ }
+ }
+}
diff --git
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RestConstants.java
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RestConstants.java
index 0d6d3fc..e6555fa 100644
---
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RestConstants.java
+++
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RestConstants.java
@@ -76,4 +76,17 @@ public class RestConstants {
public static final String LOG_DIR_KEY = "SUBMARINE_TENSORBOARD_LOG_DIR";
public static final String LOG_DIR_VALUE = "/logs/mylog";
+ /**
+ * Registered Model.
+ */
+ public static final String REGISTERED_MODEL = "registered-model";
+ public static final String REGISTERED_MODEL_NAME = "name";
+
+ /**
+ * Model Version.
+ */
+ public static final String MODEL_VERSION = "model-version";
+ public static final String MODEL_VERSION_NAME = "name";
+ public static final String MODEL_VERSION_VERSION = "version";
+
}
diff --git a/submarine-server/server-core/src/main/resources/hibernate.cfg.xml
b/submarine-server/server-core/src/main/resources/hibernate.cfg.xml
index 0682748..318bc96 100644
--- a/submarine-server/server-core/src/main/resources/hibernate.cfg.xml
+++ b/submarine-server/server-core/src/main/resources/hibernate.cfg.xml
@@ -28,7 +28,7 @@
com.mysql.jdbc.Driver
</property>
<property name="hibernate.connection.url">
-
jdbc:mysql://127.0.0.1:3306/submarine?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false
+
jdbc:mysql://127.0.0.1:3306/submarine?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false&allowMultiQueries=true
</property>
<property name="hibernate.connection.username">
submarine
diff --git
a/submarine-server/server-core/src/main/resources/mbgConfiguration.xml
b/submarine-server/server-core/src/main/resources/mbgConfiguration.xml
index 888d6fe..2d4487d 100644
--- a/submarine-server/server-core/src/main/resources/mbgConfiguration.xml
+++ b/submarine-server/server-core/src/main/resources/mbgConfiguration.xml
@@ -28,7 +28,7 @@
</commentGenerator>
<jdbcConnection driverClass="com.mysql.jdbc.Driver"
- connectionURL="jdbc:mysql://127.0.0.1:3306/submarine"
+
connectionURL="jdbc:mysql://127.0.0.1:3306/submarine?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false&allowMultiQueries=true"
userId="submarine"
password="password">
</jdbcConnection>
diff --git
a/submarine-server/server-core/src/main/resources/modelbatis-config.xml
b/submarine-server/server-core/src/main/resources/modelbatis-config.xml
deleted file mode 100644
index 04352c4..0000000
--- a/submarine-server/server-core/src/main/resources/modelbatis-config.xml
+++ /dev/null
@@ -1,60 +0,0 @@
-<?xml version='1.0' encoding='UTF-8' ?>
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-<!DOCTYPE configuration PUBLIC '-//mybatis.org//DTD Config 3.0//EN'
- 'http://mybatis.org/dtd/mybatis-3-config.dtd'>
-<configuration>
- <settings>
- <setting name="cacheEnabled" value="true"/>
- <setting name="lazyLoadingEnabled" value="false"/>
- <setting name="aggressiveLazyLoading" value="true"/>
- <setting name="logImpl" value="STDOUT_LOGGING"/>
- </settings>
-
- <typeAliases>
- <package name="com.github.pagehelper.model"/>
- </typeAliases>
-
- <plugins>
- <plugin interceptor="com.github.pagehelper.PageInterceptor">
- <property name="helperDialect" value="mysql"/>
- <property name="offsetAsPageNum" value="true"/>
- <property name="rowBoundsWithCount" value="true"/>
- </plugin>
- </plugins>
-
- <environments default="development">
- <environment id="development">
- <transactionManager type="JDBC"/>
- <dataSource type="POOLED">
- <property name="driver" value="com.mysql.jdbc.Driver"/>
- <property name="url"
value="jdbc:mysql://127.0.0.1:3306/mlflowdb?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false"/>
- <property name="username" value="mlflow"/>
- <property name="password" value="password"/>
- <property name="poolPingQuery" value="SELECT NOW()"/>
- <property name="poolPingEnabled" value="true"/>
- </dataSource>
- </environment>
- </environments>
-
- <mappers>
- <mapper
resource='org/apache/submarine/database/mappers/RegisteredModelNameMapper.xml'/>
- <mapper
resource='org/apache/submarine/database/mappers/ModelVersionMapper.xml'/>
- </mappers>
-</configuration>
diff --git a/submarine-server/server-core/src/main/resources/submarine-site.xml
b/submarine-server/server-core/src/main/resources/submarine-site.xml
index e79cf70..c03bda5 100755
--- a/submarine-server/server-core/src/main/resources/submarine-site.xml
+++ b/submarine-server/server-core/src/main/resources/submarine-site.xml
@@ -113,7 +113,7 @@
</property>
<property>
<name>jdbc.url</name>
-
<value>jdbc:mysql://127.0.0.1:3306/submarine?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false</value>
+
<value>jdbc:mysql://127.0.0.1:3306/submarine?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false&allowMultiQueries=true</value>
</property>
<property>
<name>jdbc.username</name>
diff --git
a/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
b/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
new file mode 100644
index 0000000..ea682ef
--- /dev/null
+++
b/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.rest;
+
+import static org.junit.Assert.assertEquals;
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonArray;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParser;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import java.util.ArrayList;
+import java.util.List;
+import javax.ws.rs.core.Response;
+import org.apache.submarine.server.api.experiment.ExperimentId;
+import org.apache.submarine.server.gson.ExperimentIdDeserializer;
+import org.apache.submarine.server.gson.ExperimentIdSerializer;
+import org.apache.submarine.server.model.database.entities.ModelVersionEntity;
+import
org.apache.submarine.server.model.database.entities.RegisteredModelEntity;
+import org.apache.submarine.server.model.database.service.ModelVersionService;
+import
org.apache.submarine.server.model.database.service.RegisteredModelService;
+
+
+public class ModelVersionRestApiTest {
+ private ModelVersionRestApi modelVersionRestApi = new ModelVersionRestApi();
+ private final String registeredModelName = "testRegisteredModel";
+ private final String registeredModelDescription = "test registered model
description";
+ private final String modelVersionDescription = "test model version
description";
+ private final String newModelVersionDescription = "new test registered model
description";
+ private final String modelVersionSource = "s3://submarine/test";
+ private final String modelVersionUid = "test123";
+ private final String modelVersionExperimentId = "experiment_123";
+ private final String modelVersionTag = "testTag";
+
+ private final RegisteredModelService registeredModelService = new
RegisteredModelService();
+
+ private final ModelVersionService modelVersionService = new
ModelVersionService();
+
+ private static final GsonBuilder gsonBuilder = new GsonBuilder()
+ .registerTypeAdapter(ExperimentId.class, new ExperimentIdSerializer())
+ .registerTypeAdapter(ExperimentId.class, new ExperimentIdDeserializer());
+ private static Gson gson = gsonBuilder.setDateFormat("yyyy-MM-dd
HH:mm:ss").create();
+
+ private ModelVersionEntity modelVersion1 = new ModelVersionEntity();
+
+ private ModelVersionEntity modelVersion2 = new ModelVersionEntity();
+
+
+ @Before
+ public void createModelVersion() {
+ RegisteredModelEntity registeredModel = new RegisteredModelEntity();
+ registeredModel.setName(registeredModelName);
+ registeredModel.setDescription(registeredModelDescription);
+ registeredModelService.insert(registeredModel);
+ modelVersion1.setName(registeredModelName);
+ modelVersion1.setDescription(modelVersionDescription + "1");
+ modelVersion1.setVersion(1);
+ modelVersion1.setSource(modelVersionSource + "1");
+ modelVersion1.setUserId(modelVersionUid);
+ modelVersion1.setExperimentId(modelVersionExperimentId);
+ modelVersionService.insert(modelVersion1);
+ modelVersion2.setName(registeredModelName);
+ modelVersion2.setDescription(modelVersionDescription + "2");
+ modelVersion2.setVersion(2);
+ modelVersion2.setSource(modelVersionSource + "2");
+ modelVersion2.setUserId(modelVersionUid);
+ modelVersion2.setExperimentId(modelVersionExperimentId);
+ modelVersionService.insert(modelVersion2);
+ }
+
+ @Test
+ public void testListModelVersion(){
+ Response listModelVersionResponse =
modelVersionRestApi.listModelVersions(registeredModelName);
+ List<ModelVersionEntity> result = getResultListFromResponse(
+ listModelVersionResponse, ModelVersionEntity.class);
+ assertEquals(2, result.size());
+ verifyResult(modelVersion1, result.get(0));
+ verifyResult(modelVersion2, result.get(1));
+ }
+
+ @Test
+ public void testGetModelVersion(){
+ Response getModelVersionResponse =
modelVersionRestApi.getModelVersion(registeredModelName, 1);
+ ModelVersionEntity result = getResultFromResponse(getModelVersionResponse,
ModelVersionEntity.class);
+ verifyResult(modelVersion1, result);
+ }
+
+ @Test
+ public void testAddAndDeleteModelVersionTag(){
+ modelVersionRestApi.createModelVersionTag(registeredModelName, "1",
modelVersionTag);
+ Response getModelVersionResponse =
modelVersionRestApi.getModelVersion(registeredModelName, 1);
+ ModelVersionEntity result = getResultFromResponse(
+ getModelVersionResponse, ModelVersionEntity.class);
+ assertEquals(1, result.getTags().size());
+ assertEquals(modelVersionTag, result.getTags().get(0));
+
+ modelVersionRestApi.deleteModelVersionTag(registeredModelName, "1",
modelVersionTag);
+ getModelVersionResponse =
modelVersionRestApi.getModelVersion(registeredModelName, 1);
+ result = getResultFromResponse(
+ getModelVersionResponse , ModelVersionEntity.class);
+ assertEquals(0, result.getTags().size());
+ }
+
+ @Test
+ public void testUpdateModelVersion(){
+ ModelVersionEntity newModelVersion = new ModelVersionEntity();
+ newModelVersion.setName(registeredModelName);
+ newModelVersion.setVersion(1);
+ newModelVersion.setDescription(newModelVersionDescription);
+ modelVersionRestApi.updateModelVersion(newModelVersion);
+ Response getModelVersionResponse =
modelVersionRestApi.getModelVersion(registeredModelName, 1);
+ ModelVersionEntity result = getResultFromResponse(
+ getModelVersionResponse , ModelVersionEntity.class);
+ assertEquals(newModelVersionDescription, result.getDescription());
+ }
+
+ @Test
+ public void testDeleteModelVersion(){
+ modelVersionRestApi.deleteModelVersion(registeredModelName, 1);
+ Response listModelVersionResponse =
modelVersionRestApi.listModelVersions(registeredModelName);
+ List<ModelVersionEntity> result = getResultListFromResponse(
+ listModelVersionResponse, ModelVersionEntity.class);
+ assertEquals(1, result.size());
+ verifyResult(modelVersion2, result.get(0));
+ }
+
+ @After
+ public void tearDown(){
+ registeredModelService.deleteAll();
+ }
+
+ private <T> T getResultFromResponse(Response response, Class<T> typeT) {
+ String entity = (String) response.getEntity();
+ JsonObject object = new JsonParser().parse(entity).getAsJsonObject();
+ JsonElement result = object.get("result");
+ return gson.fromJson(result, typeT);
+ }
+
+ private <T> List<T> getResultListFromResponse(Response response, Class<T>
typeT) {
+ String entity = (String) response.getEntity();
+ JsonObject object = new JsonParser().parse(entity).getAsJsonObject();
+ JsonElement result = object.get("result");
+ List<T> list = new ArrayList<T>();
+ JsonArray array = result.getAsJsonArray();
+ for (JsonElement jsonElement : array) {
+ list.add(gson.fromJson(jsonElement, typeT));
+ }
+ return list;
+ }
+
+ private void verifyResult(ModelVersionEntity result, ModelVersionEntity
actual){
+ assertEquals(result.getName(), actual.getName());
+ assertEquals(result.getDescription(), actual.getDescription());
+ assertEquals(result.getVersion(), actual.getVersion());
+ assertEquals(result.getSource(), actual.getSource());
+ assertEquals(result.getExperimentId(), actual.getExperimentId());
+ }
+}
diff --git
a/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/RegisteredModelRestApiTest.java
b/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/RegisteredModelRestApiTest.java
new file mode 100644
index 0000000..a6cee07
--- /dev/null
+++
b/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/RegisteredModelRestApiTest.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.rest;
+
+import static org.junit.Assert.assertEquals;
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonArray;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParser;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import java.util.ArrayList;
+import java.util.List;
+import javax.ws.rs.core.Response;
+import org.apache.submarine.server.api.experiment.ExperimentId;
+import org.apache.submarine.server.gson.ExperimentIdDeserializer;
+import org.apache.submarine.server.gson.ExperimentIdSerializer;
+import
org.apache.submarine.server.model.database.entities.RegisteredModelEntity;
+import
org.apache.submarine.server.model.database.service.RegisteredModelService;
+
+public class RegisteredModelRestApiTest {
+ private final RegisteredModelService registeredModelService = new
RegisteredModelService();
+ private final String registeredModelName = "testRegisteredModel";
+ private final String newRegisteredModelName = "newTestRegisteredModel";
+ private final String registeredModelDescription = "test registered model
description";
+ private final String newRegisteredModelDescription = "new test registered
model description";
+ private final String registeredModelTag = "testTag";
+ private final String defaultRegisteredModelTag = "defaultTag";
+ private static final GsonBuilder gsonBuilder = new GsonBuilder()
+ .registerTypeAdapter(ExperimentId.class, new ExperimentIdSerializer())
+ .registerTypeAdapter(ExperimentId.class, new ExperimentIdDeserializer());
+ private static Gson gson = gsonBuilder.setDateFormat("yyyy-MM-dd
HH:mm:ss").create();
+ private RegisteredModelEntity registeredModel = new RegisteredModelEntity();
+
+ private final RegisteredModelRestApi registeredModelRestApi = new
RegisteredModelRestApi();
+
+
+ @Before
+ public void testCreateRegisteredModel() {
+ registeredModel.setName(registeredModelName);
+ registeredModel.setDescription(registeredModelDescription);
+ List<String> tags = new ArrayList<>();
+ tags.add(defaultRegisteredModelTag);
+ registeredModel.setTags(tags);
+ registeredModelService.insert(registeredModel);
+ }
+
+ @Test
+ public void testListRegisteredModel(){
+ Response listRegisteredModelResponse =
registeredModelRestApi.listRegisteredModels();
+ List<RegisteredModelEntity> result = getResultListFromResponse(
+ listRegisteredModelResponse, RegisteredModelEntity.class);
+ assertEquals(1, result.size());
+ verifyResult(registeredModel, result.get(0));
+ }
+
+ @Test
+ public void testGetModelRegisteredModel(){
+ Response getRegisteredModelResponse =
registeredModelRestApi.getRegisteredModel(registeredModelName);
+ RegisteredModelEntity result = getResultFromResponse(
+ getRegisteredModelResponse, RegisteredModelEntity.class);
+ verifyResult(registeredModel, result);
+ }
+
+ @Test
+ public void testAddAndDeleteRegisteredModelTag(){
+ registeredModelRestApi.deleteRegisteredModelTag(registeredModelName,
defaultRegisteredModelTag);
+ Response getRegisteredModelResponse =
registeredModelRestApi.getRegisteredModel(registeredModelName);
+ RegisteredModelEntity result = getResultFromResponse(
+ getRegisteredModelResponse , RegisteredModelEntity.class);
+ assertEquals(0, result.getTags().size());
+
+ registeredModelRestApi.createRegisteredModelTag(registeredModelName,
registeredModelTag);
+ getRegisteredModelResponse =
registeredModelRestApi.getRegisteredModel(registeredModelName);
+ result = getResultFromResponse(
+ getRegisteredModelResponse, RegisteredModelEntity.class);
+ assertEquals(1, result.getTags().size());
+ assertEquals(registeredModelTag, result.getTags().get(0));
+
+
+ }
+
+ @Test
+ public void testUpdateRegisteredModel(){
+ RegisteredModelEntity newRegisteredModel = new RegisteredModelEntity();
+ newRegisteredModel.setName(newRegisteredModelName);
+ newRegisteredModel.setDescription(newRegisteredModelDescription);
+ List<String> tags = new ArrayList<>();
+ tags.add(defaultRegisteredModelTag);
+ newRegisteredModel.setTags(tags);
+ registeredModelRestApi.updateRegisteredModel(registeredModelName,
newRegisteredModel);
+ Response getRegisteredModelResponse =
registeredModelRestApi.getRegisteredModel(newRegisteredModelName);
+ RegisteredModelEntity result = getResultFromResponse(
+ getRegisteredModelResponse , RegisteredModelEntity.class);
+ verifyResult(newRegisteredModel, result);
+ }
+
+ @Test
+ public void testDeleteRegisteredModel(){
+ registeredModelRestApi.deleteRegisteredModel(registeredModelName);
+ Response listRegisteredModelResponse =
registeredModelRestApi.listRegisteredModels();
+ List<RegisteredModelEntity> result = getResultListFromResponse(
+ listRegisteredModelResponse, RegisteredModelEntity.class);
+ assertEquals(0, result.size());
+ }
+
+ @After
+ public void tearDown(){
+ registeredModelService.deleteAll();
+ }
+
+ private <T> T getResultFromResponse(Response response, Class<T> typeT) {
+ String entity = (String) response.getEntity();
+ JsonObject object = new JsonParser().parse(entity).getAsJsonObject();
+ JsonElement result = object.get("result");
+ return gson.fromJson(result, typeT);
+ }
+
+ private <T> List<T> getResultListFromResponse(Response response, Class<T>
typeT) {
+ String entity = (String) response.getEntity();
+ JsonObject object = new JsonParser().parse(entity).getAsJsonObject();
+ JsonElement result = object.get("result");
+ List<T> list = new ArrayList<T>();
+ JsonArray array = result.getAsJsonArray();
+ for (JsonElement jsonElement : array) {
+ list.add(gson.fromJson(jsonElement, typeT));
+ }
+ return list;
+ }
+
+ private void verifyResult(RegisteredModelEntity result,
RegisteredModelEntity actual){
+ assertEquals(result.getName(), actual.getName());
+ assertEquals(result.getDescription(), actual.getDescription());
+ for ( int i = 0; i < result.getTags().size(); i++ ){
+ assertEquals(result.getTags().get(i), actual.getTags().get(i));
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]