This is an automated email from the ASF dual-hosted git repository.

yasith pushed a commit to branch sdk-batch-jobs
in repository https://gitbox.apache.org/repos/asf/airavata.git

commit e88f0ee3db597eedc43e1f5be9e4d8b8923dc241
Author: yasithdev <[email protected]>
AuthorDate: Thu Jul 31 05:28:08 2025 -0500

    implement missing functions for batch sdk
---
 .../airavata_experiments/airavata.py               | 67 +++++++++++++++-------
 .../airavata_experiments/plan.py                   | 16 +++---
 .../airavata_experiments/runtime.py                | 59 ++++++++++++++++---
 .../airavata_experiments/task.py                   | 10 +++-
 4 files changed, 116 insertions(+), 36 deletions(-)

diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
index 05e91ab494..362dbdcf55 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
@@ -36,6 +36,7 @@ from airavata.model.experiment.ttypes import ExperimentModel, 
ExperimentType, Us
 from airavata.model.scheduling.ttypes import 
ComputationalResourceSchedulingModel
 from airavata.model.data.replica.ttypes import DataProductModel, 
DataProductType, DataReplicaLocationModel, ReplicaLocationCategory
 from airavata.model.appcatalog.groupresourceprofile.ttypes import 
GroupResourceProfile
+from airavata.model.status.ttypes import JobStatus, JobState, 
ExperimentStatus, ExperimentState
 
 warnings.filterwarnings("ignore", category=DeprecationWarning)
 logger = logging.getLogger("airavata_sdk.clients")
@@ -446,7 +447,36 @@ class AiravataOperator:
     assert process_id is not None, f"Expected process_id, got {process_id}"
     url_path = os.path.join(process_id, remote_file)
     filemgr_svc_download_url = 
