This is an automated email from the ASF dual-hosted git repository.
cdmikechen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new 3d2a3e67 SUBMARINE-1283. copy data for experiment before it running
via distcp to minio
3d2a3e67 is described below
commit 3d2a3e67530d2836ae9be8e5f9549bb30dcc53ca
Author: FatalLin <[email protected]>
AuthorDate: Tue Sep 20 22:16:50 2022 +0800
SUBMARINE-1283. copy data for experiment before it running via distcp to
minio
### What is this PR for?
This is a prototype of experiment prehandler, once the required arguments
has been set, submarine would put an init-container on the main pod. The init
container would copy the source data to minio to the path
/submarine/${experimentId}.
note: the init container would be add under:
TFJob: ps
PytorchJob: master
XGboostJob: master
### What type of PR is it?
Feature
### Todos
add a housekeeping container to clean the copied data.
### What is the Jira issue?
https://issues.apache.org/jira/browse/SUBMARINE-1283
### How should this be tested?
Should add another test for it.
### Screenshots (if appropriate)
### Questions:
* Do the license files need updating? No
* Are there breaking changes for older versions? No
* Does this need new documentation? Yes
Author: FatalLin <[email protected]>
Signed-off-by: Xiang Chen <[email protected]>
Closes #989 from FatalLin/SUBMARINE-1283 and squashes the following commits:
ef804903 [FatalLin] fix conflict
d816b880 [FatalLin] modify test cases
171da696 [FatalLin] code polish
36297f31 [FatalLin] polish code
2c62fb7c [FatalLin] fix
17135f2a [FatalLin] fix script
314a3866 [FatalLin] fix build script
028eff7d [FatalLin] fix script
79d91c88 [FatalLin] prototype of experiment prehandler
eaa25e1d [FatalLin] Merge branch 'master' of
https://github.com/apache/submarine into SUBMARINE-1283
d81841b1 [FatalLin] Merge branch 'master' of
https://github.com/apache/submarine into SUBMARINE-1283
44473203 [FatalLin] for debugging
---
.github/scripts/build-image-locally-v3.sh | 4 +-
.github/scripts/build-image-locally.sh | 4 +-
.../docker-images/experiment-prehandler/Dockerfile | 3 +-
.../docker-images/experiment-prehandler/build.sh | 10 +++
.../fs_prehandler/hdfs_prehandler.py | 42 +++++++----
.../submarine/server/api/spec/ExperimentSpec.java | 9 +++
.../server/submitter/k8s/model/mljob/MLJob.java | 86 ++++++++++++++++++++++
.../submitter/k8s/model/pytorchjob/PyTorchJob.java | 7 ++
.../server/submitter/k8s/model/tfjob/TFJob.java | 12 +++
.../submitter/k8s/model/xgboostjob/XGBoostJob.java | 8 +-
.../submitter/k8s/ExperimentSpecParserTest.java | 44 ++++++++++-
.../src/test/resources/pytorch_job_req.json | 7 ++
.../src/test/resources/tf_mnist_req.json | 7 ++
.../src/test/resources/xgboost_job_req.json | 7 ++
14 files changed, 231 insertions(+), 19 deletions(-)
diff --git a/.github/scripts/build-image-locally-v3.sh
b/.github/scripts/build-image-locally-v3.sh
index 81203af7..00fdb770 100755
--- a/.github/scripts/build-image-locally-v3.sh
+++ b/.github/scripts/build-image-locally-v3.sh
@@ -17,12 +17,14 @@
#
SUBMARINE_VERSION="0.8.0-SNAPSHOT"
-FOLDER_LIST=("database" "mlflow" "submarine" "operator-v3")
+FOLDER_LIST=("database" "mlflow" "submarine" "operator-v3" "agent"
"experiment-prehandler")
IMAGE_LIST=(
"apache/submarine:database-${SUBMARINE_VERSION}"
"apache/submarine:mlflow-${SUBMARINE_VERSION}"
"apache/submarine:server-${SUBMARINE_VERSION}"
"apache/submarine:operator-${SUBMARINE_VERSION}"
+ "apache/submarine:agent-${SUBMARINE_VERSION}"
+ "apache/submarine:experiment-prehandler-${SUBMARINE_VERSION}"
)
for i in "${!IMAGE_LIST[@]}"
diff --git a/.github/scripts/build-image-locally.sh
b/.github/scripts/build-image-locally.sh
index a53de5b3..4d35d690 100755
--- a/.github/scripts/build-image-locally.sh
+++ b/.github/scripts/build-image-locally.sh
@@ -17,12 +17,14 @@
#
SUBMARINE_VERSION="0.8.0-SNAPSHOT"
-FOLDER_LIST=("database" "mlflow" "submarine" "operator")
+FOLDER_LIST=("database" "mlflow" "submarine" "operator" "agent"
"experiment-prehandler")
IMAGE_LIST=(
"apache/submarine:database-${SUBMARINE_VERSION}"
"apache/submarine:mlflow-${SUBMARINE_VERSION}"
"apache/submarine:server-${SUBMARINE_VERSION}"
"apache/submarine:operator-${SUBMARINE_VERSION}"
+ "apache/submarine:agent-${SUBMARINE_VERSION}"
+ "apache/submarine:experiment-prehandler-${SUBMARINE_VERSION}"
)
for i in "${!IMAGE_LIST[@]}"
diff --git a/dev-support/docker-images/experiment-prehandler/Dockerfile
b/dev-support/docker-images/experiment-prehandler/Dockerfile
index 87307d07..7a6c7e69 100644
--- a/dev-support/docker-images/experiment-prehandler/Dockerfile
+++ b/dev-support/docker-images/experiment-prehandler/Dockerfile
@@ -21,7 +21,8 @@ RUN apt-get -y install python3 python3-pip bash tini
ADD ./tmp/hadoop-3.3.3.tar.gz /opt/
ADD ./tmp/submarine-experiment-prehandler /opt/submarine-experiment-prehandler
-
+ADD ./tmp/hadoop-aws-3.3.3.jar /opt/hadoop-3.3.3/share/hadoop/hdfs
+ADD ./tmp/aws-java-sdk-bundle-1.12.267.jar /opt/hadoop-3.3.3/share/hadoop/hdfs
ENV HADOOP_HOME=/opt/hadoop-3.3.3
ENV ARROW_LIBHDFS_DIR=/opt/hadoop-3.3.3/lib/native
diff --git a/dev-support/docker-images/experiment-prehandler/build.sh
b/dev-support/docker-images/experiment-prehandler/build.sh
index fcdedb05..c1e94a37 100755
--- a/dev-support/docker-images/experiment-prehandler/build.sh
+++ b/dev-support/docker-images/experiment-prehandler/build.sh
@@ -19,6 +19,12 @@ set -euxo pipefail
SUBMARINE_VERSION=0.8.0-SNAPSHOT
SUBMARINE_IMAGE_NAME="apache/submarine:experiment-prehandler-${SUBMARINE_VERSION}"
+if [ -L ${BASH_SOURCE-$0} ]; then
+ PWD=$(dirname $(readlink "${BASH_SOURCE-$0}"))
+else
+ PWD=$(dirname ${BASH_SOURCE-$0})
+fi
+
export CURRENT_PATH=$(cd "${PWD}">/dev/null; pwd)
export SUBMARINE_HOME=${CURRENT_PATH}/../../..
@@ -33,7 +39,11 @@ trap "test -f $tmpfile && rm $tmpfile" RETURN
curl -L -o $tmpfile ${HADOOP_TAR_URL}
mv $tmpfile ${CURRENT_PATH}/tmp/hadoop-3.3.3.tar.gz
+curl -L -o ${CURRENT_PATH}/tmp/hadoop-aws-3.3.3.jar
https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.3.3/hadoop-aws-3.3.3.jar
+curl -L -o ${CURRENT_PATH}/tmp/aws-java-sdk-bundle-1.12.267.jar
https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.267/aws-java-sdk-bundle-1.12.267.jar
+
echo "Start building the ${SUBMARINE_IMAGE_NAME} docker image ..."
+cd ${CURRENT_PATH}
docker build -t ${SUBMARINE_IMAGE_NAME} .
# clean temp file
diff --git a/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py
b/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py
index 1a13f1b2..c138b61b 100644
--- a/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py
+++ b/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py
@@ -15,6 +15,7 @@
import logging
import os
+import subprocess
from fs_prehandler import FsPreHandler
from fsspec.implementations.arrow import HadoopFileSystem
@@ -22,23 +23,36 @@ from fsspec.implementations.arrow import HadoopFileSystem
class HDFSPreHandler(FsPreHandler):
def __init__(self):
- self.hdfs_host = os.environ['HDFS_HOST']
- self.hdfs_port = int(os.environ['HDFS_PORT'])
- self.hdfs_source = os.environ['HDFS_SOURCE']
- self.dest_path = os.environ['DEST_PATH']
- self.enable_kerberos = os.environ['ENABLE_KERBEROS']
+ self.hdfs_host=os.environ['HDFS_HOST']
+ self.hdfs_port=os.environ['HDFS_PORT']
+ self.hdfs_source=os.environ['HDFS_SOURCE']
+ self.enable_kerberos=os.environ['ENABLE_KERBEROS']
+ self.hadoop_home=os.environ['HADOOP_HOME']
+ self.dest_minio_host=os.environ['DEST_MINIO_HOST']
+ self.dest_minio_port=os.environ['DEST_MINIO_PORT']
+ self.minio_access_key=os.environ['MINIO_ACCESS_KEY']
+ self.minio_secert_key=os.environ['MINIO_SECRET_KEY']
+ self.experiment_id=os.environ['EXPERIMENT_ID']
logging.info('HDFS_HOST:%s' % self.hdfs_host)
- logging.info('HDFS_PORT:%d' % self.hdfs_port)
+ logging.info('HDFS_PORT:%s' % self.hdfs_port)
logging.info('HDFS_SOURCE:%s' % self.hdfs_source)
- logging.info('DEST_PATH:%s' % self.dest_path)
+ logging.info('MINIO_DEST_HOST:%s' % self.dest_minio_host)
+ logging.info('MINIO_DEST_PORT:%s' % self.dest_minio_port)
logging.info('ENABLE_KERBEROS:%s' % self.enable_kerberos)
-
- self.fs = HadoopFileSystem(host=self.hdfs_host, port=self.hdfs_port)
+ logging.info('EXPERIMENT_ID:%s' % self.experiment_id)
def process(self):
- self.fs.get(self.hdfs_source, self.dest_path, recursive=True)
- logging.info(
- 'fetch data from hdfs://%s:%d/%s to %s complete'
- % (self.hdfs_host, self.hdfs_port, self.hdfs_source,
self.dest_path)
- )
+ dest_path = 'submarine/experiment/' + self.experiment_id
+ p = subprocess.run([self.hadoop_home+'/bin/hadoop', 'distcp'
+ , '-Dfs.s3a.endpoint=http://' + self.dest_minio_host + ':' +
self.dest_minio_port + '/'
+ , '-Dfs.s3a.access.key=' + self.minio_access_key
+ , '-Dfs.s3a.secret.key=' + self.minio_secert_key
+ , '-Dfs.s3a.path.style.access=true'
+ , 'hdfs://'+self.hdfs_host + ':' + self.hdfs_port + '/' +
self.hdfs_source
+ , 's3a://' + dest_path])
+
+ if p.returncode == 0:
+ logging.info('fetch data from hdfs://%s:%s/%s to %s complete' %
(self.hdfs_host, self.hdfs_port, self.hdfs_source, dest_path))
+ else:
+ raise Exception( 'error occured when fetching data from
hdfs://%s:%s/%s to %s' % (self.hdfs_host, self.hdfs_port, self.hdfs_source,
dest_path) )
diff --git
a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java
b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java
index b0c283ab..8c3024d6 100644
---
a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java
+++
b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java
@@ -28,6 +28,7 @@ public class ExperimentSpec {
private ExperimentMeta meta;
private EnvironmentSpec environment;
private Map<String, ExperimentTaskSpec> spec;
+ private Map<String, String> experimentHandlerSpec;
private CodeSpec code;
public ExperimentSpec() {}
@@ -63,6 +64,14 @@ public class ExperimentSpec {
public void setCode(CodeSpec code) {
this.code = code;
}
+
+ public Map<String, String> getExperimentHandlerSpec() {
+ return experimentHandlerSpec;
+ }
+
+ public void setExperimentHandlerSpec(Map<String, String>
experimentHandlerSpec) {
+ this.experimentHandlerSpec = experimentHandlerSpec;
+ }
@Override
public String toString() {
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java
index 45c174cb..f026bc3d 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java
+++
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java
@@ -22,10 +22,17 @@ package
org.apache.submarine.server.submitter.k8s.model.mljob;
import com.google.gson.JsonSyntaxException;
import com.google.gson.annotations.SerializedName;
import io.kubernetes.client.common.KubernetesObject;
+import io.kubernetes.client.openapi.models.V1Container;
+import io.kubernetes.client.openapi.models.V1EnvVar;
import io.kubernetes.client.openapi.models.V1JobStatus;
import io.kubernetes.client.openapi.models.V1ObjectMeta;
import io.kubernetes.client.openapi.models.V1ObjectMetaBuilder;
import io.kubernetes.client.openapi.models.V1Status;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
import org.apache.submarine.server.api.common.CustomResourceType;
import org.apache.submarine.server.api.experiment.Experiment;
@@ -205,6 +212,85 @@ public abstract class MLJob implements KubernetesObject,
K8sResource<Experiment>
this.experimentId = experimentId;
}
+ public V1Container getExperimentHandlerContainer(ExperimentSpec spec) {
+ Map<String, String> handlerSpec = spec.getExperimentHandlerSpec();
+
+ if (checkExperimentHanderArg(handlerSpec)) {
+ V1Container initContainer = new V1Container();
+
initContainer.setImage("apache/submarine:experiment-prehandler-0.8.0-SNAPSHOT");
+ initContainer.setName("ExperimentHandler".toLowerCase());
+ List<V1EnvVar> envVar = new ArrayList<>();
+
+ V1EnvVar hdfsHostVar = new V1EnvVar();
+ hdfsHostVar.setName("HDFS_HOST");
+ hdfsHostVar.setValue(handlerSpec.get("HDFS_HOST"));
+ envVar.add(hdfsHostVar);
+
+ V1EnvVar hdfsPortVar = new V1EnvVar();
+ hdfsPortVar.setName("HDFS_PORT");
+ hdfsPortVar.setValue(handlerSpec.get("HDFS_PORT"));
+ envVar.add(hdfsPortVar);
+
+ V1EnvVar hdfsSourceVar = new V1EnvVar();
+ hdfsSourceVar.setName("HDFS_SOURCE");
+ hdfsSourceVar.setValue(handlerSpec.get("HDFS_SOURCE"));
+ envVar.add(hdfsSourceVar);
+
+ V1EnvVar hdfsEnableKerberosVar = new V1EnvVar();
+ hdfsEnableKerberosVar.setName("ENABLE_KERBEROS");
+ hdfsEnableKerberosVar.setValue(handlerSpec.get("ENABLE_KERBEROS"));
+ envVar.add(hdfsEnableKerberosVar);
+
+ V1EnvVar destMinIOHostVar = new V1EnvVar();
+ destMinIOHostVar.setName("DEST_MINIO_HOST");
+ destMinIOHostVar.setValue("submarine-minio-service");
+ envVar.add(destMinIOHostVar);
+
+ V1EnvVar destMinIOPortVar = new V1EnvVar();
+ destMinIOPortVar.setName("DEST_MINIO_PORT");
+ destMinIOPortVar.setValue("9000");
+ envVar.add(destMinIOPortVar);
+
+ V1EnvVar minIOAccessKeyVar = new V1EnvVar();
+ minIOAccessKeyVar.setName("MINIO_ACCESS_KEY");
+ minIOAccessKeyVar.setValue("submarine_minio");
+ envVar.add(minIOAccessKeyVar);
+
+ V1EnvVar minIOSecretKeyVar = new V1EnvVar();
+ minIOSecretKeyVar.setName("MINIO_SECRET_KEY");
+ minIOSecretKeyVar.setValue("submarine_minio");
+ envVar.add(minIOSecretKeyVar);
+
+ V1EnvVar fileSystemTypeVar = new V1EnvVar();
+ fileSystemTypeVar.setName("FILE_SYSTEM_TYPE");
+ fileSystemTypeVar.setValue(handlerSpec.get("FILE_SYSTEM_TYPE"));
+ envVar.add(fileSystemTypeVar);
+
+ V1EnvVar experimentIdVar = new V1EnvVar();
+ experimentIdVar.setName("EXPERIMENT_ID");
+ experimentIdVar.setValue(this.experimentId);
+ envVar.add(experimentIdVar);
+
+ initContainer.setEnv(envVar);
+ return initContainer;
+ }
+ return null;
+ }
+
+ private boolean checkExperimentHanderArg(Map<String, String> handlerSpec) {
+ if (handlerSpec == null)
+ return false;
+ if (handlerSpec.get("FILE_SYSTEM_TYPE") == null)
+ return false;
+ else if (handlerSpec.get("FILE_SYSTEM_TYPE") == "HDFS") {
+ if ((handlerSpec.get("HDFS_HOST") == null) ||
(handlerSpec.get("HDFS_PORT") == null) ||
+ (handlerSpec.get("HDFS_SOURCE") == null) ||
(handlerSpec.get("ENABLE_KERBEROS") == null)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
/**
* Convert MLJob object to return Experiment object
*/
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java
index 7da916c5..7d38ed59 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java
+++
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java
@@ -22,6 +22,7 @@ package
org.apache.submarine.server.submitter.k8s.model.pytorchjob;
import com.google.gson.annotations.SerializedName;
import io.kubernetes.client.custom.V1Patch;
import io.kubernetes.client.openapi.ApiException;
+import io.kubernetes.client.openapi.models.V1Container;
import io.kubernetes.client.openapi.models.V1PodTemplateSpec;
import io.kubernetes.client.openapi.models.V1Status;
import io.kubernetes.client.util.generic.options.CreateOptions;
@@ -78,6 +79,7 @@ public class PyTorchJob extends MLJob {
throws InvalidSpecException {
PyTorchJobSpec pyTorchJobSpec = new PyTorchJobSpec();
+ V1Container initContainer =
this.getExperimentHandlerContainer(experimentSpec);
Map<PyTorchJobReplicaType, MLJobReplicaSpec> replicaSpecMap = new
HashMap<>();
for (Map.Entry<String, ExperimentTaskSpec> entry :
experimentSpec.getSpec().entrySet()) {
String replicaType = entry.getKey();
@@ -86,6 +88,11 @@ public class PyTorchJob extends MLJob {
MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec();
replicaSpec.setReplicas(taskSpec.getReplicas());
V1PodTemplateSpec podTemplateSpec =
ExperimentSpecParser.parseTemplateSpec(taskSpec, experimentSpec);
+
+ if (initContainer != null && replicaType.equals("Master")) {
+ podTemplateSpec.getSpec().addInitContainersItem(initContainer);
+ }
+
replicaSpec.setTemplate(podTemplateSpec);
replicaSpecMap.put(PyTorchJobReplicaType.valueOf(replicaType),
replicaSpec);
} else {
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
index ee9d0d9f..1b9fe5a3 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
+++
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
@@ -23,6 +23,7 @@ import com.google.gson.annotations.SerializedName;
import io.kubernetes.client.custom.V1Patch;
import io.kubernetes.client.openapi.ApiException;
+import io.kubernetes.client.openapi.models.V1Container;
import io.kubernetes.client.openapi.models.V1PodTemplateSpec;
import io.kubernetes.client.openapi.models.V1Status;
import io.kubernetes.client.util.generic.options.CreateOptions;
@@ -68,6 +69,17 @@ public class TFJob extends MLJob {
setGroup(CRD_TF_GROUP_V1);
// set spec
setSpec(parseTFJobSpec(experimentSpec));
+
+ V1Container initContainer =
this.getExperimentHandlerContainer(experimentSpec);
+ if (initContainer != null) {
+ Map<TFJobReplicaType, MLJobReplicaSpec> replicaSet =
this.getSpec().getReplicaSpecs();
+ if (replicaSet.keySet().contains(TFJobReplicaType.Ps)) {
+ MLJobReplicaSpec psSpec = replicaSet.get(TFJobReplicaType.Ps);
+ psSpec.getTemplate().getSpec().addInitContainersItem(initContainer);
+ } else {
+ throw new InvalidSpecException("PreHandler only support TFJob with PS
for now");
+ }
+ }
}
@Override
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java
index 0bacbfde..740b868c 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java
+++
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java
@@ -22,6 +22,7 @@ package
org.apache.submarine.server.submitter.k8s.model.xgboostjob;
import com.google.gson.annotations.SerializedName;
import io.kubernetes.client.custom.V1Patch;
import io.kubernetes.client.openapi.ApiException;
+import io.kubernetes.client.openapi.models.V1Container;
import io.kubernetes.client.openapi.models.V1PodTemplateSpec;
import io.kubernetes.client.openapi.models.V1Status;
import io.kubernetes.client.util.generic.options.CreateOptions;
@@ -75,11 +76,16 @@ public class XGBoostJob extends MLJob {
for (Map.Entry<String, ExperimentTaskSpec> entry :
experimentSpec.getSpec().entrySet()) {
String replicaType = entry.getKey();
ExperimentTaskSpec taskSpec = entry.getValue();
-
+ V1Container initContainer =
this.getExperimentHandlerContainer(experimentSpec);
if (XGBoostJobReplicaType.isSupportedReplicaType(replicaType)) {
MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec();
replicaSpec.setReplicas(taskSpec.getReplicas());
V1PodTemplateSpec podTemplateSpec =
ExperimentSpecParser.parseTemplateSpec(taskSpec, experimentSpec);
+
+ if (initContainer != null && replicaType.equals("Master")) {
+ podTemplateSpec.getSpec().addInitContainersItem(initContainer);
+ }
+
replicaSpec.setTemplate(podTemplateSpec);
replicaSpecMap.put(XGBoostJobReplicaType.valueOf(replicaType),
replicaSpec);
} else {
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
index e6f8b344..acbf48d5 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
+++
b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
@@ -27,6 +27,8 @@ import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
import io.kubernetes.client.openapi.models.V1ObjectMeta;
import io.kubernetes.client.openapi.models.V1Volume;
@@ -98,6 +100,7 @@ public class ExperimentSpecParserTest extends SpecBuilder {
validateReplicaSpec(experimentSpec, tfJob, TFJobReplicaType.Ps);
validateReplicaSpec(experimentSpec, tfJob, TFJobReplicaType.Worker);
+ validateExperimentHandlerMetadata(experimentSpec, tfJob);
}
@Test
@@ -142,6 +145,7 @@ public class ExperimentSpecParserTest extends SpecBuilder {
validateReplicaSpec(experimentSpec, pyTorchJob,
PyTorchJobReplicaType.Master);
validateReplicaSpec(experimentSpec, pyTorchJob,
PyTorchJobReplicaType.Worker);
+ validateExperimentHandlerMetadata(experimentSpec, pyTorchJob);
}
@Test
@@ -183,6 +187,7 @@ public class ExperimentSpecParserTest extends SpecBuilder {
validateReplicaSpec(experimentSpec, xgboostJob,
XGBoostJobReplicaType.Master);
validateReplicaSpec(experimentSpec, xgboostJob,
XGBoostJobReplicaType.Worker);
+ validateExperimentHandlerMetadata(experimentSpec, xgboostJob);
}
@Test
@@ -218,7 +223,44 @@ public class ExperimentSpecParserTest extends SpecBuilder {
Assert.assertEquals(K8sUtils.getNamespace(), actualMeta.getNamespace());
Assert.assertEquals(expectedMeta.getFramework().toLowerCase(),
actualFramework);
}
-
+
+ private void validateExperimentHandlerMetadata(ExperimentSpec experimentSpec,
+ MLJob mlJob) {
+
+ if (experimentSpec.getExperimentHandlerSpec() == null ||
+ experimentSpec.getExperimentHandlerSpec().isEmpty()) {
+ return;
+ }
+
+ V1Container initContainer = null;
+
+ MLJobReplicaSpec mlJobReplicaSpec = null;
+ if (mlJob instanceof PyTorchJob) {
+ mlJobReplicaSpec = ((PyTorchJob) mlJob).getSpec()
+ .getReplicaSpecs().get(PyTorchJobReplicaType.Master);
+ } else if (mlJob instanceof TFJob) {
+ mlJobReplicaSpec = ((TFJob) mlJob).getSpec()
+ .getReplicaSpecs().get(TFJobReplicaType.Ps);
+ } else if (mlJob instanceof XGBoostJob) {
+ mlJobReplicaSpec = ((XGBoostJob) mlJob).getSpec()
+ .getReplicaSpecs().get(XGBoostJobReplicaType.Master);
+ }
+ initContainer =
mlJobReplicaSpec.getTemplate().getSpec().getInitContainers().get(0);
+ Map<String, String> varMap = initContainer.getEnv().stream()
+ .collect(Collectors.toMap(V1EnvVar::getName, V1EnvVar::getValue));
+
Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("FILE_SYSTEM_TYPE")
+ , varMap.get("FILE_SYSTEM_TYPE"));
+
Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("HDFS_HOST")
+ , varMap.get("HDFS_HOST"));
+
Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("HDFS_PORT")
+ , varMap.get("HDFS_PORT"));
+
Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("HDFS_SOURCE")
+ , varMap.get("HDFS_SOURCE"));
+
Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("ENABLE_KERBEROS")
+ , varMap.get("ENABLE_KERBEROS"));
+ Assert.assertEquals(mlJob.getExperimentId(), varMap.get("EXPERIMENT_ID"));
+ }
+
private void validateReplicaSpec(ExperimentSpec experimentSpec,
MLJob mlJob, MLJobReplicaType type) {
MLJobReplicaSpec mlJobReplicaSpec = null;
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json
index ed1828fa..69b1101d 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json
+++
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json
@@ -22,5 +22,12 @@
"replicas": 2,
"resources": "cpu=1,memory=1024M"
}
+ },
+ experimentHandlerSpec": {
+ "FILE_SYSTEM_TYPE": "HDFS",
+ "HDFS_HOST": "127.0.0.1",
+ "HDFS_PORT": "9000",
+ "HDFS_SOURCE": "/tmp",
+ "ENABLE_KERBEROS": "false"
}
}
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json
index 2c806ddc..80ac9646 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json
+++
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json
@@ -20,5 +20,12 @@
"replicas": 2,
"resources": "cpu=2,memory=1024M,nvidia.com/gpu=1"
}
+ },
+ "experimentHandlerSpec": {
+ "FILE_SYSTEM_TYPE": "HDFS",
+ "HDFS_HOST": "127.0.0.1",
+ "HDFS_PORT": "9000",
+ "HDFS_SOURCE": "/tmp",
+ "ENABLE_KERBEROS": "false"
}
}
diff --git
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json
index c4ba97f4..7498ba6c 100644
---
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json
+++
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json
@@ -22,5 +22,12 @@
"replicas": 2,
"resources": "cpu=1,memory=1024M"
}
+ },
+ "experimentHandlerSpec": {
+ "FILE_SYSTEM_TYPE": "HDFS",
+ "HDFS_HOST": "127.0.0.1",
+ "HDFS_PORT": "9000",
+ "HDFS_SOURCE": "/tmp",
+ "ENABLE_KERBEROS": "false"
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]