Copilot commented on code in PR #16998:
URL: https://github.com/apache/iotdb/pull/16998#discussion_r2670832349
##########
iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py:
##########
@@ -87,47 +84,40 @@ def _estimate_shared_pool_size_by_total_mem(
pool_num = int(per_model_share // mem_usages[model_info.model_id])
if pool_num <= 0:
logger.warning(
- f"[Inference][Device-{device}] Not enough TOTAL memory to
guarantee at least 1 pool for model {model_info.model_id}, no pool will be
scheduled for this model. "
+ f"[Inference][{device}] Not enough TOTAL memory to guarantee
at least 1 pool for model {model_info.model_id}, no pool will be scheduled for
this model. "
f"Per-model share={per_model_share / 1024 ** 2:.2f} MB,
need>={mem_usages[model_info.model_id] / 1024 ** 2:.2f} MB"
)
allocation[model_info.model_id] = pool_num
logger.info(
- f"[Inference][Device-{device}] Shared pool allocation (by TOTAL
memory): {allocation}"
+ f"[Inference][{device}] Shared pool allocation (by TOTAL memory):
{allocation}"
)
return allocation
class BasicPoolScheduler(AbstractPoolScheduler):
"""
- A basic scheduler to init the request pools. In short, different kind of
models will equally share the available resource of the located device, and
scale down actions are always ahead of scale up.
+ A basic scheduler to init the request pools. In short,
+ different kind of models will equally share the available resource of the
located device,
+ and scale down actions are always ahead of scale up.
"""
- def __init__(self, request_pool_map: Dict[str, Dict[str, PoolGroup]]):
+ def __init__(self, request_pool_map: Dict[str, Dict[torch.device,
PoolGroup]]):
super().__init__(request_pool_map)
self._model_manager = ModelManager()
def schedule(self, model_id: str) -> List[ScaleAction]:
- """
- Schedule a scaling action for the given model_id.
- """
- if model_id not in self._request_pool_map:
- pool_num = estimate_pool_size(self.DEFAULT_DEVICE, model_id)
- if pool_num <= 0:
- raise InferenceModelInternalException(
- f"Not enough memory to run model {model_id}."
- )
- return [ScaleAction(ScaleActionType.SCALE_UP, pool_num, model_id)]
+ pass
Review Comment:
The `schedule` method in `BasicPoolScheduler` is now just `pass`,
effectively removing the previous implementation without replacement. This
breaks the abstract interface contract and could cause issues if this method is
called at runtime.
```suggestion
logger.warning(
"BasicPoolScheduler.schedule called for model_id=%s; returning
no scale actions.",
model_id,
)
return []
```
##########
iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py:
##########
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from contextlib import nullcontext
+
Review Comment:
The `nullcontext` import is unused in this file. Consider removing unused
imports to keep the code clean.
```suggestion
```
##########
iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py:
##########
@@ -144,6 +142,24 @@ def forecast(self, req: TForecastReq) -> TForecastResp:
return TForecastResp(status, [])
return self._inference_manager.forecast(req)
+ # ==================== Internal API ====================
+
+ def _ensure_device_id_is_available(self, device_id_list: list[str]) ->
TSStatus:
+ """
+ Ensure that the device IDs in the provided list are available.
+ """
+ available_devices = self._backend.device_ids()
+ for device_id in device_id_list:
+ try:
+ if device_id != "cpu" and int(device_id) not in
available_devices:
+ raise ValueError(f"Invalid device ID [{device_id}]")
+ except ValueError:
+ return TSStatus(
+ code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value,
+ message=f"AIDevice ID [{device_id}] is not available. You
can use 'SHOW AI_DEVICES' to retrieve the available devices.",
+ )
Review Comment:
The error handling logic has a flaw: when `int(device_id)` raises a
ValueError (e.g., for "cpu" string or invalid formats), the code catches it and
returns an error. However, the condition `device_id != "cpu"` is checked first,
so "cpu" should be skipped. But if a device_id is an invalid string (not "cpu"
and not a valid integer), it will always return an error. The logic should
check if the device_id is valid before attempting to convert it.
```suggestion
# Allow the special "cpu" device without further validation.
if device_id == "cpu":
continue
# Validate that the device ID can be interpreted as an integer.
try:
parsed_id = int(device_id)
except (TypeError, ValueError):
return TSStatus(
code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value,
message=f"AIDevice ID [{device_id}] is not available.
You can use 'SHOW AI_DEVICES' to retrieve the available devices.",
)
# Check that the parsed device ID is in the list of available
devices.
if parsed_id not in available_devices:
return TSStatus(
code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value,
message=f"AIDevice ID [{device_id}] is not available.
You can use 'SHOW AI_DEVICES' to retrieve the available devices.",
)
```
##########
iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py:
##########
@@ -167,7 +168,7 @@ def show_loaded_models(
if device_id in device_map:
pool_group = device_map[device_id]
device_models[model_id] =
pool_group.get_running_pool_count()
- result[device_id] = device_models
+ result[str(device_id.index)] = device_models
Review Comment:
When the device is CPU (`device_id.index` is None for CPU devices), this
code will convert the device_id to the string "None" instead of "cpu". This
will create incorrect dictionary keys like "None" instead of "cpu" in the
result map, potentially causing confusion for API consumers.
```suggestion
device_key = device_id.type if device_id.index is None else
str(device_id.index)
result[device_key] = device_models
```
##########
iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py:
##########
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+from iotdb.ainode.core.device.backend.cpu_backend import CPUBackend
+from iotdb.ainode.core.device.backend.cuda_backend import CUDABackend
+from iotdb.ainode.core.device.device_utils import DeviceLike, parse_device_like
+from iotdb.ainode.core.device.env import DistEnv, read_dist_env
+from iotdb.ainode.core.util.decorator import singleton
+
+
+@dataclass(frozen=True)
+class DeviceManagerConfig:
+ use_local_rank_if_distributed: bool = True
+
+
+@singleton
+class DeviceManager:
+ """
+ Unified device entry point:
+ - Select backend (cuda/npu/cpu)
+ - Parse device expression (None/int/str/torch.device/DeviceSpec)
+ - Provide device, autocast, grad scaler, synchronize, dist backend
recommendation, etc.
+ """
+
+ def __init__(self, cfg: DeviceManagerConfig):
Review Comment:
The `DeviceManager` class is decorated with `@singleton` but is being
instantiated without providing the required `cfg` parameter. The `__init__`
method expects a `DeviceManagerConfig` argument, but all instantiations in the
code call `DeviceManager()` without arguments. This will cause a `TypeError` at
runtime.
```suggestion
def __init__(self, cfg: Optional[DeviceManagerConfig] = None):
if cfg is None:
cfg = DeviceManagerConfig()
```
##########
iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py:
##########
@@ -48,25 +48,12 @@
logger = Logger()
-def _ensure_device_id_is_available(device_id_list: list[str]) -> TSStatus:
- """
- Ensure that the device IDs in the provided list are available.
- """
- available_devices = get_available_devices()
- for device_id in device_id_list:
- if device_id not in available_devices:
- return TSStatus(
- code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value,
- message=f"AIDevice ID [{device_id}] is not available. You can
use 'SHOW AI_DEVICES' to retrieve the available devices.",
- )
- return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)
-
-
class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
def __init__(self, ainode):
self._ainode = ainode
self._model_manager = ModelManager()
self._inference_manager = InferenceManager()
+ self._backend = DeviceManager()
Review Comment:
The `_backend` field name is somewhat misleading since this is actually a
DeviceManager, not a backend adapter. Consider renaming to `_device_manager`
for better clarity and consistency with other manager naming conventions in the
codebase.
##########
iotdb-core/ainode/pyproject.toml:
##########
@@ -76,7 +76,7 @@ exclude = [
]
[tool.poetry.dependencies]
-python = ">=3.11.0,<3.14.0"
+python = ">=3.11.0,<3.12.0"
Review Comment:
The Python version constraint has been changed from `>=3.11.0,<3.14.0` to
`>=3.11.0,<3.12.0`, which significantly restricts supported Python versions.
This means Python 3.12 and 3.13 are no longer supported, which could be
problematic for users who have upgraded to newer Python versions. Consider
whether this restriction is intentional and necessary.
```suggestion
python = ">=3.11.0,<3.14.0"
```
##########
iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py:
##########
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from contextlib import nullcontext
+
+import torch
+
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+
+
+class CPUBackend(BackendAdapter):
+ type = BackendType.CPU
+
+ def is_available(self) -> bool:
+ return True
+
+ def device_count(self) -> int:
+ return 1
+
+ def make_device(self, index: int | None) -> torch.device:
+ return torch.device("cpu")
+
+ def set_device(self, index: int) -> None:
+ return None
Review Comment:
The `BackendAdapter` Protocol defines required methods `synchronize()`,
`autocast()`, `make_grad_scaler()`, `default_dist_backend()`, and
`supports_bf16()`, but the `CPUBackend` class only implements `is_available()`,
`device_count()`, `make_device()`, and `set_device()`. This will cause
`AttributeError` at runtime when these methods are accessed.
```suggestion
def set_device(self, index: int) -> None:
# CPU backend does not require setting a specific device.
return None
def synchronize(self) -> None:
"""Synchronize CPU operations. No-op for CPU backend."""
return None
def autocast(self, enabled: bool = True):
"""
Return a context manager for autocast.
For CPU backend this is a no-op context.
"""
return nullcontext()
def make_grad_scaler(self):
"""
Return a gradient scaler instance if supported.
CPU backend does not use gradient scaling by default.
"""
return None
def default_dist_backend(self) -> str:
"""
Return the default distributed backend for CPU devices.
"""
return "gloo"
def supports_bf16(self) -> bool:
"""
Indicate whether the CPU backend supports bfloat16.
"""
return False
```
##########
iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py:
##########
@@ -0,0 +1,41 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from contextlib import nullcontext
+
+import torch
+
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+
+
+class CUDABackend(BackendAdapter):
+ type = BackendType.CUDA
+
+ def is_available(self) -> bool:
+ return torch.cuda.is_available()
+
+ def device_count(self) -> int:
+ return torch.cuda.device_count()
+
+ def make_device(self, index: int | None) -> torch.device:
+ if index is None:
+ raise ValueError("CUDA backend requires a valid device index")
+ return torch.device(f"cuda:{index}")
+
+ def set_device(self, index: int) -> None:
+ torch.cuda.set_device(index)
Review Comment:
The `BackendAdapter` Protocol defines required methods `synchronize()`,
`autocast()`, `make_grad_scaler()`, `default_dist_backend()`, and
`supports_bf16()`, but the `CUDABackend` class only implements
`is_available()`, `device_count()`, `make_device()`, and `set_device()`. This
will cause `AttributeError` at runtime when these methods are accessed.
```suggestion
torch.cuda.set_device(index)
def synchronize(self) -> None:
"""Synchronize all CUDA kernels on the current device, if
available."""
if torch.cuda.is_available():
torch.cuda.synchronize()
def autocast(self, enabled: bool = True):
"""
Return an autocast context manager for CUDA.
When autocast is disabled or CUDA is not available, this falls back
to a no-op
nullcontext so callers can always use `with backend.autocast():`.
"""
if not enabled or not torch.cuda.is_available():
return nullcontext()
# Use torch.autocast with explicit device_type for CUDA
return torch.autocast(device_type="cuda", enabled=True)
def make_grad_scaler(self, enabled: bool = True):
"""
Create and return a gradient scaler for mixed precision training on
CUDA.
If CUDA AMP is not available or disabled, returns a no-op
nullcontext so
callers can still use the returned object safely.
"""
if not enabled or not torch.cuda.is_available():
return nullcontext()
amp_mod = getattr(torch.cuda, "amp", None)
GradScaler = getattr(amp_mod, "GradScaler", None) if amp_mod is not
None else None
if GradScaler is None:
# Fallback: no AMP support, return a no-op context.
return nullcontext()
return GradScaler(enabled=True)
def default_dist_backend(self) -> str:
"""
Return the default PyTorch distributed backend for CUDA devices.
"""
return "nccl"
def supports_bf16(self) -> bool:
"""
Indicate whether the current environment supports bfloat16 on CUDA.
"""
if not torch.cuda.is_available():
return False
is_bf16_supported = getattr(torch.cuda, "is_bf16_supported", None)
if callable(is_bf16_supported):
try:
return bool(is_bf16_supported())
except TypeError:
# In case of unexpected signature, fall back conservatively.
pass
# Conservative fallback: if we cannot determine support, report
False.
return False
```
##########
iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py:
##########
@@ -75,7 +75,9 @@ def __init__(
self.model_info = model_info
self.pool_kwargs = pool_kwargs
self.ready_event = ready_event
- self.device = convert_device_id_to_torch_device(device)
+ self.device = device
+
+ self._backend = DeviceManager()
Review Comment:
The field name `_backend` is somewhat misleading since this is actually a
DeviceManager, not a backend adapter. Consider renaming to `_device_manager`
for better clarity and consistency with naming conventions.
```suggestion
self._device_manager = DeviceManager()
self._backend = self._device_manager # backwards-compatible alias
```
##########
iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py:
##########
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+from iotdb.ainode.core.device.backend.cpu_backend import CPUBackend
+from iotdb.ainode.core.device.backend.cuda_backend import CUDABackend
+from iotdb.ainode.core.device.device_utils import DeviceLike, parse_device_like
+from iotdb.ainode.core.device.env import DistEnv, read_dist_env
+from iotdb.ainode.core.util.decorator import singleton
+
+
+@dataclass(frozen=True)
+class DeviceManagerConfig:
+ use_local_rank_if_distributed: bool = True
+
+
+@singleton
+class DeviceManager:
+ """
+ Unified device entry point:
+ - Select backend (cuda/npu/cpu)
+ - Parse device expression (None/int/str/torch.device/DeviceSpec)
+ - Provide device, autocast, grad scaler, synchronize, dist backend
recommendation, etc.
+ """
+
+ def __init__(self, cfg: DeviceManagerConfig):
+ self.cfg = cfg
+ self.env: DistEnv = read_dist_env()
+
+ self.backends: dict[BackendType, BackendAdapter] = {
+ BackendType.CUDA: CUDABackend(),
+ BackendType.CPU: CPUBackend(),
+ }
+
+ self.type: BackendType
+ self.backend: BackendAdapter = self._auto_select_backend()
+ self.default_index: Optional[int] = self._select_default_index()
+
+ # ensure process uses correct device early
+ self._set_device_for_process()
+ self.device: torch.device =
self.backend.make_device(self.default_index)
+
+ # ==================== selection ====================
+ def _auto_select_backend(self) -> BackendAdapter:
+ for name in BackendType:
+ backend = self.backends.get(name)
+ if backend is not None and backend.is_available():
+ self.type = backend.type
+ return backend
+ return self.backends[BackendType.CPU]
+
+ def _select_default_index(self) -> Optional[int]:
+ if self.backend.type == BackendType.CPU:
+ return None
+ if self.cfg.use_local_rank_if_distributed and self.env.world_size > 1:
+ return self.env.local_rank
+ return 0
+
+ def _set_device_for_process(self) -> None:
+ if self.backend.type in (BackendType.CUDA) and self.default_index is
not None:
+ self.backend.set_device(self.default_index)
+
+ # ==================== public API ====================
+ def device_ids(self) -> list[int]:
+ """
+ Returns a list of available device IDs for the current backend.
+ """
+ if self.backend.type == BackendType.CPU:
+ return []
+ return list(range(self.backend.device_count()))
+
+ def available_devices_with_cpu(self) -> list[torch.device]:
+ """
+ Returns the list of available torch.devices, including "cpu".
+ """
+ device_id_list = self.device_ids()
+ device_id_list = [self.torch_device(device_id) for device_id in
device_id_list]
+ device_id_list.append(self.torch_device("cpu"))
+ return device_id_list
+
+ def torch_device(self, device: DeviceLike) -> torch.device:
+ """
+ Convert a DeviceLike specification into a torch.device object.
+ If device is None, returns the default device of current process.
+ Args:
+ device: Could be any of the following formats:
+ an integer (e.g., 0, 1, ...),
+ a string (e.g., "0", "cuda:0", "cpu", ...),
+ a torch.device object, return itself if so.
+ """
+ if isinstance(device, torch.device):
+ return device
+ spec = parse_device_like(device)
+ if spec.type == "cpu":
+ return torch.device("cpu")
+ return self.backend.make_device(spec.index)
Review Comment:
The `torch_device` method docstring states "If device is None, returns the
default device of current process," but the implementation doesn't handle the
None case. When device is None, `parse_device_like(None)` will raise a
TypeError. Either the implementation should handle None by returning
`self.device`, or the docstring should be corrected to remove this claim.
##########
iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py:
##########
@@ -71,6 +68,7 @@ class InferenceManager:
def __init__(self):
self._model_manager = ModelManager()
+ self._backend = DeviceManager()
Review Comment:
The field name `_backend` is somewhat misleading since this is actually a
DeviceManager, not a backend adapter. Consider renaming to `_device_manager`
for better clarity and consistency with naming conventions.
```suggestion
self._device_manager = DeviceManager()
```
##########
iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py:
##########
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+from iotdb.ainode.core.device.backend.cpu_backend import CPUBackend
+from iotdb.ainode.core.device.backend.cuda_backend import CUDABackend
+from iotdb.ainode.core.device.device_utils import DeviceLike, parse_device_like
+from iotdb.ainode.core.device.env import DistEnv, read_dist_env
+from iotdb.ainode.core.util.decorator import singleton
+
+
+@dataclass(frozen=True)
+class DeviceManagerConfig:
+ use_local_rank_if_distributed: bool = True
+
+
+@singleton
+class DeviceManager:
+ """
+ Unified device entry point:
+ - Select backend (cuda/npu/cpu)
+ - Parse device expression (None/int/str/torch.device/DeviceSpec)
+ - Provide device, autocast, grad scaler, synchronize, dist backend
recommendation, etc.
+ """
+
+ def __init__(self, cfg: DeviceManagerConfig):
+ self.cfg = cfg
+ self.env: DistEnv = read_dist_env()
+
+ self.backends: dict[BackendType, BackendAdapter] = {
+ BackendType.CUDA: CUDABackend(),
+ BackendType.CPU: CPUBackend(),
+ }
+
+ self.type: BackendType
+ self.backend: BackendAdapter = self._auto_select_backend()
+ self.default_index: Optional[int] = self._select_default_index()
+
+ # ensure process uses correct device early
+ self._set_device_for_process()
+ self.device: torch.device =
self.backend.make_device(self.default_index)
+
+ # ==================== selection ====================
+ def _auto_select_backend(self) -> BackendAdapter:
+ for name in BackendType:
+ backend = self.backends.get(name)
+ if backend is not None and backend.is_available():
+ self.type = backend.type
+ return backend
+ return self.backends[BackendType.CPU]
+
+ def _select_default_index(self) -> Optional[int]:
+ if self.backend.type == BackendType.CPU:
+ return None
+ if self.cfg.use_local_rank_if_distributed and self.env.world_size > 1:
+ return self.env.local_rank
+ return 0
+
+ def _set_device_for_process(self) -> None:
+ if self.backend.type in (BackendType.CUDA) and self.default_index is
not None:
Review Comment:
The tuple with a single element `(BackendType.CUDA)` should be written as
`(BackendType.CUDA,)` with a trailing comma. Without the comma, Python
interprets this as just a parenthesized expression, not a tuple. While the `in`
operator will still work (checking membership in a single BackendType enum
value), this is likely not the intended behavior and could cause confusion.
```suggestion
if self.backend.type in (BackendType.CUDA,) and self.default_index
is not None:
```
##########
iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py:
##########
@@ -0,0 +1,41 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from contextlib import nullcontext
+
Review Comment:
The `nullcontext` import is unused in this file. Consider removing unused
imports to keep the code clean.
```suggestion
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]