f"{self.filemgr_svc_url()}/download/live/{url_path}"
-  
+
+  def execute_cmd(self, agent_ref: str, cmd: str) -> bytes:
+    """
+    Execute a command on a remote directory of a storage resource
+    TODO add data_svc fallback
+
+    Return Path: /{project_name}/{experiment_name}
+
+    """
+    res = requests.post(f"{self.connection_svc_url()}/agent/execute/shell", 
json={
+        "agentId": agent_ref,
+        "envName": agent_ref,
+        "workingDir": ".",
+        "arguments": ["sh", "-c", f"{cmd} | base64 -w0"]
+    })
+    data = res.json()
+    if data["error"] is not None:
+      raise Exception(data["error"])
+    else:
+      exc_id = data["executionId"]
+      while True:
+        res = 
requests.get(f"{self.connection_svc_url()}/agent/execute/shell/{exc_id}")
+        data = res.json()
+        if data["executed"]:
+          content = data["responseString"]
+          import base64
+          content = base64.b64decode(content)
+          return content
+        time.sleep(1)
+
   def cat_file(self, process_id: str, agent_ref: str, sr_host: str, 
remote_file: str, remote_dir: str) -> bytes:
     """
     Download files from a remote directory of a storage resource to a local 
directory
@@ -674,14 +704,15 @@ class AiravataOperator:
     # wait until task begins, then get job id
     print(f"[AV] Experiment {experiment_name} WAITING until task begins...")
     job_id = job_state = None
-    while job_state is None:
+    while job_id in [None, "N/A"]:
       try:
         job_id, job_state = self.get_task_status(ex_id)
       except:
         time.sleep(2)
       else:
         time.sleep(2)
-    print(f"[AV] Experiment {experiment_name} - Task {job_state} with id: 
{job_id}")
+    assert job_state is not None, f"Job state is None for job id: {job_id}"
+    print(f"[AV] Experiment {experiment_name} - Task {job_state.name} with id: 
{job_id}")
 
     return LaunchState(
       experiment_id=ex_id,
@@ -692,14 +723,12 @@ class AiravataOperator:
       sr_host=storage.hostName,
     )
 
-  def get_experiment_status(self, experiment_id: str) -> Literal['CREATED', 
'VALIDATED', 'SCHEDULED', 'LAUNCHED', 'EXECUTING', 'CANCELING', 'CANCELED', 
'COMPLETED', 'FAILED']:
-    states = ["CREATED", "VALIDATED", "SCHEDULED", "LAUNCHED", "EXECUTING", 
"CANCELING", "CANCELED", "COMPLETED", "FAILED"]
+  def get_experiment_status(self, experiment_id: str) -> ExperimentState:
     status = self.api_server_client.get_experiment_status(self.airavata_token, 
experiment_id)
-    state = status.state.name
-    if state in states:
-      return state
-    else:
-      return "FAILED"
+    if status is None:
+      return ExperimentState.CREATED
+    assert isinstance(status, ExperimentStatus)
+    return status.state
 
   def stop_experiment(self, experiment_id: str):
     status = self.api_server_client.terminate_experiment(
@@ -794,16 +823,12 @@ class AiravataOperator:
       Remote(cluster="anvil.rcac.purdue.edu", category="cpu", 
queue_name="shared", node_count=1, cpu_count=24, gpu_count=0, walltime=30, 
group="Default"),
     ]
   
-  def get_task_status(self, experiment_id: str) -> tuple[str, 
Literal["SUBMITTED", "UN_SUBMITTED", "SETUP", "QUEUED", "ACTIVE", "COMPLETE", 
"CANCELING", "CANCELED", "FAILED", "HELD", "SUSPENDED", "UNKNOWN"] | None]:
-    states = ["SUBMITTED", "UN_SUBMITTED", "SETUP", "QUEUED", "ACTIVE", 
"COMPLETE", "CANCELING", "CANCELED", "FAILED", "HELD", "SUSPENDED", "UNKNOWN"]
-    job_details: dict = 
self.api_server_client.get_job_statuses(self.airavata_token, experiment_id) # 
type: ignore
-    print(f"[av] Job details: {job_details}")
+  def get_task_status(self, experiment_id: str) -> tuple[str, JobState]:
+    job_details: dict[str, JobStatus] = 
self.api_server_client.get_job_statuses(self.airavata_token, experiment_id) # 
type: ignore
     job_id = job_state = None
-    # get the most recent job id and state
-    for job_id, v in job_details.items(): # type: ignore
-      if v.reason in states:
-        job_state = v.reason
-      else:
-        job_state = states[int(v.jobState)]
-    return job_id or "N/A", job_state # type: ignore
+    for job_id, v in job_details.items():
+      job_state = v.jobState
+    return job_id or "N/A", job_state or JobState.UNKNOWN
+  
+  JobState = JobState
 
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py
index 8560620616..7f35e3c6d6 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py
@@ -66,11 +66,13 @@ class Plan(pydantic.BaseModel):
       statuses.append(task.status())
     return statuses
 
-  def __stage_stop__(self) -> None:
-    print("Stopping task(s)...")
-    for task in self.tasks:
-      task.stop()
-    print("Task(s) stopped.")
+  def __stage_stop__(self, runs: list[int] = []) -> None:
+    runs = runs if len(runs) > 0 else list(range(len(self.tasks)))
+    print(f"Stopping task(s): {runs}")
+    for i, task in enumerate(self.tasks):
+      if i in runs:
+        task.stop()
+    print(f"Task(s) stopped: {runs}")
 
   def __stage_fetch__(self, local_dir: str) -> list[list[str]]:
     print("Fetching results...")
@@ -119,8 +121,8 @@ class Plan(pydantic.BaseModel):
     assert os.path.isdir(local_dir)
     self.__stage_fetch__(local_dir)
 
-  def stop(self) -> None:
-    self.__stage_stop__()
+  def stop(self, runs: list[int] = []) -> None:
+    self.__stage_stop__(runs)
     self.save()
 
   def export(self, filename: str) -> None:
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py
index 00f039f459..054f94ec82 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py
@@ -16,7 +16,7 @@
 
 from __future__ import annotations
 import abc
-from typing import Any
+from typing import Any, Literal
 from pathlib import Path
 import os
 
@@ -24,6 +24,28 @@ import pydantic
 
 # from .task import Task
 Task = Any
+States = Literal[
+  # Experiment States
+  'CREATED',
+  'VALIDATED',
+  'SCHEDULED',
+  'LAUNCHED',
+  'EXECUTING',
+  'CANCELING',
+  'CANCELED',
+  'COMPLETED',
+  'FAILED',
+  # Job States
+  'SUBMITTED',
+  'QUEUED',
+  'ACTIVE',
+  'COMPLETE',
+  'CANCELED',
+  'FAILED',
+  'SUSPENDED',
+  'UNKNOWN',
+  'NON_CRITICAL_FAIL',
+]
 
 class Runtime(abc.ABC, pydantic.BaseModel):
 
@@ -36,6 +58,9 @@ class Runtime(abc.ABC, pydantic.BaseModel):
   @abc.abstractmethod
   def execute_py(self, libraries: list[str], code: str, task: Task) -> None: 
...
 
+  @abc.abstractmethod
+  def execute_cmd(self, cmd: str, task: Task) -> bytes: ...
+
   @abc.abstractmethod
   def status(self, task: Task) -> tuple[str, str]: ...
 
@@ -87,6 +112,9 @@ class Mock(Runtime):
     task.agent_ref = str(uuid.uuid4())
     task.ref = str(uuid.uuid4())
 
+  def execute_cmd(self, cmd: str, task: Task) -> bytes:
+    return b""
+
   def execute_py(self, libraries: list[str], code: str, task: Task) -> None:
     pass
 
@@ -158,6 +186,23 @@ class Remote(Runtime):
     except Exception as e:
       print(f"[Remote] Failed to launch experiment: {repr(e)}")
       raise e
+    
+  def execute_cmd(self, cmd: str, task: Task) -> bytes:
+    assert task.ref is not None
+    assert task.agent_ref is not None
+    assert task.pid is not None
+    assert task.sr_host is not None
+    assert task.workdir is not None
+
+    from .airavata import AiravataOperator
+    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    try:
+      result = av.execute_cmd(task.agent_ref, cmd)
+      return result
+    except Exception as e:
+      print(f"[Remote] Failed to execute command: {repr(e)}")
+      return b""
+
 
   def execute_py(self, libraries: list[str], code: str, task: Task) -> None:
     assert task.ref is not None
@@ -169,7 +214,7 @@ class Remote(Runtime):
     result = av.execute_py(task.project, libraries, code, task.agent_ref, 
task.pid, task.runtime.args)
     print(result)
 
-  def status(self, task: Task) -> tuple[str, str]:
+  def status(self, task: Task) -> tuple[str, States]:
     assert task.ref is not None
     assert task.agent_ref is not None
 
@@ -177,10 +222,10 @@ class Remote(Runtime):
     av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
     # prioritize job state, fallback to experiment state
     job_id, job_state = av.get_task_status(task.ref)
-    if not job_state or job_state == "UN_SUBMITTED":
-      return job_id, av.get_experiment_status(task.ref)
+    if job_state in [AiravataOperator.JobState.UNKNOWN, 
AiravataOperator.JobState.NON_CRITICAL_FAIL]:
+      return job_id, av.get_experiment_status(task.ref).name
     else:
-      return job_id, job_state
+      return job_id, job_state.name
 
   def signal(self, signal: str, task: Task) -> None:
     assert task.ref is not None
@@ -259,5 +304,5 @@ def find_runtimes(
       out_runtimes.append(r)
   return out_runtimes
 
-def is_terminal_state(x):
-  return x in ["CANCELED", "COMPLETED", "FAILED"]
\ No newline at end of file
+def is_terminal_state(x: States) -> bool:
+  return x in ["CANCELED", "COMPLETE", "COMPLETED", "FAILED"]
\ No newline at end of file
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/task.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/task.py
index bcda796518..3700bf2f66 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/task.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/task.py
@@ -72,7 +72,11 @@ class Task(pydantic.BaseModel):
     assert self.ref is not None
     from pathlib import Path
     Path(local_dir).mkdir(parents=True, exist_ok=True)
-    return self.runtime.download(file, local_dir, self)
+    try:
+      return self.runtime.download(file, local_dir, self)
+    except Exception as e:
+      print(f"[Remote] Failed to download file: {repr(e)}")
+      return ""
   
   def download_all(self, local_dir: str) -> list[str]:
     assert self.ref is not None
@@ -92,6 +96,10 @@ class Task(pydantic.BaseModel):
   def cat(self, file: str) -> bytes:
     assert self.ref is not None
     return self.runtime.cat(file, self)
+  
+  def exec(self, cmd: str) -> bytes:
+    assert self.ref is not None
+    return self.runtime.execute_cmd(cmd, self)
 
   def stop(self) -> None:
     assert self.ref is not None

Reply via email to