This is an automated email from the ASF dual-hosted git repository.

yasith pushed a commit to branch sdk-batch-jobs
in repository https://gitbox.apache.org/repos/asf/airavata.git

commit 3b84ca6358a2b0f12aba0f6777ec3bed8853749f
Author: yasithdev <[email protected]>
AuthorDate: Fri Aug 1 06:32:07 2025 -0500

    list actual runtimes from API. reorganize auth code. require cpus, nodes, 
and walltime per run
---
 .../airavata_auth/device_auth.py                   |  9 +++
 .../airavata_experiments/airavata.py               | 48 ++++++++---
 .../airavata_experiments/base.py                   | 38 +++++----
 .../airavata_experiments/plan.py                   |  7 +-
 .../airavata_experiments/runtime.py                | 59 +++++++++-----
 .../airavata_jupyter_magic/__init__.py             | 94 +++++++++++++---------
 6 files changed, 163 insertions(+), 92 deletions(-)

diff --git a/dev-tools/airavata-python-sdk/airavata_auth/device_auth.py 
b/dev-tools/airavata-python-sdk/airavata_auth/device_auth.py
index 3037ab1702..944055ad26 100644
--- a/dev-tools/airavata-python-sdk/airavata_auth/device_auth.py
+++ b/dev-tools/airavata-python-sdk/airavata_auth/device_auth.py
@@ -10,6 +10,13 @@ from airavata_sdk import Settings
 
 
 class AuthContext:
+    
+    @staticmethod
+    def get_access_token():
+        if os.environ.get("CS_ACCESS_TOKEN", None) is None:
+            context = AuthContext()
+            context.login()
+        return os.environ["CS_ACCESS_TOKEN"]
 
     def __init__(self):
         self.settings = Settings()
@@ -21,6 +28,8 @@ class AuthContext:
         self.console = Console()
 
     def login(self):
+        if os.environ.get('CS_ACCESS_TOKEN', None) is not None:
+            return
         # Step 1: Request device and user code
         auth_device_url = 
