This is an automated email from the ASF dual-hosted git repository. yasith pushed a commit to branch sdk-batch-jobs in repository https://gitbox.apache.org/repos/asf/airavata.git
commit a451521f84458f6573a50c478d80181a6e9d8ccc Author: yasithdev <[email protected]> AuthorDate: Thu Jul 31 04:16:33 2025 -0500 improve sdk functions --- .../airavata_experiments/__init__.py | 4 +- .../airavata_experiments/airavata.py | 46 ++++++++++++---------- .../airavata_experiments/base.py | 25 ++++++++---- .../airavata_experiments/plan.py | 4 +- .../airavata_experiments/runtime.py | 2 +- .../airavata-python-sdk/airavata_sdk/__init__.py | 2 +- 6 files changed, 50 insertions(+), 33 deletions(-) diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/__init__.py b/dev-tools/airavata-python-sdk/airavata_experiments/__init__.py index dd391c07c4..f6f50fef20 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/__init__.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/__init__.py @@ -18,7 +18,7 @@ from __future__ import annotations from . import base, plan from airavata_auth.device_auth import AuthContext -from .runtime import list_runtimes, Runtime +from .runtime import find_runtimes, Runtime from typing import Any @@ -27,7 +27,7 @@ context = AuthContext() def login(): context.login() -__all__ = ["list_runtimes", "base", "plan", "login"] +__all__ = ["find_runtimes", "base", "plan", "login"] def display_runtimes(runtimes: list[Runtime]) -> None: diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py index e6d7b10385..05e91ab494 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py @@ -100,8 +100,10 @@ class AiravataOperator: ) def get_resource_host_id(self, resource_name): - resources: dict = self.api_server_client.get_all_compute_resource_names(self.airavata_token) # type: ignore - return next((str(k) for k, v in resources.items() if v == resource_name)) + resources: dict = self.api_server_client.get_all_compute_resource_names(self.airavata_token) # type: ignore + resource_id = next((str(k) for k, v in resources.items() if v == resource_name), None) + assert resource_id is not None, f"Compute resource {resource_name} not found" + return resource_id def configure_computation_resource_scheduling( self, @@ -188,7 +190,7 @@ class AiravataOperator: """ tree = self.api_server_client.get_detailed_experiment_tree(self.airavata_token, experiment_id) # type: ignore processModels = tree.processes - assert processModels is not None + assert processModels is not None, f"No process models found for experiment {experiment_id}" assert len(processModels) == 1, f"Expected 1 process model, got {len(processModels)}" return processModels[0].processId @@ -213,7 +215,8 @@ class AiravataOperator: sr_hostname = sr_hostname or self.default_sr_hostname() # logic sr_names: dict[str, str] = self.api_server_client.get_all_storage_resource_names(self.airavata_token) # type: ignore - sr_id = next((str(k) for k, v in sr_names.items() if v == sr_hostname)) + sr_id = next((str(k) for k, v in sr_names.items() if v == sr_hostname), None) + assert sr_id is not None, f"Storage resource {sr_hostname} not found" return self.api_server_client.get_gateway_storage_preference(self.airavata_token, gateway_id, sr_id) def get_storage(self, storage_name: str | None = None) -> any: # type: ignore @@ -225,7 +228,8 @@ class AiravataOperator: storage_name = storage_name or self.default_sr_hostname() # logic sr_names: dict[str, str] = self.api_server_client.get_all_storage_resource_names(self.airavata_token) # type: ignore - sr_id = next((str(k) for k, v in sr_names.items() if v == storage_name)) + sr_id = next((str(k) for k, v in sr_names.items() if v == storage_name), None) + assert sr_id is not None, f"Storage resource {storage_name} not found" storage = self.api_server_client.get_storage_resource(self.airavata_token, sr_id) return storage @@ -236,11 +240,9 @@ class AiravataOperator: """ # logic grps: list[GroupResourceProfile] = self.api_server_client.get_group_resource_list(self.airavata_token, self.default_gateway_id()) # type: ignore - try: - grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == group)) - return str(grp_id) - except StopIteration: - raise Exception(f"Group resource profile {group} not found") + grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == group), None) + assert grp_id is not None, f"Group resource profile {group} not found" + return str(grp_id) def get_group_resource_profile(self, group_id: str): grp = self.api_server_client.get_group_resource_profile(self.airavata_token, group_id) # type: ignore @@ -253,7 +255,8 @@ class AiravataOperator: """ # logic grps: list = self.api_server_client.get_group_resource_list(self.airavata_token, self.default_gateway_id()) # type: ignore - grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == group)) + grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == group), None) + assert grp_id is not None, f"Group resource profile {group} not found" deployments = self.api_server_client.get_application_deployments_for_app_module_and_group_resource_profile(self.airavata_token, app_interface_id, grp_id) return deployments @@ -264,13 +267,15 @@ class AiravataOperator: """ gateway_id = str(gateway_id or self.default_gateway_id()) apps: list = self.api_server_client.get_all_application_interfaces(self.airavata_token, gateway_id) # type: ignore - app_id = next((app.applicationInterfaceId for app in apps if app.applicationName == app_name)) + app_id = next((app.applicationInterfaceId for app in apps if app.applicationName == app_name), None) + assert app_id is not None, f"Application interface {app_name} not found" return str(app_id) def get_project_id(self, project_name: str, gateway_id: str | None = None): gateway_id = str(gateway_id or self.default_gateway_id()) projects: list = self.api_server_client.get_user_projects(self.airavata_token, gateway_id, self.user_id, 10, 0) # type: ignore - project_id = next((p.projectID for p in projects if p.name == project_name and p.owner == self.user_id)) + project_id = next((p.projectID for p in projects if p.name == project_name and p.owner == self.user_id), None) + assert project_id is not None, f"Project {project_name} not found" return str(project_id) def get_application_inputs(self, app_interface_id: str) -> list: @@ -596,8 +601,8 @@ class AiravataOperator: def register_input_file(file: Path) -> str: return str(self.register_input_file(file.name, sr_host, sr_id, gateway_id, file.name, abs_path)) - # set up file inputs - print("[AV] Setting up file inputs...") + # set up experiment inputs + print("[AV] Setting up experiment inputs...") files_to_upload = list[Path]() file_refs = dict[str, str | list[str]]() for key, value in file_inputs.items(): @@ -610,11 +615,9 @@ class AiravataOperator: file_refs[key] = [*map(register_input_file, value)] else: raise ValueError("Invalid file input type") - - # configure experiment inputs experiment_inputs = [] for exp_input in self.api_server_client.get_application_inputs(self.airavata_token, app_interface_id): # type: ignore - assert exp_input.type is not None + assert exp_input.type is not None, f"Invalid exp_input type for {exp_input.name}: {exp_input.type}" if exp_input.type < 3 and exp_input.name in data_inputs: value = data_inputs[exp_input.name] if exp_input.type == 0: @@ -623,11 +626,12 @@ class AiravataOperator: exp_input.value = repr(value) elif exp_input.type == 3 and exp_input.name in file_refs: ref = file_refs[exp_input.name] - assert isinstance(ref, str) + assert isinstance(ref, str), f"Invalid file ref: {ref}" exp_input.value = ref elif exp_input.type == 4 and exp_input.name in file_refs: exp_input.value = ','.join(file_refs[exp_input.name]) experiment_inputs.append(exp_input) + print(f"[AV] * {exp_input.name}={exp_input.value}") experiment.experimentInputs = experiment_inputs # configure experiment outputs @@ -784,6 +788,7 @@ class AiravataOperator: def get_available_runtimes(self): from .runtime import Remote return [ + Remote(cluster="login.delta.ncsa.illinois.edu", category="gpu", queue_name="gpuA100x4", node_count=1, cpu_count=10, gpu_count=1, walltime=30, group="Default"), Remote(cluster="login.expanse.sdsc.edu", category="gpu", queue_name="gpu-shared", node_count=1, cpu_count=10, gpu_count=1, walltime=30, group="Default"), Remote(cluster="login.expanse.sdsc.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=10, gpu_count=0, walltime=30, group="Default"), Remote(cluster="anvil.rcac.purdue.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=24, gpu_count=0, walltime=30, group="Default"), @@ -792,9 +797,10 @@ class AiravataOperator: def get_task_status(self, experiment_id: str) -> tuple[str, Literal["SUBMITTED", "UN_SUBMITTED", "SETUP", "QUEUED", "ACTIVE", "COMPLETE", "CANCELING", "CANCELED", "FAILED", "HELD", "SUSPENDED", "UNKNOWN"] | None]: states = ["SUBMITTED", "UN_SUBMITTED", "SETUP", "QUEUED", "ACTIVE", "COMPLETE", "CANCELING", "CANCELED", "FAILED", "HELD", "SUSPENDED", "UNKNOWN"] job_details: dict = self.api_server_client.get_job_statuses(self.airavata_token, experiment_id) # type: ignore + print(f"[av] Job details: {job_details}") job_id = job_state = None # get the most recent job id and state - for job_id, v in job_details.items(): + for job_id, v in job_details.items(): # type: ignore if v.reason in states: job_state = v.reason else: diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/base.py b/dev-tools/airavata-python-sdk/airavata_experiments/base.py index e9ad36b68e..be8ec92b4e 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/base.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/base.py @@ -85,18 +85,25 @@ class Experiment(Generic[T], abc.ABC): self.resource = resource return self - def create_task(self, *allowed_runtimes: Runtime, name: str | None = None) -> None: + def add_run(self, use: list[Runtime] = [], name: str | None = None, **kwargs) -> None: """ Create a task to run the experiment on a given runtime. """ - runtime = random.choice(allowed_runtimes) if len(allowed_runtimes) > 0 else self.resource + runtime = random.choice(use) if len(use) > 0 else self.resource uuid_str = str(uuid.uuid4())[:4].upper() - + # override walltime if one is provided in kwargs + runtime = runtime.model_copy() + if (w := kwargs.pop("walltime", None)) is not None: + runtime.args["walltime"] = w + # override experiment inputs with inputs provided in kwargs + inputs = self.inputs.copy() + inputs.update(kwargs) + # create a task with the given runtime and inputs self.tasks.append( Task( name=name or f"{self.name}_{uuid_str}", app_id=self.application.app_id, - inputs={**self.inputs}, + inputs=inputs, runtime=runtime, ) ) @@ -124,10 +131,14 @@ class Experiment(Generic[T], abc.ABC): def plan(self, **kwargs) -> Plan: if len(self.tasks) == 0: - self.create_task(self.resource) + self.add_run() tasks = [] for t in self.tasks: agg_inputs = {**self.inputs, **t.inputs} task_inputs = {k: {"value": agg_inputs[v[0]], "type": v[1]} for k, v in self.input_mapping.items()} - tasks.append(Task(name=t.name, app_id=self.application.app_id, inputs=task_inputs, runtime=t.runtime)) - return Plan(tasks=tasks) + task = Task(name=t.name, app_id=self.application.app_id, inputs=task_inputs, runtime=t.runtime) + # task.freeze() # TODO upload the task-related data and freeze the task + tasks.append(task) + plan = Plan(tasks=tasks) + plan.save() + return plan diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py index f231e2583d..8560620616 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py @@ -78,7 +78,7 @@ class Plan(pydantic.BaseModel): for task in self.tasks: fps.append(task.download_all(local_dir)) print("Results fetched.") - self.save_json(os.path.join(local_dir, "plan.json")) + self.export(os.path.join(local_dir, "plan.json")) return fps def launch(self, silent: bool = True) -> None: @@ -123,7 +123,7 @@ class Plan(pydantic.BaseModel): self.__stage_stop__() self.save() - def save_json(self, filename: str) -> None: + def export(self, filename: str) -> None: with open(filename, "w") as f: json.dump(self.model_dump(), f, indent=2) diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py index e843135843..00f039f459 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py @@ -239,7 +239,7 @@ class Remote(Runtime): return content -def list_runtimes( +def find_runtimes( cluster: str | None = None, category: str | None = None, group: str | None = None, diff --git a/dev-tools/airavata-python-sdk/airavata_sdk/__init__.py b/dev-tools/airavata-python-sdk/airavata_sdk/__init__.py index 546c3bcd67..0a559918e8 100644 --- a/dev-tools/airavata-python-sdk/airavata_sdk/__init__.py +++ b/dev-tools/airavata-python-sdk/airavata_sdk/__init__.py @@ -164,7 +164,7 @@ class Settings: @property def STORAGE_RESOURCE_HOST(self): - return str(os.getenv("STORAGE_RESOURCE_HOST", "cybershuttle.org")) + return str(os.getenv("STORAGE_RESOURCE_HOST", "gateway.cybershuttle.org")) @property def SFTP_PORT(self):
