This is an automated email from the ASF dual-hosted git repository. yasith pushed a commit to branch sdk-batch-jobs in repository https://gitbox.apache.org/repos/asf/airavata.git
commit 3b84ca6358a2b0f12aba0f6777ec3bed8853749f Author: yasithdev <[email protected]> AuthorDate: Fri Aug 1 06:32:07 2025 -0500 list actual runtimes from API. reorganize auth code. require cpus, nodes, and walltime per run --- .../airavata_auth/device_auth.py | 9 +++ .../airavata_experiments/airavata.py | 48 ++++++++--- .../airavata_experiments/base.py | 38 +++++---- .../airavata_experiments/plan.py | 7 +- .../airavata_experiments/runtime.py | 59 +++++++++----- .../airavata_jupyter_magic/__init__.py | 94 +++++++++++++--------- 6 files changed, 163 insertions(+), 92 deletions(-) diff --git a/dev-tools/airavata-python-sdk/airavata_auth/device_auth.py b/dev-tools/airavata-python-sdk/airavata_auth/device_auth.py index 3037ab1702..944055ad26 100644 --- a/dev-tools/airavata-python-sdk/airavata_auth/device_auth.py +++ b/dev-tools/airavata-python-sdk/airavata_auth/device_auth.py @@ -10,6 +10,13 @@ from airavata_sdk import Settings class AuthContext: + + @staticmethod + def get_access_token(): + if os.environ.get("CS_ACCESS_TOKEN", None) is None: + context = AuthContext() + context.login() + return os.environ["CS_ACCESS_TOKEN"] def __init__(self): self.settings = Settings() @@ -21,6 +28,8 @@ class AuthContext: self.console = Console() def login(self): + if os.environ.get('CS_ACCESS_TOKEN', None) is not None: + return # Step 1: Request device and user code auth_device_url = f"{self.settings.AUTH_SERVER_URL}/realms/{self.settings.AUTH_REALM}/protocol/openid-connect/auth/device" response = requests.post(auth_device_url, data={ diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py index 362dbdcf55..fd83596c08 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py @@ -35,7 +35,8 @@ from airavata.model.security.ttypes import AuthzToken from airavata.model.experiment.ttypes import ExperimentModel, ExperimentType, UserConfigurationDataModel from airavata.model.scheduling.ttypes import ComputationalResourceSchedulingModel from airavata.model.data.replica.ttypes import DataProductModel, DataProductType, DataReplicaLocationModel, ReplicaLocationCategory -from airavata.model.appcatalog.groupresourceprofile.ttypes import GroupResourceProfile +from airavata.model.appcatalog.groupresourceprofile.ttypes import GroupResourceProfile, ResourceType +from airavata.model.appcatalog.computeresource.ttypes import ComputeResourceDescription from airavata.model.status.ttypes import JobStatus, JobState, ExperimentStatus, ExperimentState warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -101,8 +102,8 @@ class AiravataOperator: ) def get_resource_host_id(self, resource_name): - resources: dict = self.api_server_client.get_all_compute_resource_names(self.airavata_token) # type: ignore - resource_id = next((str(k) for k, v in resources.items() if v == resource_name), None) + resources = self.api_server_client.get_all_compute_resource_names(self.airavata_token) + resource_id = next((k for k in resources if k.startswith(resource_name)), None) assert resource_id is not None, f"Compute resource {resource_name} not found" return resource_id @@ -814,14 +815,41 @@ class AiravataOperator: print(f"[av] Remote execution failed! {e}") return None - def get_available_runtimes(self): + def get_available_groups(self, gateway_id: str = "default"): + grps: list[GroupResourceProfile] = self.api_server_client.get_group_resource_list(self.airavata_token, gatewayId=gateway_id) + return grps + + def get_available_runtimes(self, group: str, gateway_id: str = "default"): + grps = self.get_available_groups(gateway_id) + grp_id, gcr_prefs, gcr_policies = next(((x.groupResourceProfileId, x.computePreferences, x.computeResourcePolicies) for x in grps if str(x.groupResourceProfileName).strip() == group.strip()), (None, None, None)) + assert grp_id is not None, f"Group {group} was not found" + assert gcr_prefs is not None, f"Compute preferences for group={grp_id} were not found" + assert gcr_policies is not None, f"Compute policies for group={grp_id} were not found" # type: ignore from .runtime import Remote - return [ - Remote(cluster="login.delta.ncsa.illinois.edu", category="gpu", queue_name="gpuA100x4", node_count=1, cpu_count=10, gpu_count=1, walltime=30, group="Default"), - Remote(cluster="login.expanse.sdsc.edu", category="gpu", queue_name="gpu-shared", node_count=1, cpu_count=10, gpu_count=1, walltime=30, group="Default"), - Remote(cluster="login.expanse.sdsc.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=10, gpu_count=0, walltime=30, group="Default"), - Remote(cluster="anvil.rcac.purdue.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=24, gpu_count=0, walltime=30, group="Default"), - ] + runtimes = [] + for pref in gcr_prefs: + cr = self.api_server_client.get_compute_resource(self.airavata_token, pref.computeResourceId) + assert cr is not None, "Compute resource not found" + assert isinstance(cr, ComputeResourceDescription), "Compute resource is not a ComputeResourceDescription" + assert cr.batchQueues is not None, "Compute resource has no batch queues" + for queue in cr.batchQueues: + if pref.resourceType == ResourceType.SLURM: + policy = next((p for p in gcr_policies if p.computeResourceId == pref.computeResourceId), None) + assert policy is not None, f"Compute resource policy not found for {pref.computeResourceId}" + if queue.queueName not in (policy.allowedBatchQueues or []): + continue + runtime = Remote( + cluster=pref.computeResourceId.split("_")[0], + category="GPU" if "gpu" in queue.queueName.lower() else "CPU", + queue_name=queue.queueName, + node_count=queue.maxNodes or 1, + cpu_count=queue.cpuPerNode or 1, + gpu_count=1 if "gpu" in queue.queueName.lower() else 0, + walltime=queue.maxRunTime or 30, + group=group, + ) + runtimes.append(runtime) + return runtimes def get_task_status(self, experiment_id: str) -> tuple[str, JobState]: job_details: dict[str, JobStatus] = self.api_server_client.get_job_statuses(self.airavata_token, experiment_id) # type: ignore diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/base.py b/dev-tools/airavata-python-sdk/airavata_experiments/base.py index be8ec92b4e..2ba99a3bcb 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/base.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/base.py @@ -85,53 +85,57 @@ class Experiment(Generic[T], abc.ABC): self.resource = resource return self - def add_run(self, use: list[Runtime] = [], name: str | None = None, **kwargs) -> None: + def add_run(self, use: list[Runtime], cpus: int, nodes: int, walltime: int, name: str | None = None, **extra_params) -> None: """ Create a task to run the experiment on a given runtime. """ runtime = random.choice(use) if len(use) > 0 else self.resource uuid_str = str(uuid.uuid4())[:4].upper() - # override walltime if one is provided in kwargs + # override runtime args with given values runtime = runtime.model_copy() - if (w := kwargs.pop("walltime", None)) is not None: - runtime.args["walltime"] = w - # override experiment inputs with inputs provided in kwargs - inputs = self.inputs.copy() - inputs.update(kwargs) + runtime.args["cpu_count"] = cpus + runtime.args["node_count"] = nodes + runtime.args["walltime"] = walltime + # add extra inputs to task inputs + task_inputs = {**self.inputs, **extra_params} # create a task with the given runtime and inputs self.tasks.append( Task( - name=name or f"{self.name}_{uuid_str}", + name=f"{name or self.name}_{uuid_str}", app_id=self.application.app_id, - inputs=inputs, + inputs=task_inputs, runtime=runtime, ) ) print(f"Task created. ({len(self.tasks)} tasks in total)") - def add_sweep(self, *allowed_runtimes: Runtime, **space: list) -> None: + def add_sweep(self, use: list[Runtime], cpus: int, nodes: int, walltime: int, name: str | None = None, **space: list) -> None: """ Add a sweep to the experiment. """ for values in product(space.values()): - runtime = random.choice(allowed_runtimes) if len(allowed_runtimes) > 0 else self.resource + runtime = random.choice(use) if len(use) > 0 else self.resource uuid_str = str(uuid.uuid4())[:4].upper() - + # override runtime args with given values + runtime = runtime.model_copy() + runtime.args["cpu_count"] = cpus + runtime.args["node_count"] = nodes + runtime.args["walltime"] = walltime + # add sweep params to task inputs task_specific_params = dict(zip(space.keys(), values)) agg_inputs = {**self.inputs, **task_specific_params} task_inputs = {k: {"value": agg_inputs[v[0]], "type": v[1]} for k, v in self.input_mapping.items()} - + # create a task with the given runtime and inputs self.tasks.append(Task( - name=f"{self.name}_{uuid_str}", + name=f"{name or self.name}_{uuid_str}", app_id=self.application.app_id, inputs=task_inputs, runtime=runtime or self.resource, )) - def plan(self, **kwargs) -> Plan: - if len(self.tasks) == 0: - self.add_run() + def plan(self) -> Plan: + assert len(self.tasks) > 0, "add_run() must be called before plan() to define runtimes and resources." tasks = [] for t in self.tasks: agg_inputs = {**self.inputs, **t.inputs} diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py index 7f35e3c6d6..4f41135610 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py @@ -25,6 +25,7 @@ from rich.progress import Progress from .runtime import is_terminal_state from .task import Task import uuid +from airavata_auth.device_auth import AuthContext from .airavata import AiravataOperator @@ -131,7 +132,7 @@ class Plan(pydantic.BaseModel): def save(self) -> None: settings = Settings() - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) az = av.__airavata_token__(av.access_token, av.default_gateway_id()) assert az.accessToken is not None assert az.claimsMap is not None @@ -164,7 +165,7 @@ def load_json(filename: str) -> Plan: def load(id: str | None) -> Plan: settings = Settings() assert id is not None - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) az = av.__airavata_token__(av.access_token, av.default_gateway_id()) assert az.accessToken is not None assert az.claimsMap is not None @@ -185,7 +186,7 @@ def load(id: str | None) -> Plan: def query() -> list[Plan]: settings = Settings() - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) az = av.__airavata_token__(av.access_token, av.default_gateway_id()) assert az.accessToken is not None assert az.claimsMap is not None diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py index 054f94ec82..f5c40e3c86 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py @@ -18,10 +18,11 @@ from __future__ import annotations import abc from typing import Any, Literal from pathlib import Path -import os import pydantic +from airavata_auth.device_auth import AuthContext + # from .task import Task Task = Any States = Literal[ @@ -163,7 +164,7 @@ class Remote(Runtime): print(f"[Remote] Creating Experiment: name={task.name}") from .airavata import AiravataOperator - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) try: launch_state = av.launch_experiment( experiment_name=task.name, @@ -195,7 +196,7 @@ class Remote(Runtime): assert task.workdir is not None from .airavata import AiravataOperator - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) try: result = av.execute_cmd(task.agent_ref, cmd) return result @@ -210,7 +211,7 @@ class Remote(Runtime): assert task.pid is not None from .airavata import AiravataOperator - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) result = av.execute_py(task.project, libraries, code, task.agent_ref, task.pid, task.runtime.args) print(result) @@ -219,7 +220,7 @@ class Remote(Runtime): assert task.agent_ref is not None from .airavata import AiravataOperator - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) # prioritize job state, fallback to experiment state job_id, job_state = av.get_task_status(task.ref) if job_state in [AiravataOperator.JobState.UNKNOWN, AiravataOperator.JobState.NON_CRITICAL_FAIL]: @@ -232,7 +233,7 @@ class Remote(Runtime): assert task.agent_ref is not None from .airavata import AiravataOperator - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) av.stop_experiment(task.ref) def ls(self, task: Task) -> list[str]: @@ -243,7 +244,7 @@ class Remote(Runtime): assert task.workdir is not None from .airavata import AiravataOperator - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) files = av.list_files(task.pid, task.agent_ref, task.sr_host, task.workdir) return files @@ -255,7 +256,7 @@ class Remote(Runtime): assert task.workdir is not None from .airavata import AiravataOperator - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) result = av.upload_files(task.pid, task.agent_ref, task.sr_host, [file], task.workdir).pop() return result @@ -267,7 +268,7 @@ class Remote(Runtime): assert task.workdir is not None from .airavata import AiravataOperator - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) result = av.download_file(task.pid, task.agent_ref, task.sr_host, file, task.workdir, local_dir) return result @@ -279,7 +280,7 @@ class Remote(Runtime): assert task.workdir is not None from .airavata import AiravataOperator - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + av = AiravataOperator(AuthContext.get_access_token()) content = av.cat_file(task.pid, task.agent_ref, task.sr_host, file, task.workdir) return content @@ -287,22 +288,36 @@ class Remote(Runtime): def find_runtimes( cluster: str | None = None, category: str | None = None, - group: str | None = None, node_count: int | None = None, cpu_count: int | None = None, - walltime: int | None = None, + group: str | None = None, ) -> list[Runtime]: from .airavata import AiravataOperator - av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) - all_runtimes = av.get_available_runtimes() - out_runtimes = [] - for r in all_runtimes: - if (cluster in [None, r.args["cluster"]]) and (category in [None, r.args["category"]]) and (group in [None, r.args["group"]]): - r.args["node_count"] = node_count or r.args["node_count"] - r.args["cpu_count"] = cpu_count or r.args["cpu_count"] - r.args["walltime"] = walltime or r.args["walltime"] - out_runtimes.append(r) - return out_runtimes + av = AiravataOperator(AuthContext.get_access_token()) + grps = av.get_available_groups() + grp_names = [str(x.groupResourceProfileName) for x in grps] + if group is not None: + assert group in grp_names, f"Group {group} was not found. Available groups: {repr(grp_names)}" + groups = [g for g in grps if str(g.groupResourceProfileName) == group] + else: + groups = grps + runtimes = [] + for g in groups: + matched_runtimes = [] + assert g.groupResourceProfileName is not None, f"Group {g} has no name" + r: Runtime + for r in av.get_available_runtimes(group=g.groupResourceProfileName): + if (node_count or 1) > int(r.args["node_count"]): + continue + if (cpu_count or 1) > int(r.args["cpu_count"]): + continue + if (cluster or r.args["cluster"]) != r.args["cluster"]: + continue + if (category or r.args["category"]) != r.args["category"]: + continue + matched_runtimes.append(r) + runtimes.extend(matched_runtimes) + return runtimes def is_terminal_state(x: States) -> bool: return x in ["CANCELED", "COMPLETE", "COMPLETED", "FAILED"] \ No newline at end of file diff --git a/dev-tools/airavata-python-sdk/airavata_jupyter_magic/__init__.py b/dev-tools/airavata-python-sdk/airavata_jupyter_magic/__init__.py index 1528304462..ab2ecd8bf8 100644 --- a/dev-tools/airavata-python-sdk/airavata_jupyter_magic/__init__.py +++ b/dev-tools/airavata-python-sdk/airavata_jupyter_magic/__init__.py @@ -43,6 +43,7 @@ from rich.live import Live from jupyter_client.blocking.client import BlockingKernelClient from airavata_auth.device_auth import AuthContext +from airavata_experiments.plan import Plan from airavata_sdk import Settings # ======================================================================== @@ -62,6 +63,7 @@ class RequestedRuntime: group: str file: str | None use: str | None + plan: str | None class ProcessState(IntEnum): @@ -139,23 +141,6 @@ class State: # HELPER FUNCTIONS -def get_access_token(envar_name: str = "CS_ACCESS_TOKEN", state_path: str = "/tmp/av.json") -> str | None: - """ - Get access token from environment or file - - @param None: - @returns: access token if present, None otherwise - - """ - token = os.getenv(envar_name) - if not token: - try: - token = json.load(Path(state_path).open("r")).get("access_token") - except (FileNotFoundError, json.JSONDecodeError): - pass - return token - - def is_runtime_ready(access_token: str, rt: RuntimeInfo, rt_name: str): """ Check if the runtime (i.e., agent job) is ready to receive requests @@ -470,7 +455,8 @@ def submit_agent_job( memory: int | None = None, gpus: int | None = None, gpu_memory: int | None = None, - file: str | None = None, + spec_file: str | None = None, + plan_file: str | None = None, ) -> None: """ Submit an agent job to the given runtime @@ -487,7 +473,8 @@ def submit_agent_job( @param memory: the memory for cpu (MB) @param gpus: the number of gpus (int) @param gpu_memory: the memory for gpu (MB) - @param file: environment file (path) + @param spec_file: environment file (path) + @param plan_file: experiment plan file (path) @returns: None """ @@ -506,14 +493,14 @@ def submit_agent_job( pip: list[str] = [] # if file is provided, validate it and use the given values as defaults - if file is not None: - fp = Path(file) + if spec_file is not None: + fp = Path(spec_file) # validation - assert fp.exists(), f"File {file} does not exist" + assert fp.exists(), f"File {spec_file} does not exist" with open(fp, "r") as f: - content = yaml.safe_load(f) + spec = yaml.safe_load(f) # validation: /workspace - assert (workspace := content.get("workspace", None)) is not None, "missing section: /workspace" + assert (workspace := spec.get("workspace", None)) is not None, "missing section: /workspace" assert (resources := workspace.get("resources", None)) is not None, "missing section: /workspace/resources" assert (min_cpu := resources.get("min_cpu", None)) is not None, "missing section: /workspace/resources/min_cpu" assert (min_mem := resources.get("min_mem", None)) is not None, "missing section: /workspace/resources/min_mem" @@ -523,12 +510,18 @@ def submit_agent_job( assert (datasets := workspace.get("data_collection", None)) is not None, "missing section: /workspace/data_collection" collection = models + datasets # validation: /additional_dependencies - assert (additional_dependencies := content.get("additional_dependencies", None)) is not None, "missing section: /additional_dependencies" + assert (additional_dependencies := spec.get("additional_dependencies", None)) is not None, "missing section: /additional_dependencies" assert (modules := additional_dependencies.get("modules", None)) is not None, "missing /additional_dependencies/modules section" assert (conda := additional_dependencies.get("conda", None)) is not None, "missing /additional_dependencies/conda section" assert (pip := additional_dependencies.get("pip", None)) is not None, "missing /additional_dependencies/pip section" mounts = [f"{i['identifier']}:{i['mount_point']}" for i in collection] + if plan_file is not None: + assert Path(plan_file).exists(), f"File {plan_file} does not exist" + with open(Path(plan_file), "r") as f: + plan = yaml.safe_load(f) + assert plan.get("experimentId") is not None, "missing experimentId in state file" + # payload data = { 'experimentName': app_name, @@ -1114,7 +1107,7 @@ def request_runtime(line: str): Request a runtime with given capabilities """ - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None [rt_name, *cmd_args] = line.strip().split() @@ -1151,12 +1144,32 @@ def request_runtime(line: str): p.add_argument("--group", type=str, help="resource group", required=False, default="Default") p.add_argument("--file", type=str, help="yml file", required=False) p.add_argument("--use", type=str, help="allowed resources", required=False) + p.add_argument("--plan", type=str, help="experiment plan file", required=False) args = p.parse_args(cmd_args, namespace=RequestedRuntime()) if args.file is not None: - assert args.use is not None - cluster, queue = meta_scheduler(args.use.split(",")) + assert (args.use or args.plan) is not None + if args.use: + cluster, queue = meta_scheduler(args.use.split(",")) + else: + assert args.plan is not None, "--plan is required when --use is not provided" + assert os.path.exists(args.plan), f"--plan={args.plan} file does not exist" + assert os.path.isfile(args.plan), f"--plan={args.plan} is not a file" + with open(args.plan, "r") as f: + plan: Plan = Plan(**json.load(f)) + clusters = [] + queues = [] + for task in plan.tasks: + c, q = task.runtime.args.get("cluster"), task.runtime.args.get("queue_name") + clusters.append(c) + queues.append(q) + assert len(set(clusters)) == 1, "all tasks must be on the same cluster" + assert len(set(queues)) == 1, "all tasks must be on the same queue" + cluster, queue = clusters[0], queues[0] + assert cluster is not None, "cluster is required" + assert queue is not None, "queue is required" + submit_agent_job( rt_name=rt_name, access_token=access_token, @@ -1166,7 +1179,8 @@ def request_runtime(line: str): cluster=cluster, queue=queue, group=args.group, - file=args.file, + spec_file=args.file, + plan_file=args.plan, ) else: assert args.cluster is not None @@ -1194,7 +1208,7 @@ def stat_runtime(line: str): Show the status of the runtime """ - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None rt_name = line.strip() @@ -1226,7 +1240,7 @@ def wait_for_runtime(line: str): rt_name, render_live_logs = parts[0], True else: raise ValueError("Usage: %wait_for_runtime <rt> [--live]") - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None rt = state.all_runtimes.get(rt_name, None) @@ -1247,7 +1261,7 @@ def run_subprocess(line: str): Run a subprocess asynchronously """ - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None rt_name = state.current_runtime @@ -1279,7 +1293,7 @@ def kill_subprocess(line: str): Kill a running subprocess asynchronously """ - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None rt_name = state.current_runtime @@ -1308,7 +1322,7 @@ def open_tunnels(line: str): Open tunnels to the runtime """ - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None rt_name = state.current_runtime @@ -1348,7 +1362,7 @@ def close_tunnels(line: str): Close tunnels to the runtime """ - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None rt_name = state.current_runtime @@ -1374,7 +1388,7 @@ def restart_runtime(rt_name: str): Restart the runtime """ - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None rt = state.all_runtimes.get(rt_name, None) @@ -1389,7 +1403,7 @@ def stop_runtime(rt_name: str): Stop the runtime """ - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None rt = state.all_runtimes.get(rt_name, None) @@ -1457,7 +1471,7 @@ def launch_remote_kernel(rt_name: str, base_port: int, hostname: str): Launch a remote Jupyter kernel, open tunnels, and connect a local Jupyter client. """ assert ipython is not None - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None # launch kernel and tunnel ports @@ -1535,7 +1549,7 @@ def open_web_terminal(line: str): cmd = f"ttyd -p {random_port} -i 0.0.0.0 --writable bash" # Get access token - access_token = get_access_token() + access_token = AuthContext.get_access_token() if access_token is None: print("Not authenticated. Please run %authenticate first.") return @@ -1667,7 +1681,7 @@ async def run_cell_async( return await orig_run_code(raw_cell, store_history, silent, shell_futures, transformed_cell=transformed_cell, preprocessing_exc_tuple=preprocessing_exc_tuple, cell_id=cell_id) else: # Validation: check runtime is ready and kernel is started - access_token = get_access_token() + access_token = AuthContext.get_access_token() assert access_token is not None rt_info = state.all_runtimes.get(rt, None) if rt_info is None:
