This is an automated email from the ASF dual-hosted git repository.
pingsutw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new 0c2b915 SUBMARINE-1106. refactor curd operation of TFJob and
PyTorchJob
0c2b915 is described below
commit 0c2b915393fe01424c551007de4622ca4bb76d28
Author: FatalLin <[email protected]>
AuthorDate: Thu Mar 3 11:46:40 2022 +0800
SUBMARINE-1106. refactor curd operation of TFJob and PyTorchJob
### What is this PR for?
just like the refactoring work of notebook crud operation, we also refactor
the crud operation of tfjob and pytorch job via agent.
### What type of PR is it?
Feature
### Todos
N/A
### What is the Jira issue?
https://issues.apache.org/jira/browse/SUBMARINE-1106
### How should this be tested?
this PR should pass all the existed test cases.
### Screenshots (if appropriate)

### Questions:
* Do the license files need updating? No
* Are there breaking changes for older versions? No
* Does this need new documentation? No
Author: FatalLin <[email protected]>
Signed-off-by: Kevin <[email protected]>
Closes #893 from FatalLin/SUBMARINE-1106 and squashes the following commits:
d410e7fc [FatalLin] fix
85399ffa [FatalLin] using datetime compare function to avoid timezone issue
0c146dc7 [FatalLin] fix
bbd0a143 [FatalLin] rollback path
c227a08e [FatalLin] add uid to entity
8d72dc7a [FatalLin] fix
b016dcb4 [FatalLin] add apache header
5f4fd2ae [FatalLin] refactor curd operation of TFJob and PyTorchJob
913436a2 [FatalLin] merge
39a4b0c6 [FatalLin] refactor tfjob
---
dev-support/database/submarine.sql | 4 +
.../server/api/common/CustomResourceType.java | 2 +-
.../server/api/experiment/Experiment.java | 8 ++
.../server/experiment/ExperimentManager.java | 82 ++++++++++------
.../database/entity/ExperimentEntity.java | 52 ++++++++++
.../server/internal/InternalServiceManager.java | 27 ++++-
.../server/rest/InternalServiceRestApi.java | 3 +-
.../database/mappers/ExperimentMapper.xml | 16 ++-
.../server/experiment/ExperimentManagerTest.java | 47 +++++++--
.../submarine/server/k8s/agent/SubmarineAgent.java | 1 -
.../server/k8s/agent/handler/NotebookHandler.java | 6 +-
.../k8s/agent/handler/PyTorchJobHandler.java | 109 +++++++++++++++++++++
.../server/k8s/agent/handler/TFJobHandler.java | 108 ++++++++++++++++++++
.../server/submitter/k8s/K8sSubmitter.java | 15 ++-
.../server/submitter/k8s/model/AgentPod.java | 95 ++++++++++++++++++
.../k8s/model/pytorchjob/PyTorchJobSpec.java | 3 +-
.../server/submitter/k8s/model/tfjob/TFJob.java | 5 +-
.../submitter/k8s/model/tfjob/TFJobSpec.java | 3 +-
.../submitter/k8s/parser/ExperimentSpecParser.java | 16 ++-
19 files changed, 542 insertions(+), 60 deletions(-)
diff --git a/dev-support/database/submarine.sql
b/dev-support/database/submarine.sql
index 41f3446..44ad116 100644
--- a/dev-support/database/submarine.sql
+++ b/dev-support/database/submarine.sql
@@ -248,6 +248,10 @@ CREATE TABLE `experiment` (
`update_by` varchar(32) DEFAULT NULL COMMENT 'last update user',
`update_time` datetime DEFAULT NULL COMMENT 'last update time',
`experiment_status` varchar(20) DEFAULT NULL COMMENT 'experiment status',
+ `accepted_time` datetime DEFAULT NULL COMMENT 'accept time',
+ `running_time` datetime DEFAULT NULL COMMENT 'running time',
+ `finished_time` datetime DEFAULT NULL COMMENT 'finished time',
+ `uid` varchar(64) DEFAULT NULL COMMENT 'uid of experiment',
PRIMARY KEY `id` (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
diff --git
a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/common/CustomResourceType.java
b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/common/CustomResourceType.java
index 634291e..e2873b9 100644
---
a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/common/CustomResourceType.java
+++
b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/common/CustomResourceType.java
@@ -20,7 +20,7 @@
package org.apache.submarine.server.api.common;
public enum CustomResourceType {
- TFJob("TFJob"), PYTORCHJob("PYTORCHJob"), Notebook("Notebook");
+ TFJob("TFJob"), PyTorchJob("PyTorchJob"), Notebook("Notebook");
private String customResourceType;
CustomResourceType(String customResourceType) {
diff --git
a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/experiment/Experiment.java
b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/experiment/Experiment.java
index 74c8786..17991fc 100644
---
a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/experiment/Experiment.java
+++
b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/experiment/Experiment.java
@@ -157,15 +157,23 @@ public class Experiment {
}
if (experiment.getAcceptedTime() != null) {
this.setAcceptedTime(experiment.getAcceptedTime());
+ } else {
+ this.setAcceptedTime(null);
}
if (experiment.getCreatedTime() != null) {
this.setCreatedTime(experiment.getCreatedTime());
+ } else {
+ this.setCreatedTime(null);
}
if (experiment.getRunningTime() != null) {
this.setRunningTime(experiment.getRunningTime());
+ } else {
+ this.setRunningTime(null);
}
if (experiment.getFinishedTime() != null) {
this.setFinishedTime(experiment.getFinishedTime());
+ } else {
+ this.setFinishedTime(null);
}
}
}
diff --git
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/experiment/ExperimentManager.java
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/experiment/ExperimentManager.java
index e0764d1..45d8eb0 100644
---
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/experiment/ExperimentManager.java
+++
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/experiment/ExperimentManager.java
@@ -45,6 +45,7 @@ import org.apache.submarine.server.api.spec.ExperimentSpec;
import org.apache.submarine.server.experiment.database.entity.ExperimentEntity;
import
org.apache.submarine.server.experiment.database.service.ExperimentService;
import org.apache.submarine.server.rest.RestConstants;
+import org.joda.time.DateTime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.mlflow.tracking.MlflowClient;
@@ -126,8 +127,8 @@ public class ExperimentManager {
experiment.setSpec(spec);
ExperimentEntity entity = buildEntityFromExperiment(experiment);
+ entity.setExperimentStatus(Experiment.Status.STATUS_ACCEPTED.toString());
experimentService.insert(entity);
-
return experiment;
}
@@ -143,8 +144,6 @@ public class ExperimentManager {
ExperimentEntity entity = experimentService.select(id);
Experiment experiment = buildExperimentFromEntity(entity);
- Experiment foundExperiment =
submitter.findExperiment(experiment.getSpec());
- experiment.rebuild(foundExperiment);
return experiment;
}
@@ -162,19 +161,7 @@ public class ExperimentManager {
for (ExperimentEntity entity : entities) {
Experiment experiment = buildExperimentFromEntity(entity);
- Experiment foundExperiment;
- try {
- foundExperiment = submitter.findExperiment(experiment.getSpec());
- } catch (SubmarineRuntimeException e) {
- LOG.warn("Submitter can not find experiment: {}, will delete it",
entity.getId());
- experimentService.delete(entity.getId());
- continue;
- }
- LOG.info("Found experiment: {}", foundExperiment.getStatus());
- if (status == null ||
status.toLowerCase().equals(foundExperiment.getStatus().toLowerCase())) {
- experiment.rebuild(foundExperiment);
- experimentList.add(experiment);
- }
+ experimentList.add(experiment);
}
LOG.info("List experiment: {}", experimentList.size());
return experimentList;
@@ -193,22 +180,11 @@ public class ExperimentManager {
for (ExperimentEntity entity : entities) {
Experiment experiment = buildExperimentFromEntity(entity);
- Experiment foundExperiment;
- try {
- foundExperiment = submitter.findExperiment(experiment.getSpec());
- } catch (SubmarineRuntimeException e) {
- LOG.warn("Submitter can not find experiment: {}, will delete it",
entity.getId());
- experimentService.delete(entity.getId());
- continue;
- }
- LOG.info("Found experiment: {}",
foundExperiment.getSpec().getMeta().getTags());
if (searchTag == null) {
- experiment.rebuild(foundExperiment);
experimentList.add(experiment);
} else {
for (String tag: experiment.getSpec().getMeta().getTags()) {
if (tag.equalsIgnoreCase(searchTag)) {
- experiment.rebuild(foundExperiment);
experimentList.add(experiment);
break;
}
@@ -291,12 +267,8 @@ public class ExperimentManager {
for (ExperimentEntity entity : entities) {
Experiment experiment = buildExperimentFromEntity(entity);
- Experiment foundExperiment =
submitter.findExperiment(experiment.getSpec());
-
- LOG.info("Found experiment: {}", foundExperiment.getStatus());
- if (status == null ||
status.toLowerCase().equals(foundExperiment.getStatus().toLowerCase())) {
- experiment.rebuild(foundExperiment);
+ if (status == null ||
status.toLowerCase().equals(experiment.getStatus().toLowerCase())) {
experimentLogList.add(submitter.getExperimentLogName(
experiment.getSpec(),
@@ -381,8 +353,33 @@ public class ExperimentManager {
*/
private Experiment buildExperimentFromEntity(ExperimentEntity entity) {
Experiment experiment = new Experiment();
+
experiment.setExperimentId(ExperimentId.fromString(entity.getId()));
experiment.setSpec(new Gson().fromJson(entity.getExperimentSpec(),
ExperimentSpec.class));
+ experiment.setStatus(entity.getExperimentStatus());
+
+ if (entity.getCreateTime() != null) {
+ experiment.setCreatedTime(new
DateTime(entity.getCreateTime()).toString());
+ } else {
+ experiment.setCreatedTime(null);
+ }
+ if (entity.getAcceptedTime() != null) {
+ experiment.setAcceptedTime(new
DateTime(entity.getAcceptedTime()).toString());
+ } else {
+ experiment.setAcceptedTime(null);
+ }
+ if (entity.getRunningTime() != null) {
+ experiment.setRunningTime(new
DateTime(entity.getRunningTime()).toString());
+ } else {
+ experiment.setRunningTime(null);
+ }
+ if (entity.getFinishedTime() != null) {
+ experiment.setFinishedTime(new
DateTime(entity.getFinishedTime()).toString());
+ } else {
+ experiment.setFinishedTime(null);
+ }
+ experiment.setUid(entity.getUid());
+
return experiment;
}
@@ -396,6 +393,27 @@ public class ExperimentManager {
ExperimentEntity entity = new ExperimentEntity();
entity.setId(experiment.getSpec().getMeta().getExperimentId());
entity.setExperimentSpec(new
GsonBuilder().disableHtmlEscaping().create().toJson(experiment.getSpec()));
+ if (experiment.getCreatedTime() != null) {
+
entity.setCreateTime(DateTime.parse(experiment.getCreatedTime()).toDate());
+ } else {
+ entity.setCreateTime(null);
+ }
+ if (experiment.getAcceptedTime() != null) {
+
entity.setAcceptedTime(DateTime.parse(experiment.getAcceptedTime()).toDate());
+ } else {
+ entity.setAcceptedTime(null);
+ }
+ if (experiment.getRunningTime() != null) {
+
entity.setRunningTime(DateTime.parse(experiment.getRunningTime()).toDate());
+ } else {
+ entity.setRunningTime(null);
+ }
+ if (experiment.getFinishedTime() != null) {
+
entity.setFinishedTime(DateTime.parse(experiment.getFinishedTime()).toDate());
+ } else {
+ entity.setFinishedTime(null);
+ }
+ entity.setUid(experiment.getUid());
return entity;
}
}
diff --git
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/experiment/database/entity/ExperimentEntity.java
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/experiment/database/entity/ExperimentEntity.java
index 3bad368..197ed01 100644
---
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/experiment/database/entity/ExperimentEntity.java
+++
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/experiment/database/entity/ExperimentEntity.java
@@ -19,7 +19,12 @@
package org.apache.submarine.server.experiment.database.entity;
+import java.util.Date;
+
import org.apache.submarine.server.database.entity.BaseEntity;
+import
org.apache.submarine.server.workbench.database.utils.CustomJsonDateDeserializer;
+
+import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
public class ExperimentEntity extends BaseEntity {
/*
@@ -29,6 +34,17 @@ public class ExperimentEntity extends BaseEntity {
private String experimentStatus;
+ @JsonDeserialize(using = CustomJsonDateDeserializer.class)
+ private Date acceptedTime;
+
+ @JsonDeserialize(using = CustomJsonDateDeserializer.class)
+ private Date runningTime;
+
+ @JsonDeserialize(using = CustomJsonDateDeserializer.class)
+ private Date finishedTime;
+
+ private String uid;
+
public ExperimentEntity() {}
public String getExperimentSpec() {
@@ -46,6 +62,38 @@ public class ExperimentEntity extends BaseEntity {
public void setExperimentStatus(String experimentStatus) {
this.experimentStatus = experimentStatus;
}
+
+ public Date getAcceptedTime() {
+ return acceptedTime;
+ }
+
+ public void setAcceptedTime(Date acceptedTime) {
+ this.acceptedTime = acceptedTime;
+ }
+
+ public Date getRunningTime() {
+ return runningTime;
+ }
+
+ public void setRunningTime(Date runningTime) {
+ this.runningTime = runningTime;
+ }
+
+ public Date getFinishedTime() {
+ return finishedTime;
+ }
+
+ public void setFinishedTime(Date finishedTime) {
+ this.finishedTime = finishedTime;
+ }
+
+ public String getUid() {
+ return uid;
+ }
+
+ public void setUid(String uid) {
+ this.uid = uid;
+ }
@Override
public String toString() {
@@ -57,6 +105,10 @@ public class ExperimentEntity extends BaseEntity {
", updateBy='" + updateBy + '\'' +
", updateTime='" + updateTime + '\'' +
", experimentStatus='" + experimentStatus + "\'" +
+ ", acceptedTime='" + acceptedTime + '\'' +
+ ", runningTime='" + runningTime + '\'' +
+ ", finishedTime='" + finishedTime + '\'' +
+ ", uid='" + uid + '\'' +
'}';
}
}
diff --git
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/internal/InternalServiceManager.java
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/internal/InternalServiceManager.java
index 7451fce..c4ec202 100644
---
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/internal/InternalServiceManager.java
+++
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/internal/InternalServiceManager.java
@@ -28,6 +28,7 @@ import org.apache.submarine.server.api.notebook.Notebook;
import org.apache.submarine.server.experiment.database.entity.ExperimentEntity;
import
org.apache.submarine.server.experiment.database.service.ExperimentService;
import org.apache.submarine.server.notebook.database.service.NotebookService;
+import org.joda.time.DateTime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -56,8 +57,8 @@ public class InternalServiceManager {
Map<String, Object> updateObject) {
if (crType.equals(CustomResourceType.Notebook)) {
return updateNotebookStatus(resourceId, updateObject);
- } else if (crType.equals(CustomResourceType.TFJob) ||
crType.equals(CustomResourceType.PYTORCHJob)) {
- return updateExperimentStatus(resourceId, null);
+ } else if (crType.equals(CustomResourceType.TFJob) ||
crType.equals(CustomResourceType.PyTorchJob)) {
+ return updateExperimentStatus(resourceId, updateObject);
}
return false;
}
@@ -68,7 +69,27 @@ public class InternalServiceManager {
throw new SubmarineRuntimeException(Status.NOT_FOUND.getStatusCode(),
String.format("cannot find experiment with id:%s", resourceId));
}
- // experimentEntity.setExperimentStatus(status);
+
+ if (updateObject.get("status") != null) {
+
experimentEntity.setExperimentStatus(updateObject.get("status").toString());
+ }
+ if (updateObject.get("acceptedTime") != null) {
+ experimentEntity.setAcceptedTime(
+
DateTime.parse(updateObject.get("acceptedTime").toString()).toDate());
+ }
+ if (updateObject.get("createdTime") != null) {
+ experimentEntity.setCreateTime(
+
DateTime.parse(updateObject.get("createdTime").toString()).toDate());
+ }
+ if (updateObject.get("runningTime") != null) {
+ experimentEntity.setRunningTime(
+
DateTime.parse(updateObject.get("runningTime").toString()).toDate());
+ }
+ if (updateObject.get("finishedTime") != null) {
+ experimentEntity.setFinishedTime(
+
DateTime.parse(updateObject.get("finishedTime").toString()).toDate());
+ }
+
return experimentService.update(experimentEntity);
}
diff --git
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/InternalServiceRestApi.java
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/InternalServiceRestApi.java
index d546ee1..07608bd 100644
---
a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/InternalServiceRestApi.java
+++
b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/InternalServiceRestApi.java
@@ -36,7 +36,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.annotations.VisibleForTesting;
-
+import com.google.gson.Gson;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
@@ -82,6 +82,7 @@ public class InternalServiceRestApi {
@PathParam(RestConstants.CUSTOM_RESOURCE_ID) String resourceId,
Map<String, Object> updatedCustomObject) {
try {
+ LOG.info("In:" + new Gson().toJson(updatedCustomObject));
internalServiceManager.updateCRStatus(CustomResourceType.valueOf(type)
, resourceId, updatedCustomObject);
return new JsonResponse.Builder<String>(Response.Status.OK)
diff --git
a/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ExperimentMapper.xml
b/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ExperimentMapper.xml
index d455430..17996f5 100644
---
a/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ExperimentMapper.xml
+++
b/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ExperimentMapper.xml
@@ -30,10 +30,14 @@
<resultMap id="ExperimentEntityResultMap"
type="org.apache.submarine.server.experiment.database.entity.ExperimentEntity"
extends="BaseEntityResultMap">
<result column="experiment_spec" jdbcType="VARCHAR"
property="experimentSpec" />
<result column="experiment_status" property="experimentStatus"/>
+ <result column="accepted_time" property="acceptedTime"/>
+ <result column="running_time" property="runningTime"/>
+ <result column="finished_time" property="finishedTime"/>
+ <result column="uid" property="uid"/>
</resultMap>
<sql id="Base_Column_List">
- id, experiment_spec, create_by, create_time, update_by, update_time,
experiment_status
+ id, experiment_spec, create_by, create_time, update_by, update_time,
experiment_status, accepted_time, running_time, finished_time, uid
</sql>
<select id="selectAll" parameterType="java.lang.String"
resultMap="ExperimentEntityResultMap">
@@ -55,17 +59,23 @@
</delete>
<insert id="insert"
parameterType="org.apache.submarine.server.experiment.database.entity.ExperimentEntity">
- insert into experiment (id, experiment_spec, create_by, create_time,
update_by, update_time, experiment_status)
+ insert into experiment (id, experiment_spec, create_by, create_time,
update_by, update_time, experiment_status, accepted_time, running_time,
finished_time, uid)
values (#{id,jdbcType=VARCHAR}, #{experimentSpec,jdbcType=VARCHAR},
- #{createBy,jdbcType=VARCHAR}, now(), #{updateBy,jdbcType=VARCHAR},
now(), #{experimentStatus,jdbcType=VARCHAR})
+ #{createBy,jdbcType=VARCHAR}, now(), #{updateBy,jdbcType=VARCHAR},
now(), #{experimentStatus,jdbcType=VARCHAR},
+ #{acceptedTime,jdbcType=TIMESTAMP},
#{runningTime,jdbcType=TIMESTAMP}, #{finishedTime,jdbcType=TIMESTAMP},
#{uid,jdbcType=VARCHAR})
</insert>
<update id="update"
parameterType="org.apache.submarine.server.experiment.database.entity.ExperimentEntity">
update experiment
<set>
<if test="experimentSpec != null and experimentStatus != null">
+ create_time = #{createTime, jdbcType=TIMESTAMP},
experiment_spec = #{experimentSpec,jdbcType=VARCHAR},
experiment_status = #{experimentStatus, jdbcType=VARCHAR},
+ accepted_time = #{acceptedTime,jdbcType=TIMESTAMP},
+ running_time = #{runningTime,jdbcType=TIMESTAMP},
+ finished_time = #{finishedTime,jdbcType=TIMESTAMP},
+ uid = #{uid,jdbcType=VARCHAR},
</if>
update_time = now()
</set>
diff --git
a/submarine-server/server-core/src/test/java/org/apache/submarine/server/experiment/ExperimentManagerTest.java
b/submarine-server/server-core/src/test/java/org/apache/submarine/server/experiment/ExperimentManagerTest.java
index ff97662..601f6c5 100644
---
a/submarine-server/server-core/src/test/java/org/apache/submarine/server/experiment/ExperimentManagerTest.java
+++
b/submarine-server/server-core/src/test/java/org/apache/submarine/server/experiment/ExperimentManagerTest.java
@@ -30,6 +30,7 @@ import
org.apache.submarine.server.api.experiment.ExperimentId;
import org.apache.submarine.server.api.spec.ExperimentSpec;
import org.apache.submarine.server.experiment.database.entity.ExperimentEntity;
import
org.apache.submarine.server.experiment.database.service.ExperimentService;
+import org.joda.time.DateTime;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
@@ -45,6 +46,7 @@ import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
@@ -153,7 +155,29 @@ public class ExperimentManagerTest {
ExperimentEntity entity = new ExperimentEntity();
entity.setExperimentSpec(toJson(spec));
entity.setId(experimentId.toString());
-
+ entity.setUid(result.getUid());
+ if (result.getCreatedTime() != null) {
+ entity.setCreateTime(DateTime.parse(result.getCreatedTime()).toDate());
+ } else {
+ entity.setCreateTime(null);
+ }
+ if (result.getAcceptedTime() != null) {
+
entity.setAcceptedTime(DateTime.parse(result.getAcceptedTime()).toDate());
+ } else {
+ entity.setAcceptedTime(null);
+ }
+ if (result.getRunningTime() != null) {
+ entity.setRunningTime(DateTime.parse(result.getRunningTime()).toDate());
+ } else {
+ entity.setRunningTime(null);
+ }
+ if (result.getFinishedTime() != null) {
+
entity.setFinishedTime(DateTime.parse(result.getFinishedTime()).toDate());
+ } else {
+ entity.setFinishedTime(null);
+ }
+ entity.setExperimentStatus(result.getStatus());
+
// Construct expected result
Experiment expectedExperiment = new Experiment();
expectedExperiment.setSpec(spec);
@@ -225,8 +249,6 @@ public class ExperimentManagerTest {
expectedExperiment.setSpec(spec);
expectedExperiment.setExperimentId(experimentId);
expectedExperiment.rebuild(status);
-
-
// Stub service select
// Pretend there is a entity in db
when(mockService.select(any(String.class))).thenReturn(entity);
@@ -257,12 +279,12 @@ public class ExperimentManagerTest {
private void verifyResult(Experiment expected, Experiment actual) {
assertEquals(expected.getUid(), actual.getUid());
- assertEquals(expected.getCreatedTime(), actual.getCreatedTime());
- assertEquals(expected.getRunningTime(), actual.getRunningTime());
- assertEquals(expected.getAcceptedTime(), actual.getAcceptedTime());
+ verifyTimeResult(expected.getCreatedTime(), actual.getCreatedTime());
+ verifyTimeResult(expected.getRunningTime(), actual.getRunningTime());
+ verifyTimeResult(expected.getAcceptedTime(), actual.getAcceptedTime());
assertEquals(expected.getStatus(), actual.getStatus());
assertEquals(expected.getExperimentId(), actual.getExperimentId());
- assertEquals(expected.getFinishedTime(), actual.getFinishedTime());
+ verifyTimeResult(expected.getFinishedTime(), actual.getFinishedTime());
assertEquals(expected.getSpec().getMeta().getName(),
actual.getSpec().getMeta().getName());
assertEquals(expected.getSpec().getMeta().getFramework(),
actual.getSpec().getMeta().getFramework());
assertEquals(expected.getSpec().getMeta().getNamespace(),
actual.getSpec().getMeta().getNamespace());
@@ -277,6 +299,17 @@ public class ExperimentManagerTest {
;
}
+ private void verifyTimeResult(String expected, String actual) {
+ if ((expected == null && actual == null) || ((expected != null && actual
== null) ||
+ (expected == null && actual != null))) {
+ assertEquals(expected, actual);
+ } else {
+ DateTime expectedTime = DateTime.parse(expected);
+ DateTime actualTime = DateTime.parse(actual);
+ assertTrue(expectedTime.isEqual(actualTime));
+ }
+ }
+
private Object buildFromJsonFile(Object obj, String filePath) throws
SubmarineException {
Gson gson = new GsonBuilder().create();
try (Reader reader =
Files.newBufferedReader(getCustomJobSpecFile(filePath).toPath(),
diff --git
a/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/SubmarineAgent.java
b/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/SubmarineAgent.java
index 148ef05..7e1f4f6 100644
---
a/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/SubmarineAgent.java
+++
b/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/SubmarineAgent.java
@@ -79,7 +79,6 @@ public class SubmarineAgent {
customResourceType, customResourceName, customResourceId);
agent.start();
-
}
}
diff --git
a/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/handler/NotebookHandler.java
b/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/handler/NotebookHandler.java
index 6d43a5a..fa4fb51 100644
---
a/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/handler/NotebookHandler.java
+++
b/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/handler/NotebookHandler.java
@@ -32,8 +32,6 @@ import
org.apache.submarine.server.submitter.k8s.model.NotebookCR;
import org.apache.submarine.server.submitter.k8s.model.NotebookCRList;
import org.apache.submarine.server.submitter.k8s.util.NotebookUtils;
import io.kubernetes.client.util.generic.GenericKubernetesApi;
-
-
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -53,6 +51,7 @@ public class NotebookHandler extends CustomResourceHandler {
private GenericKubernetesApi<NotebookCR, NotebookCRList> notebookCRClient;
private String uid;
+
public NotebookHandler() throws IOException {
super();
}
@@ -88,8 +87,10 @@ public class NotebookHandler extends CustomResourceHandler {
this.uid = podList.getItems().get(podList.getItems().size() -
1).getMetadata().getUid();
+
listOptions = new ListOptions();
String fieldSelector = String.format("involvedObject.uid=%s", this.uid);
+
listOptions.setFieldSelector(fieldSelector);
watcher = eventClient.watch(namespace, listOptions);
@@ -105,6 +106,7 @@ public class NotebookHandler extends CustomResourceHandler {
while (true) {
for (Response<CoreV1Event> event: watcher) {
String reason = event.object.getReason();
+
Object object = null;
try {
switch (reason) {
diff --git
a/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/handler/PyTorchJobHandler.java
b/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/handler/PyTorchJobHandler.java
new file mode 100644
index 0000000..6e744b5
--- /dev/null
+++
b/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/handler/PyTorchJobHandler.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.k8s.agent.handler;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.submarine.server.api.common.CustomResourceType;
+import org.apache.submarine.server.api.experiment.Experiment;
+import org.apache.submarine.server.k8s.agent.util.RestClient;
+import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJob;
+import
org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJobList;
+import org.apache.submarine.server.submitter.k8s.util.MLJobConverter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.gson.Gson;
+import com.google.gson.reflect.TypeToken;
+
+import io.kubernetes.client.openapi.ApiException;
+import io.kubernetes.client.openapi.models.CoreV1Event;
+import io.kubernetes.client.openapi.models.V1JobCondition;
+import io.kubernetes.client.util.Watch.Response;
+import io.kubernetes.client.util.Watch;
+import io.kubernetes.client.util.Watchable;
+import io.kubernetes.client.util.generic.GenericKubernetesApi;
+import okhttp3.Call;
+
+public class PyTorchJobHandler extends CustomResourceHandler {
+ private static final Logger LOG =
LoggerFactory.getLogger(PyTorchJobHandler.class);
+ private GenericKubernetesApi<PyTorchJob, PyTorchJobList> pytorchJobClient;
+ private Watchable<CoreV1Event> watcher;
+ public PyTorchJobHandler() throws IOException {
+ super();
+ }
+
+
+ @Override
+ public void init(String serverHost, Integer serverPort,
+ String namespace, String crName, String resourceId) {
+ this.serverHost = serverHost;
+ this.serverPort = serverPort;
+ this.namespace = namespace;
+ this.crName = crName;
+ this.resourceId = resourceId;
+ pytorchJobClient =
+ new GenericKubernetesApi<>(
+ PyTorchJob.class, PyTorchJobList.class,
+ PyTorchJob.CRD_PYTORCH_GROUP_V1,
PyTorchJob.CRD_PYTORCH_VERSION_V1,
+ PyTorchJob.CRD_PYTORCH_PLURAL_V1, client);
+ try {
+ String fieldSelector = String.format("involvedObject.name=%s",
resourceId);
+ LOG.info("fieldSelector:" + fieldSelector);
+ Call call = coreV1Api.listNamespacedEventCall(namespace, null, null,
null, fieldSelector,
+ null, null, null, null, null, true, null);
+
+ watcher = Watch.createWatch(client, call, new
TypeToken<Response<CoreV1Event>>(){}.getType());
+ } catch (ApiException e) {
+ e.printStackTrace();
+ }
+ restClient = new RestClient(serverHost, serverPort);
+ }
+
+ @Override
+ public void run() {
+ Gson gson = new Gson();
+ while (true) {
+ for (Response<CoreV1Event> event: watcher) {
+ PyTorchJob job = pytorchJobClient.get(this.namespace,
this.resourceId).getObject();
+ List<V1JobCondition> conditionList = job.getStatus().getConditions();
+ V1JobCondition lastCondition = conditionList.get(conditionList.size()
- 1);
+ Experiment experiment = MLJobConverter.toJobFromMLJob(job);
+
+ this.restClient.callStatusUpdate(CustomResourceType.PyTorchJob,
resourceId, experiment);
+ LOG.info(String.format("receiving condition:%s",
lastCondition.getReason()));
+ LOG.info(String.format("current status of PyTorchjob:%s is %s",
resourceId, experiment.getStatus()));
+
+ switch (lastCondition.getReason()) {
+ case "PyTorchJobSucceeded":
+ LOG.info(String.format("PyTorchjob:%s is succeeded, exit",
this.resourceId));
+ return;
+ case "PyTorchJobFailed":
+ LOG.info(String.format("PyTorchjob:%s is failed, exit",
this.resourceId));
+ return;
+ default:
+ break;
+ }
+
+ }
+ }
+ }
+}
diff --git
a/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/handler/TFJobHandler.java
b/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/handler/TFJobHandler.java
new file mode 100644
index 0000000..13b157f
--- /dev/null
+++
b/submarine-server/server-submitter/submarine-k8s-agent/src/main/java/org/apache/submarine/server/k8s/agent/handler/TFJobHandler.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.k8s.agent.handler;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.submarine.server.api.common.CustomResourceType;
+import org.apache.submarine.server.api.experiment.Experiment;
+import org.apache.submarine.server.k8s.agent.util.RestClient;
+import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJob;
+import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJobList;
+import org.apache.submarine.server.submitter.k8s.util.MLJobConverter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.gson.reflect.TypeToken;
+
+import io.kubernetes.client.openapi.ApiException;
+import io.kubernetes.client.openapi.models.CoreV1Event;
+import io.kubernetes.client.openapi.models.V1JobCondition;
+import io.kubernetes.client.util.Watch.Response;
+import io.kubernetes.client.util.Watch;
+import io.kubernetes.client.util.Watchable;
+import io.kubernetes.client.util.generic.GenericKubernetesApi;
+import okhttp3.Call;
+
+public class TFJobHandler extends CustomResourceHandler {
+ private static final Logger LOG =
LoggerFactory.getLogger(TFJobHandler.class);
+ private GenericKubernetesApi<TFJob, TFJobList> tfJobClient;
+ private Watchable<CoreV1Event> watcher;
+ public TFJobHandler() throws IOException {
+ super();
+ }
+
+
+ @Override
+ public void init(String serverHost, Integer serverPort,
+ String namespace, String crName, String resourceId) {
+ this.serverHost = serverHost;
+ this.serverPort = serverPort;
+ this.namespace = namespace;
+ this.crName = crName;
+ this.resourceId = resourceId;
+ tfJobClient =
+ new GenericKubernetesApi<>(
+ TFJob.class, TFJobList.class,
+ TFJob.CRD_TF_GROUP_V1, TFJob.CRD_TF_VERSION_V1,
+ TFJob.CRD_TF_PLURAL_V1, client);
+ try {
+ String fieldSelector = String.format("involvedObject.name=%s",
resourceId);
+ LOG.info("fieldSelector:" + fieldSelector);
+ Call call = coreV1Api.listNamespacedEventCall(namespace, null, null,
null, fieldSelector,
+ null, null, null, null, null, true, null);
+
+ watcher = Watch.createWatch(client, call, new
TypeToken<Response<CoreV1Event>>(){}.getType());
+ } catch (ApiException e) {
+ e.printStackTrace();
+ }
+ restClient = new RestClient(serverHost, serverPort);
+ }
+
+ @Override
+ public void run() {
+
+ while (true) {
+ for (Response<CoreV1Event> event: watcher) {
+ TFJob job = tfJobClient.get(this.namespace,
this.resourceId).getObject();
+ List<V1JobCondition> conditionList = job.getStatus().getConditions();
+ V1JobCondition lastCondition = conditionList.get(conditionList.size()
- 1);
+ Experiment experiment = MLJobConverter.toJobFromMLJob(job);
+
+ this.restClient.callStatusUpdate(CustomResourceType.TFJob, resourceId,
experiment);
+ LOG.info(String.format("receiving condition:%s",
lastCondition.getReason()));
+ LOG.info(String.format("current status of tfjob:%s is %s", resourceId,
experiment.getStatus()));
+
+ switch (lastCondition.getReason()) {
+ case "TFJobSucceeded":
+ LOG.info(String.format("TfJob:%s is succeeded, exit",
this.resourceId));
+ return;
+ case "TFJobFailed":
+ LOG.info(String.format("TfJob:%s is failed, exit",
this.resourceId));
+ return;
+ default:
+ break;
+ }
+
+ }
+ }
+ }
+}
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/K8sSubmitter.java
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/K8sSubmitter.java
index dee44f1..ac094ff 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/K8sSubmitter.java
+++
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/K8sSubmitter.java
@@ -71,6 +71,7 @@ import org.apache.submarine.server.k8s.utils.K8sUtils;
import org.apache.submarine.serve.utils.IstioConstants;
import org.apache.submarine.serve.utils.SeldonConstants;
import org.apache.submarine.server.api.Submitter;
+import org.apache.submarine.server.api.common.CustomResourceType;
import org.apache.submarine.server.api.exception.InvalidSpecException;
import org.apache.submarine.server.api.experiment.Experiment;
import org.apache.submarine.server.api.experiment.ExperimentLog;
@@ -82,6 +83,7 @@ import org.apache.submarine.server.api.notebook.Notebook;
import org.apache.submarine.server.api.spec.ExperimentMeta;
import org.apache.submarine.server.api.spec.ExperimentSpec;
import org.apache.submarine.server.api.spec.NotebookSpec;
+import org.apache.submarine.server.submitter.k8s.model.AgentPod;
import org.apache.submarine.server.submitter.k8s.model.MLJob;
import org.apache.submarine.server.submitter.k8s.model.NotebookCR;
import org.apache.submarine.server.submitter.k8s.model.NotebookCRList;
@@ -260,12 +262,19 @@ public class K8sSubmitter implements Submitter {
MLJob mlJob = ExperimentSpecParser.parseJob(spec);
mlJob.getMetadata().setNamespace(getServerNamespace());
mlJob.getMetadata().setOwnerReferences(OwnerReferenceUtils.getOwnerReference());
-
+ AgentPod agentPod = new AgentPod(getServerNamespace(),
spec.getMeta().getName(),
+ mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
+ ? CustomResourceType.TFJob : CustomResourceType.PyTorchJob,
+ spec.getMeta().getExperimentId());
+
+
Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
? tfJobClient.create(getServerNamespace(), (TFJob) mlJob,
new CreateOptions()).throwsApiException().getObject()
: pyTorchJobClient.create(getServerNamespace(), (PyTorchJob)
mlJob,
new CreateOptions()).throwsApiException().getObject();
+
+ V1Pod agentPodResult =
podClient.create(agentPod).throwsApiException().getObject();
experiment = parseExperimentResponseObject(object,
ParseOp.PARSE_OP_RESULT);
} catch (InvalidSpecException e) {
LOG.error("K8s submitter: parse Job object failed by " + e.getMessage(),
e);
@@ -282,6 +291,7 @@ public class K8sSubmitter implements Submitter {
public Experiment findExperiment(ExperimentSpec spec) throws
SubmarineRuntimeException {
Experiment experiment;
try {
+
MLJob mlJob = ExperimentSpecParser.parseJob(spec);
mlJob.getMetadata().setNamespace(getServerNamespace());
Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
@@ -289,6 +299,7 @@ public class K8sSubmitter implements Submitter {
.throwsApiException().getObject()
: pyTorchJobClient.get(getServerNamespace(),
mlJob.getMetadata().getName())
.throwsApiException().getObject();
+
experiment = parseExperimentResponseObject(object,
ParseOp.PARSE_OP_RESULT);
} catch (InvalidSpecException e) {
@@ -306,7 +317,6 @@ public class K8sSubmitter implements Submitter {
try {
MLJob mlJob = ExperimentSpecParser.parseJob(spec);
mlJob.getMetadata().setNamespace(getServerNamespace());
-
Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
? tfJobClient.patch(getServerNamespace(),
mlJob.getMetadata().getName(),
V1Patch.PATCH_FORMAT_JSON_PATCH,
@@ -330,7 +340,6 @@ public class K8sSubmitter implements Submitter {
try {
MLJob mlJob = ExperimentSpecParser.parseJob(spec);
mlJob.getMetadata().setNamespace(getServerNamespace());
-
Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
? tfJobClient.delete(getServerNamespace(),
mlJob.getMetadata().getName(),
MLJobConverter.toDeleteOptionsFromMLJob(mlJob))
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/AgentPod.java
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/AgentPod.java
new file mode 100644
index 0000000..7787c64
--- /dev/null
+++
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/AgentPod.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.submitter.k8s.model;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.submarine.commons.utils.SubmarineConfiguration;
+import org.apache.submarine.server.api.common.CustomResourceType;
+
+import io.kubernetes.client.openapi.models.V1Container;
+import io.kubernetes.client.openapi.models.V1EnvVar;
+import io.kubernetes.client.openapi.models.V1ObjectMeta;
+import io.kubernetes.client.openapi.models.V1Pod;
+import io.kubernetes.client.openapi.models.V1PodSpec;
+
+public class AgentPod extends V1Pod{
+ private static SubmarineConfiguration conf =
SubmarineConfiguration.getInstance();
+ private static final String AGENT_IMAGE =
"apache/submarine:agent-0.7.0-SNAPSHOT";
+ private static final String CONTAINER_NAME = "agent";
+ public AgentPod(String namespace, String name,
+ CustomResourceType type,
+ String resourceId) {
+ super();
+ V1ObjectMeta meta = new V1ObjectMeta();
+
+ meta.setName(
+ String.format("%s-%s-%s-%s", type.toString().toLowerCase(), name,
+ resourceId.toLowerCase(), CONTAINER_NAME));
+ meta.setNamespace(namespace);
+ this.setMetadata(meta);
+
+ V1PodSpec spec = new V1PodSpec();
+ List<V1Container> containers = spec.getContainers();
+ V1Container agentContainer = new V1Container();
+ agentContainer.setName(CONTAINER_NAME);
+ agentContainer.setImage(AGENT_IMAGE);
+
+ List<V1EnvVar> envVarList = new ArrayList<>();
+ V1EnvVar crTypeVar = new V1EnvVar();
+ crTypeVar.setName("CUSTOM_RESOURCE_TYPE");
+ crTypeVar.setValue(type.toString());
+
+ V1EnvVar crNameVar = new V1EnvVar();
+ crNameVar.setName("CUSTOM_RESOURCE_NAME");
+ crNameVar.setValue(name);
+
+ V1EnvVar namespaceVar = new V1EnvVar();
+ namespaceVar.setName("NAMESPACE");
+ namespaceVar.setValue(namespace);
+
+ V1EnvVar serverHostVar = new V1EnvVar();
+ serverHostVar.setName("SERVER_HOST");
+ serverHostVar.setValue(conf.getServerServiceName());
+
+ V1EnvVar serverPortVar = new V1EnvVar();
+ serverPortVar.setName("SERVER_PORT");
+ serverPortVar.setValue(String.valueOf(conf.getServerPort()));
+
+ V1EnvVar customResourceIdVar = new V1EnvVar();
+ customResourceIdVar.setName("CUSTOM_RESOURCE_ID");
+ customResourceIdVar.setValue(resourceId);
+
+ envVarList.add(crTypeVar);
+ envVarList.add(crNameVar);
+ envVarList.add(namespaceVar);
+ envVarList.add(serverHostVar);
+ envVarList.add(serverPortVar);
+ envVarList.add(customResourceIdVar);
+
+ agentContainer.env(envVarList);
+
+ containers.add(agentContainer);
+
+ spec.setRestartPolicy("OnFailure");
+ this.setSpec(spec);
+ }
+}
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJobSpec.java
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJobSpec.java
index a840bb6..3e93d12 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJobSpec.java
+++
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJobSpec.java
@@ -31,7 +31,8 @@ public class PyTorchJobSpec {
*/
@SerializedName("pytorchReplicaSpecs")
private Map<PyTorchJobReplicaType, MLJobReplicaSpec> replicaSpecs;
-
+ @SerializedName("backoffLimit")
+ private Integer backoffLimit = 3;
/**
* Get the replica specs.
*
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
index 9c60207..b1282b7 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
+++
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
@@ -20,12 +20,15 @@
package org.apache.submarine.server.submitter.k8s.model.tfjob;
import com.google.gson.annotations.SerializedName;
+
+import io.kubernetes.client.common.KubernetesObject;
+
import org.apache.submarine.server.submitter.k8s.model.MLJob;
/**
* It's the tf-operator's entry model.
*/
-public class TFJob extends MLJob {
+public class TFJob extends MLJob implements KubernetesObject {
public static final String CRD_TF_KIND_V1 = "TFJob";
public static final String CRD_TF_PLURAL_V1 = "tfjobs";
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJobSpec.java
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJobSpec.java
index ae21b81..a7f4606 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJobSpec.java
+++
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJobSpec.java
@@ -33,7 +33,8 @@ public class TFJobSpec {
*/
@SerializedName("tfReplicaSpecs")
private Map<TFJobReplicaType, MLJobReplicaSpec> tfReplicaSpecs;
-
+ @SerializedName("backoffLimit")
+ private Integer backoffLimit = 3;
/**
* Get the replica specs.
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/parser/ExperimentSpecParser.java
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/parser/ExperimentSpecParser.java
index 1c5edbe..710f8aa 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/parser/ExperimentSpecParser.java
+++
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/parser/ExperimentSpecParser.java
@@ -88,6 +88,7 @@ public class ExperimentSpecParser {
public static PyTorchJobSpec parsePyTorchJobSpec(ExperimentSpec
experimentSpec)
throws InvalidSpecException {
PyTorchJobSpec pyTorchJobSpec = new PyTorchJobSpec();
+
Map<PyTorchJobReplicaType, MLJobReplicaSpec> replicaSpecMap = new
HashMap<>();
for (Map.Entry<String, ExperimentTaskSpec> entry :
experimentSpec.getSpec().entrySet()) {
String replicaType = entry.getKey();
@@ -95,7 +96,9 @@ public class ExperimentSpecParser {
if (PyTorchJobReplicaType.isSupportedReplicaType(replicaType)) {
MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec();
replicaSpec.setReplicas(taskSpec.getReplicas());
- replicaSpec.setTemplate(parseTemplateSpec(taskSpec, experimentSpec));
+ V1PodTemplateSpec podTemplateSpec = parseTemplateSpec(taskSpec,
experimentSpec);
+
+ replicaSpec.setTemplate(podTemplateSpec);
replicaSpecMap.put(PyTorchJobReplicaType.valueOf(replicaType),
replicaSpec);
} else {
throw new InvalidSpecException("Unrecognized replica type name: " +
@@ -122,20 +125,25 @@ public class ExperimentSpecParser {
Map<String, String> labels = new HashMap<>();
labels.put(ExperimentMeta.SUBMARINE_EXPERIMENT_NAME,
experimentSpec.getMeta().getName());
meta.setLabels(labels);
- meta.setNamespace(experimentSpec.getMeta().getNamespace());
+
return meta;
}
- private static TFJobSpec parseTFJobSpec(ExperimentSpec experimentSpec)
throws InvalidSpecException {
+ private static TFJobSpec parseTFJobSpec(ExperimentSpec experimentSpec)
+ throws InvalidSpecException {
TFJobSpec tfJobSpec = new TFJobSpec();
Map<TFJobReplicaType, MLJobReplicaSpec> replicaSpecMap = new HashMap<>();
+
for (Map.Entry<String, ExperimentTaskSpec> entry :
experimentSpec.getSpec().entrySet()) {
String replicaType = entry.getKey();
ExperimentTaskSpec taskSpec = entry.getValue();
+
if (TFJobReplicaType.isSupportedReplicaType(replicaType)) {
MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec();
replicaSpec.setReplicas(taskSpec.getReplicas());
- replicaSpec.setTemplate(parseTemplateSpec(taskSpec, experimentSpec));
+ V1PodTemplateSpec podTemplateSpec = parseTemplateSpec(taskSpec,
experimentSpec);
+
+ replicaSpec.setTemplate(podTemplateSpec);
replicaSpecMap.put(TFJobReplicaType.valueOf(replicaType), replicaSpec);
} else {
throw new InvalidSpecException("Unrecognized replica type name: " +
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]