This is an automated email from the ASF dual-hosted git repository.

yasith pushed a commit to branch sdk-batch-jobs
in repository https://gitbox.apache.org/repos/asf/airavata.git

commit a451521f84458f6573a50c478d80181a6e9d8ccc
Author: yasithdev <[email protected]>
AuthorDate: Thu Jul 31 04:16:33 2025 -0500

    improve sdk functions
---
 .../airavata_experiments/__init__.py               |  4 +-
 .../airavata_experiments/airavata.py               | 46 ++++++++++++----------
 .../airavata_experiments/base.py                   | 25 ++++++++----
 .../airavata_experiments/plan.py                   |  4 +-
 .../airavata_experiments/runtime.py                |  2 +-
 .../airavata-python-sdk/airavata_sdk/__init__.py   |  2 +-
 6 files changed, 50 insertions(+), 33 deletions(-)

diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/__init__.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/__init__.py
index dd391c07c4..f6f50fef20 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/__init__.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/__init__.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 from . import base, plan
 from airavata_auth.device_auth import AuthContext
-from .runtime import list_runtimes, Runtime
+from .runtime import find_runtimes, Runtime
 from typing import Any
 
 
@@ -27,7 +27,7 @@ context = AuthContext()
 def login():
   context.login()
 
-__all__ = ["list_runtimes", "base", "plan", "login"]
+__all__ = ["find_runtimes", "base", "plan", "login"]
 
 
 def display_runtimes(runtimes: list[Runtime]) -> None:
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
index e6d7b10385..05e91ab494 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
@@ -100,8 +100,10 @@ class AiravataOperator:
     )
   
   def get_resource_host_id(self, resource_name):