f"{self.settings.AUTH_SERVER_URL}/realms/{self.settings.AUTH_REALM}/protocol/openid-connect/auth/device"
         response = requests.post(auth_device_url, data={
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
index 362dbdcf55..fd83596c08 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
@@ -35,7 +35,8 @@ from airavata.model.security.ttypes import AuthzToken
 from airavata.model.experiment.ttypes import ExperimentModel, ExperimentType, 
UserConfigurationDataModel
 from airavata.model.scheduling.ttypes import 
ComputationalResourceSchedulingModel
 from airavata.model.data.replica.ttypes import DataProductModel, 
DataProductType, DataReplicaLocationModel, ReplicaLocationCategory
-from airavata.model.appcatalog.groupresourceprofile.ttypes import 
GroupResourceProfile
+from airavata.model.appcatalog.groupresourceprofile.ttypes import 
GroupResourceProfile, ResourceType
+from airavata.model.appcatalog.computeresource.ttypes import 
ComputeResourceDescription
 from airavata.model.status.ttypes import JobStatus, JobState, 
ExperimentStatus, ExperimentState
 
 warnings.filterwarnings("ignore", category=DeprecationWarning)
@@ -101,8 +102,8 @@ class AiravataOperator:
     )
   
   def get_resource_host_id(self, resource_name):
-    resources: dict = 
self.api_server_client.get_all_compute_resource_names(self.airavata_token) # 
type: ignore
-    resource_id = next((str(k) for k, v in resources.items() if v == 
resource_name), None)
+    resources = 
self.api_server_client.get_all_compute_resource_names(self.airavata_token)
+    resource_id = next((k for k in resources if k.startswith(resource_name)), 
None)
     assert resource_id is not None, f"Compute resource {resource_name} not 
found"
     return resource_id
   
@@ -814,14 +815,41 @@ class AiravataOperator:
       print(f"[av] Remote execution failed! {e}")
       return None
     
-  def get_available_runtimes(self):
+  def get_available_groups(self, gateway_id: str = "default"):
+    grps: list[GroupResourceProfile] = 
self.api_server_client.get_group_resource_list(self.airavata_token, 
gatewayId=gateway_id)
+    return grps
+  
+  def get_available_runtimes(self, group: str, gateway_id: str = "default"):
+    grps = self.get_available_groups(gateway_id)
+    grp_id, gcr_prefs, gcr_policies = next(((x.groupResourceProfileId, 
x.computePreferences, x.computeResourcePolicies) for x in grps if 
str(x.groupResourceProfileName).strip() == group.strip()), (None, None, None))
+    assert grp_id is not None, f"Group {group} was not found"
+    assert gcr_prefs is not None, f"Compute preferences for group={grp_id} 
were not found"
+    assert gcr_policies is not None, f"Compute policies for group={grp_id} 
were not found" # type: ignore
     from .runtime import Remote
-    return [
-      Remote(cluster="login.delta.ncsa.illinois.edu", category="gpu", 
queue_name="gpuA100x4", node_count=1, cpu_count=10, gpu_count=1, walltime=30, 
group="Default"),
-      Remote(cluster="login.expanse.sdsc.edu", category="gpu", 
queue_name="gpu-shared", node_count=1, cpu_count=10, gpu_count=1, walltime=30, 
group="Default"),
-      Remote(cluster="login.expanse.sdsc.edu", category="cpu", 
queue_name="shared", node_count=1, cpu_count=10, gpu_count=0, walltime=30, 
group="Default"),
-      Remote(cluster="anvil.rcac.purdue.edu", category="cpu", 
queue_name="shared", node_count=1, cpu_count=24, gpu_count=0, walltime=30, 
group="Default"),
-    ]
+    runtimes = []
+    for pref in gcr_prefs:
+      cr = self.api_server_client.get_compute_resource(self.airavata_token, 
pref.computeResourceId)
+      assert cr is not None, "Compute resource not found"
+      assert isinstance(cr, ComputeResourceDescription), "Compute resource is 
not a ComputeResourceDescription"
+      assert cr.batchQueues is not None, "Compute resource has no batch queues"
+      for queue in cr.batchQueues:
+        if pref.resourceType == ResourceType.SLURM:
+          policy = next((p for p in gcr_policies if p.computeResourceId == 
pref.computeResourceId), None)
+          assert policy is not None, f"Compute resource policy not found for 
{pref.computeResourceId}"
+          if queue.queueName not in (policy.allowedBatchQueues or []):
+            continue
+        runtime = Remote(
+          cluster=pref.computeResourceId.split("_")[0],
+          category="GPU" if "gpu" in queue.queueName.lower() else "CPU",
+          queue_name=queue.queueName,
+          node_count=queue.maxNodes or 1,
+          cpu_count=queue.cpuPerNode or 1,
+          gpu_count=1 if "gpu" in queue.queueName.lower() else 0,
+          walltime=queue.maxRunTime or 30,
+          group=group,
+        )
+        runtimes.append(runtime)
+    return runtimes
   
   def get_task_status(self, experiment_id: str) -> tuple[str, JobState]:
     job_details: dict[str, JobStatus] = 
self.api_server_client.get_job_statuses(self.airavata_token, experiment_id) # 
type: ignore
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/base.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/base.py
index be8ec92b4e..2ba99a3bcb 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/base.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/base.py
@@ -85,53 +85,57 @@ class Experiment(Generic[T], abc.ABC):
     self.resource = resource
     return self
 
-  def add_run(self, use: list[Runtime] = [], name: str | None = None, 
**kwargs) -> None:
+  def add_run(self, use: list[Runtime], cpus: int, nodes: int, walltime: int, 
name: str | None = None, **extra_params) -> None:
     """
     Create a task to run the experiment on a given runtime.
     """
     runtime = random.choice(use) if len(use) > 0 else self.resource
     uuid_str = str(uuid.uuid4())[:4].upper()
-    # override walltime if one is provided in kwargs
+    # override runtime args with given values
     runtime = runtime.model_copy()
-    if (w := kwargs.pop("walltime", None)) is not None:
-      runtime.args["walltime"] = w
-    # override experiment inputs with inputs provided in kwargs
-    inputs = self.inputs.copy()
-    inputs.update(kwargs)
+    runtime.args["cpu_count"] = cpus
+    runtime.args["node_count"] = nodes
+    runtime.args["walltime"] = walltime
+    # add extra inputs to task inputs
+    task_inputs = {**self.inputs, **extra_params}
     # create a task with the given runtime and inputs
     self.tasks.append(
         Task(
-            name=name or f"{self.name}_{uuid_str}",
+            name=f"{name or self.name}_{uuid_str}",
             app_id=self.application.app_id,
-            inputs=inputs,
+            inputs=task_inputs,
             runtime=runtime,
         )
     )
     print(f"Task created. ({len(self.tasks)} tasks in total)")
 
-  def add_sweep(self, *allowed_runtimes: Runtime, **space: list) -> None:
+  def add_sweep(self, use: list[Runtime], cpus: int, nodes: int, walltime: 
int, name: str | None = None, **space: list) -> None:
     """
     Add a sweep to the experiment.
 
     """
     for values in product(space.values()):
-      runtime = random.choice(allowed_runtimes) if len(allowed_runtimes) > 0 
else self.resource
+      runtime = random.choice(use) if len(use) > 0 else self.resource
       uuid_str = str(uuid.uuid4())[:4].upper()
-
+      # override runtime args with given values
+      runtime = runtime.model_copy()
+      runtime.args["cpu_count"] = cpus
+      runtime.args["node_count"] = nodes
+      runtime.args["walltime"] = walltime
+      # add sweep params to task inputs
       task_specific_params = dict(zip(space.keys(), values))
       agg_inputs = {**self.inputs, **task_specific_params}
       task_inputs = {k: {"value": agg_inputs[v[0]], "type": v[1]} for k, v in 
self.input_mapping.items()}
-
+      # create a task with the given runtime and inputs
       self.tasks.append(Task(
-          name=f"{self.name}_{uuid_str}",
+          name=f"{name or self.name}_{uuid_str}",
           app_id=self.application.app_id,
           inputs=task_inputs,
           runtime=runtime or self.resource,
       ))
 
-  def plan(self, **kwargs) -> Plan:
-    if len(self.tasks) == 0:
-      self.add_run()
+  def plan(self) -> Plan:
+    assert len(self.tasks) > 0, "add_run() must be called before plan() to 
define runtimes and resources."
     tasks = []
     for t in self.tasks:
       agg_inputs = {**self.inputs, **t.inputs}
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py
index 7f35e3c6d6..4f41135610 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py
@@ -25,6 +25,7 @@ from rich.progress import Progress
 from .runtime import is_terminal_state
 from .task import Task
 import uuid
+from airavata_auth.device_auth import AuthContext
 
 from .airavata import AiravataOperator
 
@@ -131,7 +132,7 @@ class Plan(pydantic.BaseModel):
 
   def save(self) -> None:
     settings = Settings()
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     az = av.__airavata_token__(av.access_token, av.default_gateway_id())
     assert az.accessToken is not None
     assert az.claimsMap is not None
@@ -164,7 +165,7 @@ def load_json(filename: str) -> Plan:
 def load(id: str | None) -> Plan:
     settings = Settings()
     assert id is not None
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     az = av.__airavata_token__(av.access_token, av.default_gateway_id())
     assert az.accessToken is not None
     assert az.claimsMap is not None
@@ -185,7 +186,7 @@ def load(id: str | None) -> Plan:
     
 def query() -> list[Plan]:
     settings = Settings()
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     az = av.__airavata_token__(av.access_token, av.default_gateway_id())
     assert az.accessToken is not None
     assert az.claimsMap is not None
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py 
b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py
index 054f94ec82..f5c40e3c86 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py
@@ -18,10 +18,11 @@ from __future__ import annotations
 import abc
 from typing import Any, Literal
 from pathlib import Path
-import os
 
 import pydantic
 
+from airavata_auth.device_auth import AuthContext
+
 # from .task import Task
 Task = Any
 States = Literal[
@@ -163,7 +164,7 @@ class Remote(Runtime):
     print(f"[Remote] Creating Experiment: name={task.name}")
 
     from .airavata import AiravataOperator
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     try:
       launch_state = av.launch_experiment(
           experiment_name=task.name,
@@ -195,7 +196,7 @@ class Remote(Runtime):
     assert task.workdir is not None
 
     from .airavata import AiravataOperator
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     try:
       result = av.execute_cmd(task.agent_ref, cmd)
       return result
@@ -210,7 +211,7 @@ class Remote(Runtime):
     assert task.pid is not None
 
     from .airavata import AiravataOperator
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     result = av.execute_py(task.project, libraries, code, task.agent_ref, 
task.pid, task.runtime.args)
     print(result)
 
@@ -219,7 +220,7 @@ class Remote(Runtime):
     assert task.agent_ref is not None
 
     from .airavata import AiravataOperator
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     # prioritize job state, fallback to experiment state
     job_id, job_state = av.get_task_status(task.ref)
     if job_state in [AiravataOperator.JobState.UNKNOWN, 
AiravataOperator.JobState.NON_CRITICAL_FAIL]:
@@ -232,7 +233,7 @@ class Remote(Runtime):
     assert task.agent_ref is not None
     
     from .airavata import AiravataOperator
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     av.stop_experiment(task.ref)
 
   def ls(self, task: Task) -> list[str]:
@@ -243,7 +244,7 @@ class Remote(Runtime):
     assert task.workdir is not None
 
     from .airavata import AiravataOperator
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     files = av.list_files(task.pid, task.agent_ref, task.sr_host, task.workdir)
     return files
 
@@ -255,7 +256,7 @@ class Remote(Runtime):
     assert task.workdir is not None
 
     from .airavata import AiravataOperator
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     result = av.upload_files(task.pid, task.agent_ref, task.sr_host, [file], 
task.workdir).pop()
     return result
 
@@ -267,7 +268,7 @@ class Remote(Runtime):
     assert task.workdir is not None
 
     from .airavata import AiravataOperator
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     result = av.download_file(task.pid, task.agent_ref, task.sr_host, file, 
task.workdir, local_dir)
     return result
 
@@ -279,7 +280,7 @@ class Remote(Runtime):
     assert task.workdir is not None
 
     from .airavata import AiravataOperator
-    av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
+    av = AiravataOperator(AuthContext.get_access_token())
     content = av.cat_file(task.pid, task.agent_ref, task.sr_host, file, 
task.workdir)
     return content
 
@@ -287,22 +288,36 @@ class Remote(Runtime):
 def find_runtimes(
     cluster: str | None = None,
     category: str | None = None,
-    group: str | None = None,
     node_count: int | None = None,
     cpu_count: int | None = None,
-    walltime: int | None = None,
+    group: str | None = None,
 ) -> list[Runtime]:
   from .airavata import AiravataOperator
-  av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
-  all_runtimes = av.get_available_runtimes()
-  out_runtimes = []
-  for r in all_runtimes:
-    if (cluster in [None, r.args["cluster"]]) and (category in [None, 
r.args["category"]]) and (group in [None, r.args["group"]]):
-      r.args["node_count"] = node_count or r.args["node_count"]
-      r.args["cpu_count"] = cpu_count or r.args["cpu_count"]
-      r.args["walltime"] = walltime or r.args["walltime"]
-      out_runtimes.append(r)
-  return out_runtimes
+  av = AiravataOperator(AuthContext.get_access_token())
+  grps = av.get_available_groups()
+  grp_names = [str(x.groupResourceProfileName) for x in grps]
+  if group is not None:
+    assert group in grp_names, f"Group {group} was not found. Available 
groups: {repr(grp_names)}"
+    groups = [g for g in grps if str(g.groupResourceProfileName) == group]
+  else:
+    groups = grps
+  runtimes = []
+  for g in groups:
+    matched_runtimes = []
+    assert g.groupResourceProfileName is not None, f"Group {g} has no name"
+    r: Runtime
+    for r in av.get_available_runtimes(group=g.groupResourceProfileName):
+      if (node_count or 1) > int(r.args["node_count"]):
+        continue
+      if (cpu_count or 1) > int(r.args["cpu_count"]):
+        continue
+      if (cluster or r.args["cluster"]) != r.args["cluster"]:
+        continue
+      if (category or r.args["category"]) != r.args["category"]:
+        continue
+      matched_runtimes.append(r)
+    runtimes.extend(matched_runtimes)
+  return runtimes
 
 def is_terminal_state(x: States) -> bool:
   return x in ["CANCELED", "COMPLETE", "COMPLETED", "FAILED"]
\ No newline at end of file
diff --git a/dev-tools/airavata-python-sdk/airavata_jupyter_magic/__init__.py 
b/dev-tools/airavata-python-sdk/airavata_jupyter_magic/__init__.py
index 1528304462..ab2ecd8bf8 100644
--- a/dev-tools/airavata-python-sdk/airavata_jupyter_magic/__init__.py
+++ b/dev-tools/airavata-python-sdk/airavata_jupyter_magic/__init__.py
@@ -43,6 +43,7 @@ from rich.live import Live
 from jupyter_client.blocking.client import BlockingKernelClient
 
 from airavata_auth.device_auth import AuthContext
+from airavata_experiments.plan import Plan
 from airavata_sdk import Settings
 
 # ========================================================================
@@ -62,6 +63,7 @@ class RequestedRuntime:
     group: str
     file: str | None
     use: str | None
+    plan: str | None
 
 
 class ProcessState(IntEnum):
@@ -139,23 +141,6 @@ class State:
 # HELPER FUNCTIONS
 
 
-def get_access_token(envar_name: str = "CS_ACCESS_TOKEN", state_path: str = 
"/tmp/av.json") -> str | None:
-    """
-    Get access token from environment or file
-
-    @param None:
-    @returns: access token if present, None otherwise
-
-    """
-    token = os.getenv(envar_name)
-    if not token:
-        try:
-            token = json.load(Path(state_path).open("r")).get("access_token")
-        except (FileNotFoundError, json.JSONDecodeError):
-            pass
-    return token
-
-
 def is_runtime_ready(access_token: str, rt: RuntimeInfo, rt_name: str):
     """
     Check if the runtime (i.e., agent job) is ready to receive requests
@@ -470,7 +455,8 @@ def submit_agent_job(
     memory: int | None = None,
     gpus: int | None = None,
     gpu_memory: int | None = None,
-    file: str | None = None,
+    spec_file: str | None = None,
+    plan_file: str | None = None,
 ) -> None:
     """
     Submit an agent job to the given runtime
@@ -487,7 +473,8 @@ def submit_agent_job(
     @param memory: the memory for cpu (MB)
     @param gpus: the number of gpus (int)
     @param gpu_memory: the memory for gpu (MB)
-    @param file: environment file (path)
+    @param spec_file: environment file (path)
+    @param plan_file: experiment plan file (path)
     @returns: None
 
     """
@@ -506,14 +493,14 @@ def submit_agent_job(
     pip: list[str] = []
 
     # if file is provided, validate it and use the given values as defaults
-    if file is not None:
-        fp = Path(file)
+    if spec_file is not None:
+        fp = Path(spec_file)
         # validation
-        assert fp.exists(), f"File {file} does not exist"
+        assert fp.exists(), f"File {spec_file} does not exist"
         with open(fp, "r") as f:
-            content = yaml.safe_load(f)
+            spec = yaml.safe_load(f)
         # validation: /workspace
-        assert (workspace := content.get("workspace", None)) is not None, 
"missing section: /workspace"
+        assert (workspace := spec.get("workspace", None)) is not None, 
"missing section: /workspace"
         assert (resources := workspace.get("resources", None)) is not None, 
"missing section: /workspace/resources"
         assert (min_cpu := resources.get("min_cpu", None)) is not None, 
"missing section: /workspace/resources/min_cpu"
         assert (min_mem := resources.get("min_mem", None)) is not None, 
"missing section: /workspace/resources/min_mem"
@@ -523,12 +510,18 @@ def submit_agent_job(
         assert (datasets := workspace.get("data_collection", None)) is not 
None, "missing section: /workspace/data_collection"
         collection = models + datasets
         # validation: /additional_dependencies
-        assert (additional_dependencies := 
content.get("additional_dependencies", None)) is not None, "missing section: 
/additional_dependencies"
+        assert (additional_dependencies := spec.get("additional_dependencies", 
None)) is not None, "missing section: /additional_dependencies"
         assert (modules := additional_dependencies.get("modules", None)) is 
not None, "missing /additional_dependencies/modules section"
         assert (conda := additional_dependencies.get("conda", None)) is not 
None, "missing /additional_dependencies/conda section"
         assert (pip := additional_dependencies.get("pip", None)) is not None, 
"missing /additional_dependencies/pip section"
         mounts = [f"{i['identifier']}:{i['mount_point']}" for i in collection]
 
+    if plan_file is not None:
+        assert Path(plan_file).exists(), f"File {plan_file} does not exist"
+        with open(Path(plan_file), "r") as f:
+            plan = yaml.safe_load(f)
+        assert plan.get("experimentId") is not None, "missing experimentId in 
state file"
+
     # payload
     data = {
         'experimentName': app_name,
@@ -1114,7 +1107,7 @@ def request_runtime(line: str):
     Request a runtime with given capabilities
 
     """
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     assert access_token is not None
 
     [rt_name, *cmd_args] = line.strip().split()
@@ -1151,12 +1144,32 @@ def request_runtime(line: str):
     p.add_argument("--group", type=str, help="resource group", required=False, 
default="Default")
     p.add_argument("--file", type=str, help="yml file", required=False)
     p.add_argument("--use", type=str, help="allowed resources", required=False)
+    p.add_argument("--plan", type=str, help="experiment plan file", 
required=False)
 
     args = p.parse_args(cmd_args, namespace=RequestedRuntime())
 
     if args.file is not None:
-        assert args.use is not None
-        cluster, queue  = meta_scheduler(args.use.split(","))
+        assert (args.use or args.plan) is not None
+        if args.use:
+          cluster, queue  = meta_scheduler(args.use.split(","))
+        else:
+          assert args.plan is not None, "--plan is required when --use is not 
provided"
+          assert os.path.exists(args.plan), f"--plan={args.plan} file does not 
exist"
+          assert os.path.isfile(args.plan), f"--plan={args.plan} is not a file"
+          with open(args.plan, "r") as f:
+            plan: Plan = Plan(**json.load(f))
+            clusters = []
+            queues = []
+            for task in plan.tasks:
+              c, q = task.runtime.args.get("cluster"), 
task.runtime.args.get("queue_name")
+              clusters.append(c)
+              queues.append(q)
+            assert len(set(clusters)) == 1, "all tasks must be on the same 
cluster"
+            assert len(set(queues)) == 1, "all tasks must be on the same queue"
+            cluster, queue = clusters[0], queues[0]
+            assert cluster is not None, "cluster is required"
+            assert queue is not None, "queue is required"
+
         submit_agent_job(
             rt_name=rt_name,
             access_token=access_token,
@@ -1166,7 +1179,8 @@ def request_runtime(line: str):
             cluster=cluster,
             queue=queue,
             group=args.group,
-            file=args.file,
+            spec_file=args.file,
+            plan_file=args.plan,
         )
     else:
         assert args.cluster is not None
@@ -1194,7 +1208,7 @@ def stat_runtime(line: str):
     Show the status of the runtime
 
     """
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     assert access_token is not None
 
     rt_name = line.strip()
@@ -1226,7 +1240,7 @@ def wait_for_runtime(line: str):
         rt_name, render_live_logs = parts[0], True
     else:
         raise ValueError("Usage: %wait_for_runtime <rt> [--live]")
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     assert access_token is not None
 
     rt = state.all_runtimes.get(rt_name, None)
@@ -1247,7 +1261,7 @@ def run_subprocess(line: str):
     Run a subprocess asynchronously
 
     """
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     assert access_token is not None
 
     rt_name = state.current_runtime
@@ -1279,7 +1293,7 @@ def kill_subprocess(line: str):
     Kill a running subprocess asynchronously
 
     """
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     assert access_token is not None
 
     rt_name = state.current_runtime
@@ -1308,7 +1322,7 @@ def open_tunnels(line: str):
     Open tunnels to the runtime
 
     """
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     assert access_token is not None
 
     rt_name = state.current_runtime
@@ -1348,7 +1362,7 @@ def close_tunnels(line: str):
     Close tunnels to the runtime
 
     """
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     assert access_token is not None
 
     rt_name = state.current_runtime
@@ -1374,7 +1388,7 @@ def restart_runtime(rt_name: str):
     Restart the runtime
 
     """
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     assert access_token is not None
 
     rt = state.all_runtimes.get(rt_name, None)
@@ -1389,7 +1403,7 @@ def stop_runtime(rt_name: str):
     Stop the runtime
 
     """
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     assert access_token is not None
 
     rt = state.all_runtimes.get(rt_name, None)
@@ -1457,7 +1471,7 @@ def launch_remote_kernel(rt_name: str, base_port: int, 
hostname: str):
     Launch a remote Jupyter kernel, open tunnels, and connect a local Jupyter 
client.
     """
     assert ipython is not None
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     assert access_token is not None
 
     # launch kernel and tunnel ports
@@ -1535,7 +1549,7 @@ def open_web_terminal(line: str):
     cmd = f"ttyd -p {random_port} -i 0.0.0.0 --writable bash"
 
     # Get access token
-    access_token = get_access_token()
+    access_token = AuthContext.get_access_token()
     if access_token is None:
         print("Not authenticated. Please run %authenticate first.")
         return
@@ -1667,7 +1681,7 @@ async def run_cell_async(
         return await orig_run_code(raw_cell, store_history, silent, 
shell_futures, transformed_cell=transformed_cell, 
preprocessing_exc_tuple=preprocessing_exc_tuple, cell_id=cell_id)
     else:
         # Validation: check runtime is ready and kernel is started
-        access_token = get_access_token()
+        access_token = AuthContext.get_access_token()
         assert access_token is not None
         rt_info = state.all_runtimes.get(rt, None)
         if rt_info is None:

Reply via email to