-        resources: dict = 
self.api_server_client.get_all_compute_resource_names(self.airavata_token) # 
type: ignore
-        return next((str(k) for k, v in resources.items() if v == 
resource_name))
+    resources: dict = 
self.api_server_client.get_all_compute_resource_names(self.airavata_token) # 
type: ignore
+    resource_id = next((str(k) for k, v in resources.items() if v == 
resource_name), None)
+    assert resource_id is not None, f"Compute resource {resource_name} not 
found"
+    return resource_id
   
   def configure_computation_resource_scheduling(
       self,
@@ -188,7 +190,7 @@ class AiravataOperator:
     """
     tree = 
self.api_server_client.get_detailed_experiment_tree(self.airavata_token, 
experiment_id) # type: ignore
     processModels = tree.processes
-    assert processModels is not None
+    assert processModels is not None, f"No process models found for experiment 
{experiment_id}"
     assert len(processModels) == 1, f"Expected 1 process model, got 
{len(processModels)}"
     return processModels[0].processId
 
@@ -213,7 +215,8 @@ class AiravataOperator:
     sr_hostname = sr_hostname or self.default_sr_hostname()
     # logic
     sr_names: dict[str, str] = 
self.api_server_client.get_all_storage_resource_names(self.airavata_token) # 
type: ignore
-    sr_id = next((str(k) for k, v in sr_names.items() if v == sr_hostname))
+    sr_id = next((str(k) for k, v in sr_names.items() if v == sr_hostname), 
None)
+    assert sr_id is not None, f"Storage resource {sr_hostname} not found"
     return 
self.api_server_client.get_gateway_storage_preference(self.airavata_token, 
gateway_id, sr_id)
 
   def get_storage(self, storage_name: str | None = None) -> any:  # type: 
ignore
@@ -225,7 +228,8 @@ class AiravataOperator:
     storage_name = storage_name or self.default_sr_hostname()
     # logic
     sr_names: dict[str, str] = 
self.api_server_client.get_all_storage_resource_names(self.airavata_token) # 
type: ignore
-    sr_id = next((str(k) for k, v in sr_names.items() if v == storage_name))
+    sr_id = next((str(k) for k, v in sr_names.items() if v == storage_name), 
None)
+    assert sr_id is not None, f"Storage resource {storage_name} not found"
     storage = self.api_server_client.get_storage_resource(self.airavata_token, 
sr_id)
     return storage
 
@@ -236,11 +240,9 @@ class AiravataOperator:
     """
     # logic
     grps: list[GroupResourceProfile] = 
self.api_server_client.get_group_resource_list(self.airavata_token, 
self.default_gateway_id()) # type: ignore
-    try:
-      grp_id = next((grp.groupResourceProfileId for grp in grps if 
grp.groupResourceProfileName == group))
-      return str(grp_id)
-    except StopIteration:
-      raise Exception(f"Group resource profile {group} not found")
+    grp_id = next((grp.groupResourceProfileId for grp in grps if 
grp.groupResourceProfileName == group), None)
+    assert grp_id is not None, f"Group resource profile {group} not found"
+    return str(grp_id)
   
   def get_group_resource_profile(self, group_id: str):
     grp = 
self.api_server_client.get_group_resource_profile(self.airavata_token, 
group_id) # type: ignore
@@ -253,7 +255,8 @@ class AiravataOperator:
     """
     # logic
     grps: list = 
self.api_server_client.get_group_resource_list(self.airavata_token, 
self.default_gateway_id()) # type: ignore
-    grp_id = next((grp.groupResourceProfileId for grp in grps if 
grp.groupResourceProfileName == group))
+    grp_id = next((grp.groupResourceProfileId for grp in grps if 
grp.groupResourceProfileName == group), None)
+    assert grp_id is not None, f"Group resource profile {group} not found"
     deployments = 
self.api_server_client.get_application_deployments_for_app_module_and_group_resource_profile(self.airavata_token,
 app_interface_id, grp_id)
     return deployments
 
@@ -264,13 +267,15 @@ class AiravataOperator:
     """
     gateway_id = str(gateway_id or self.default_gateway_id())
     apps: list = 
self.api_server_client.get_all_application_interfaces(self.airavata_token, 
gateway_id) # type: ignore
-    app_id = next((app.applicationInterfaceId for app in apps if 
app.applicationName == app_name))
+    app_id = next((app.applicationInterfaceId for app in apps if 
app.applicationName == app_name), None)
+    assert app_id is not None, f"Application interface {app_name} not found"
     return str(app_id)
 
   def get_project_id(self, project_name: str, gateway_id: str | None = None):
     gateway_id = str(gateway_id or self.default_gateway_id())
     projects: list = 
self.api_server_client.get_user_projects(self.airavata_token, gateway_id, 
self.user_id, 10, 0) # type: ignore
-    project_id = next((p.projectID for p in projects if p.name == project_name 
and p.owner == self.user_id))
+    project_id = next((p.projectID for p in projects if p.name == project_name 
and p.owner == self.user_id), None)
+    assert project_id is not None, f"Project {project_name} not found"
     return str(project_id)
 
   def get_application_inputs(self, app_interface_id: str) -> list:
@@ -596,8 +601,8 @@ class AiravataOperator:
     def register_input_file(file: Path) -> str:
       return str(self.register_input_file(file.name, sr_host, sr_id, 
gateway_id, file.name, abs_path))
     
-    # set up file inputs
-    print("[AV] Setting up file inputs...")
+    # set up experiment inputs
+    print("[AV] Setting up experiment inputs...")
     files_to_upload = list[Path]()
     file_refs = dict[str, str | list[str]]()
     for key, value in file_inputs.items():
@@ -610,11 +615,9 @@ class AiravataOperator:
         file_refs[key] = [*map(register_input_file, value)]
       else:
         raise ValueError("Invalid file input type")
-
-    # configure experiment inputs
     experiment_inputs = []
     for exp_input in 
self.api_server_client.get_application_inputs(self.airavata_token, 
app_interface_id):  # type: ignore
-      assert exp_input.type is not None
+      assert exp_input.type is not None, f"Invalid exp_input type for 
{exp_input.name}: {exp_input.type}"
       if exp_input.type < 3 and exp_input.name in data_inputs:
         value = data_inputs[exp_input.name]
         if exp_input.type == 0:
@@ -623,11 +626,12 @@ class AiravataOperator:
           exp_input.value = repr(value)
       elif exp_input.type == 3 and exp_input.name in file_refs:
         ref = file_refs[exp_input.name]
-        assert isinstance(ref, str)
+        assert isinstance(ref, str), f"Invalid file ref: {ref}"
         exp_input.value = ref
       elif exp_input.type == 4 and exp_input.name in file_refs:
         exp_input.value = ','.join(file_refs[exp_input.name])
       experiment_inputs.append(exp_input)
+      print(f"[AV] * {exp_input.name}={exp_input.value}")
     experiment.experimentInputs = experiment_inputs
 
     # configure experiment outputs
@@ -784,6 +788,7 @@ class AiravataOperator:
   def get_available_runtimes(self):
     from .runtime import Remote
     return [
+      Remote(cluster="login.delta.ncsa.illinois.edu", category="gpu", 
queue_name="gpuA100x4", node_count=1, cpu_count=10, gpu_count=1, walltime=30, 
group="Default"),
       Remote(cluster="login.expanse.sdsc.edu", category="gpu", 
queue_name="gpu-shared", node_count=1, cpu_count=10, gpu_count=1, walltime=30, 
group="Default"),
       Remote(cluster="login.expanse.sdsc.edu", category="cpu", 
queue_name="shared", node_count=1, cpu_count=10, gpu_count=0, walltime=30, 
group="Default"),
       Remote(cluster="anvil.rcac.purdue.edu", category="cpu", 
queue_name="shared", node_count=1, cpu_count=24, gpu_count=0, walltime=30, 
group="Default"),
@@ -792,9 +797,10 @@ class AiravataOperator:
   def get_task_status(self, experiment_id: str) -> tuple[str, 
Literal["SUBMITTED", "UN_SUBMITTED", "SETUP", "QUEUED", "ACTIVE", "COMPLETE", 
"CANCELING", "CANCELED", "FAILED", "HELD", "SUSPENDED", "UNKNOWN"] | None]:
     states = ["SUBMITTED", "UN_SUBMITTED", "SETUP", "QUEUED", "ACTIVE", 
"COMPLETE", "CANCELING", "CANCELED", "FAILED", "HELD", "SUSPENDED", "UNKNOWN"]
     job_details: dict = 
self.api_server_client.get_job_statuses(self.airavata_token, experiment_id) # 
type: ignore
+    print(f"[av] Job details: {job_details}")
     job_id = job_state = None
     # get the most recent job id and state
-    for job_id, v in job_details.items():
+    for job_id, v in job_details.items(): # type: ignore
       if v.reason in states:
         job_state = v.reason
       else:
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/base.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/base.py
index e9ad36b68e..be8ec92b4e 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/base.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/base.py
@@ -85,18 +85,25 @@ class Experiment(Generic[T], abc.ABC):
     self.resource = resource
     return self
 
-  def create_task(self, *allowed_runtimes: Runtime, name: str | None = None) 
-> None:
+  def add_run(self, use: list[Runtime] = [], name: str | None = None, 
**kwargs) -> None:
     """
     Create a task to run the experiment on a given runtime.
     """
-    runtime = random.choice(allowed_runtimes) if len(allowed_runtimes) > 0 
else self.resource
+    runtime = random.choice(use) if len(use) > 0 else self.resource
     uuid_str = str(uuid.uuid4())[:4].upper()
-
+    # override walltime if one is provided in kwargs
+    runtime = runtime.model_copy()
+    if (w := kwargs.pop("walltime", None)) is not None:
+      runtime.args["walltime"] = w
+    # override experiment inputs with inputs provided in kwargs
+    inputs = self.inputs.copy()
+    inputs.update(kwargs)
+    # create a task with the given runtime and inputs
     self.tasks.append(
         Task(
             name=name or f"{self.name}_{uuid_str}",
             app_id=self.application.app_id,
-            inputs={**self.inputs},
+            inputs=inputs,
             runtime=runtime,
         )
     )
@@ -124,10 +131,14 @@ class Experiment(Generic[T], abc.ABC):
 
   def plan(self, **kwargs) -> Plan:
     if len(self.tasks) == 0:
-      self.create_task(self.resource)
+      self.add_run()
     tasks = []
     for t in self.tasks:
       agg_inputs = {**self.inputs, **t.inputs}
       task_inputs = {k: {"value": agg_inputs[v[0]], "type": v[1]} for k, v in 
self.input_mapping.items()}
-      tasks.append(Task(name=t.name, app_id=self.application.app_id, 
inputs=task_inputs, runtime=t.runtime))
-    return Plan(tasks=tasks)
+      task = Task(name=t.name, app_id=self.application.app_id, 
inputs=task_inputs, runtime=t.runtime)
+      # task.freeze()  # TODO upload the task-related data and freeze the task
+      tasks.append(task)
+    plan = Plan(tasks=tasks)
+    plan.save()
+    return plan
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py
index f231e2583d..8560620616 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py
@@ -78,7 +78,7 @@ class Plan(pydantic.BaseModel):
     for task in self.tasks:
       fps.append(task.download_all(local_dir))
     print("Results fetched.")
-    self.save_json(os.path.join(local_dir, "plan.json"))
+    self.export(os.path.join(local_dir, "plan.json"))
     return fps
 
   def launch(self, silent: bool = True) -> None:
@@ -123,7 +123,7 @@ class Plan(pydantic.BaseModel):
     self.__stage_stop__()
     self.save()
 
-  def save_json(self, filename: str) -> None:
+  def export(self, filename: str) -> None:
     with open(filename, "w") as f:
       json.dump(self.model_dump(), f, indent=2)
 
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py
index e843135843..00f039f459 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py
@@ -239,7 +239,7 @@ class Remote(Runtime):
     return content
 
 
-def list_runtimes(
+def find_runtimes(
     cluster: str | None = None,
     category: str | None = None,
     group: str | None = None,
diff --git a/dev-tools/airavata-python-sdk/airavata_sdk/__init__.py 
b/dev-tools/airavata-python-sdk/airavata_sdk/__init__.py
index 546c3bcd67..0a559918e8 100644
--- a/dev-tools/airavata-python-sdk/airavata_sdk/__init__.py
+++ b/dev-tools/airavata-python-sdk/airavata_sdk/__init__.py
@@ -164,7 +164,7 @@ class Settings:
 
     @property
     def STORAGE_RESOURCE_HOST(self):
-        return str(os.getenv("STORAGE_RESOURCE_HOST", "cybershuttle.org"))
+        return str(os.getenv("STORAGE_RESOURCE_HOST", 
"gateway.cybershuttle.org"))
 
     @property
     def SFTP_PORT(self):

Reply via email to