This is an automated email from the ASF dual-hosted git repository. pingsutw pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push: new 1fb6652 SUBMARINE-1045. Add static type parameter in submarine-sdk 1fb6652 is described below commit 1fb665277a55f76a37aa23244b150fcd42683d9d Author: rayray2002 <rayray2002.hu...@gmail.com> AuthorDate: Mon Oct 25 15:14:41 2021 +0800 SUBMARINE-1045. Add static type parameter in submarine-sdk ### What is this PR for? <!-- A few sentences describing the overall goals of the pull request's commits. First time? Check out the contributing guide - https://submarine.apache.org/contribution/contributions.html --> Add static type parameter in submarine-sdk ### What type of PR is it? [Improvement] ### Todos ### What is the Jira issue? <!-- * Open an issue on Jira https://issues.apache.org/jira/browse/SUBMARINE/ * Put link here, and add [SUBMARINE-*Jira number*] in PR title, eg. `SUBMARINE-23. PR title` --> https://issues.apache.org/jira/projects/SUBMARINE/issues/SUBMARINE-1045 ### How should this be tested? <!-- * First time? Setup Travis CI as described on https://submarine.apache.org/contribution/contributions.html#continuous-integration * Strongly recommended: add automated unit tests for any new or changed behavior * Outline any manual steps to test the PR here. --> mypy passed runtime type check ### Screenshots (if appropriate) ### Questions: * Do the license files need updating? Yes/No * Are there breaking changes for older versions? Yes/No * Does this need new documentation? Yes/No Author: rayray2002 <rayray2002.hu...@gmail.com> Signed-off-by: Kevin <pings...@apache.org> Closes #781 from rayray2002/SUBMARINE-1045 and squashes the following commits: a846b3cb [rayray2002] SUBMARINE-1045. Add static type parameter in submarine-sdk 53256dde [rayray2002] SUBMARINE-1045. Add static type parameter in submarine-sdk --- .../pysubmarine/submarine/artifacts/repository.py | 8 ++++---- .../submarine/entities/_submarine_object.py | 8 ++++---- .../submarine/entities/model_registry/model_stages.py | 2 +- .../submarine/entities/model_registry/model_version.py | 2 +- .../entities/model_registry/registered_model.py | 2 +- .../submarine/experiment/api/experiment_client.py | 6 +++--- .../submarine/experiment/models/code_spec.py | 16 +++++++++------- submarine-sdk/pysubmarine/submarine/experiment/rest.py | 4 ++-- .../pysubmarine/submarine/ml/pytorch/layers/core.py | 14 +++++++------- submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py | 2 +- .../pysubmarine/submarine/ml/pytorch/metric.py | 2 +- .../pysubmarine/submarine/ml/tensorflow/input/input.py | 8 ++++---- .../pysubmarine/submarine/ml/tensorflow/layers/core.py | 18 ++++++++++-------- .../pysubmarine/submarine/ml/tensorflow/optimizer.py | 2 +- submarine-sdk/pysubmarine/submarine/models/client.py | 18 +++++++++--------- submarine-sdk/pysubmarine/submarine/models/pytorch.py | 2 +- .../pysubmarine/submarine/models/tensorflow.py | 2 +- .../submarine/store/tracking/sqlalchemy_store.py | 2 +- submarine-sdk/pysubmarine/submarine/utils/__init__.py | 2 +- submarine-sdk/pysubmarine/submarine/utils/db_utils.py | 6 +++--- submarine-sdk/pysubmarine/submarine/utils/env.py | 12 ++++++------ .../pysubmarine/submarine/utils/rest_utils.py | 4 ++-- .../pysubmarine/submarine/utils/validation.py | 14 +++++++------- 23 files changed, 80 insertions(+), 76 deletions(-) diff --git a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py index 5a60648..9bee7ae 100644 --- a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py +++ b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py @@ -19,7 +19,7 @@ import boto3 class Repository: - def __init__(self, experiment_id): + def __init__(self, experiment_id: str): self.client = boto3.client( "s3", aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), @@ -28,10 +28,10 @@ class Repository: ) self.dest_path = experiment_id - def _upload_file(self, local_file, bucket, key): + def _upload_file(self, local_file: str, bucket: str, key: str) -> None: self.client.upload_file(Filename=local_file, Bucket=bucket, Key=key) - def _list_artifact_subfolder(self, artifact_path): + def _list_artifact_subfolder(self, artifact_path: str): response = self.client.list_objects( Bucket="submarine", Prefix=os.path.join(self.dest_path, artifact_path) + "/", @@ -39,7 +39,7 @@ class Repository: ) return response.get("CommonPrefixes") - def log_artifact(self, local_file, artifact_path): + def log_artifact(self, local_file: str, artifact_path: str) -> None: bucket = "submarine" dest_path = self.dest_path dest_path = os.path.join(dest_path, artifact_path) diff --git a/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py b/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py index ffcea7f..db2ad09 100644 --- a/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py +++ b/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py @@ -31,11 +31,11 @@ class _SubmarineObject: filtered_dict = {key: value for key, value in the_dict.items() if key in cls._properties()} return cls(**filtered_dict) - def __repr__(self): + def __repr__(self) -> str: return to_string(self) -def to_string(obj): +def to_string(obj) -> str: return _SubmarineObjectPrinter().to_string(obj) @@ -48,10 +48,10 @@ class _SubmarineObjectPrinter: super(_SubmarineObjectPrinter, self).__init__() self.printer = pprint.PrettyPrinter() - def to_string(self, obj): + def to_string(self, obj) -> str: if isinstance(obj, _SubmarineObject): return "<%s: %s>" % (get_classname(obj), self._entity_to_string(obj)) return self.printer.pformat(obj) - def _entity_to_string(self, entity): + def _entity_to_string(self, entity) -> str: return ", ".join(["%s=%s" % (key, self.to_string(value)) for key, value in entity]) diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py index 4a5e565..3d3f556 100644 --- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py +++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py @@ -26,7 +26,7 @@ ALL_STAGES = [STAGE_NONE, STAGE_DEVELOPING, STAGE_PRODUCTION, STAGE_ARCHIVED] _CANONICAL_MAPPING = {stage.lower(): stage for stage in ALL_STAGES} -def get_canonical_stage(stage): +def get_canonical_stage(stage: str) -> str: key = stage.lower() if key not in _CANONICAL_MAPPING: raise SubmarineException(f"Invalid Model Version stage {stage}.") diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py index 86652b6..0b43e0a 100644 --- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py +++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py @@ -98,6 +98,6 @@ class ModelVersion(_SubmarineObject): return self._description @property - def tags(self): + def tags(self) -> list: """List of strings.""" return self._tags diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py index b88ac22..f94c5a1 100644 --- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py +++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py @@ -49,6 +49,6 @@ class RegisteredModel(_SubmarineObject): return self._description @property - def tags(self): + def tags(self) -> list: """List of strings""" return self._tags diff --git a/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py b/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py index d659876..f54481e 100644 --- a/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py +++ b/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py @@ -38,7 +38,7 @@ def generate_host(): class ExperimentClient: - def __init__(self, host=generate_host()): + def __init__(self, host: str = generate_host()): """ Submarine experiment client constructor :param host: An HTTP URI like http://submarine-server:8080. @@ -59,7 +59,7 @@ class ExperimentClient: response = self.experiment_api.create_experiment(experiment_spec=experiment_spec) return response.result - def wait_for_finish(self, id, polling_interval=10): + def wait_for_finish(self, id, polling_interval: float = 10): """ Waits until experiment is finished or failed :param id: submarine experiment id @@ -75,7 +75,7 @@ class ExperimentClient: index = self._log_pod(id, index) time.sleep(polling_interval) - def _log_pod(self, id, index): + def _log_pod(self, id, index: int): response = self.experiment_api.get_log(id) log_contents = response.result["logContent"] if len(log_contents) == 0: diff --git a/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py b/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py index b319f90..06e5f0f 100644 --- a/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py +++ b/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py @@ -52,7 +52,9 @@ class CodeSpec(object): attribute_map = {"sync_mode": "syncMode", "url": "url"} - def __init__(self, sync_mode=None, url=None, local_vars_configuration=None): # noqa: E501 + def __init__( + self, sync_mode: str = None, url: str = None, local_vars_configuration: Configuration = None + ): # noqa: E501 """CodeSpec - a model defined in OpenAPI""" # noqa: E501 if local_vars_configuration is None: local_vars_configuration = Configuration() @@ -78,7 +80,7 @@ class CodeSpec(object): return self._sync_mode @sync_mode.setter - def sync_mode(self, sync_mode): + def sync_mode(self, sync_mode: str) -> None: """Sets the sync_mode of this CodeSpec. @@ -99,7 +101,7 @@ class CodeSpec(object): return self._url @url.setter - def url(self, url): + def url(self, url: str) -> None: """Sets the url of this CodeSpec. @@ -135,22 +137,22 @@ class CodeSpec(object): return result - def to_str(self): + def to_str(self) -> str: """Returns the string representation of the model""" return pprint.pformat(self.to_dict()) - def __repr__(self): + def __repr__(self) -> str: """For `print` and `pprint`""" return self.to_str() - def __eq__(self, other): + def __eq__(self, other) -> bool: """Returns true if both objects are equal""" if not isinstance(other, CodeSpec): return False return self.to_dict() == other.to_dict() - def __ne__(self, other): + def __ne__(self, other) -> bool: """Returns true if both objects are not equal""" if not isinstance(other, CodeSpec): return True diff --git a/submarine-sdk/pysubmarine/submarine/experiment/rest.py b/submarine-sdk/pysubmarine/submarine/experiment/rest.py index 06b8d93..18fd44d 100644 --- a/submarine-sdk/pysubmarine/submarine/experiment/rest.py +++ b/submarine-sdk/pysubmarine/submarine/experiment/rest.py @@ -55,13 +55,13 @@ class RESTResponse(io.IOBase): """Returns a dictionary of the response headers.""" return self.urllib3_response.getheaders() - def getheader(self, name, default=None): + def getheader(self, name: str, default=None): """Returns a given response header.""" return self.urllib3_response.getheader(name, default) class RESTClientObject(object): - def __init__(self, configuration, pools_size=4, maxsize=None): + def __init__(self, configuration, pools_size: int = 4, maxsize: int = None): # urllib3.PoolManager will pass all kw parameters to connectionpool # https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/poolmanager.py#L75 # noqa: E501 # https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/connectionpool.py#L680 # noqa: E501 diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py index 265ea1f..fd7d9e9 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py @@ -19,7 +19,7 @@ from torch import nn # pylint: disable=W0223 class FeatureLinear(nn.Module): - def __init__(self, num_features, out_features): + def __init__(self, num_features: int, out_features: int): """ :param num_features: number of total features. :param out_features: The number of output features. @@ -28,7 +28,7 @@ class FeatureLinear(nn.Module): self.weight = nn.Embedding(num_embeddings=num_features, embedding_dim=out_features) self.bias = nn.Parameter(torch.zeros((out_features,))) - def forward(self, feature_idx, feature_value): + def forward(self, feature_idx: torch.LongTensor, feature_value: torch.LongTensor): """ :param feature_idx: torch.LongTensor (batch_size, num_fields) :param feature_value: torch.LongTensor (batch_size, num_fields) @@ -39,11 +39,11 @@ class FeatureLinear(nn.Module): class FeatureEmbedding(nn.Module): - def __init__(self, num_features, embedding_dim): + def __init__(self, num_features: int, embedding_dim): super().__init__() self.weight = nn.Embedding(num_embeddings=num_features, embedding_dim=embedding_dim) - def forward(self, feature_idx, feature_value): + def forward(self, feature_idx: torch.LongTensor, feature_value: torch.LongTensor): """ :param feature_idx: torch.LongTensor (batch_size, num_fields) :param feature_value: torch.LongTensor (batch_size, num_fields) @@ -52,7 +52,7 @@ class FeatureEmbedding(nn.Module): class PairwiseInteraction(nn.Module): - def forward(self, x): + def forward(self, x: torch.Tensor): """ :param x: torch.Tensor (batch_size, num_fields, embedding_dim) """ @@ -65,7 +65,7 @@ class PairwiseInteraction(nn.Module): class DNN(nn.Module): - def __init__(self, in_features, out_features, hidden_units, dropout_rates): + def __init__(self, in_features: int, out_features: int, hidden_units, dropout_rates): super().__init__() *layers, out_layer = list(zip([in_features, *hidden_units], [*hidden_units, out_features])) self.net = nn.Sequential( @@ -81,7 +81,7 @@ class DNN(nn.Module): nn.Linear(*out_layer) ) - def forward(self, x): + def forward(self, x: torch.FloatTensor): """ :param x: torch.FloatTensor (batch_size, in_features) """ diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py index 234c237..952863f 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py @@ -23,7 +23,7 @@ class LossKey: BCEWithLogitsLoss = "BCEWithLogitsLoss".lower() -def get_loss_fn(key): +def get_loss_fn(key: str): key = key.lower() if key == LossKey.BCELoss: return nn.BCELoss diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py index 43f3d26..0c838fd 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py @@ -24,7 +24,7 @@ class MetricKey: RECALL = "recall" -def get_metric_fn(key): +def get_metric_fn(key: str): key = key.lower() if key == MetricKey.F1_SCORE: return metrics.f1_score diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py index f779fb6..0cc7259 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py +++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py @@ -24,10 +24,10 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE def libsvm_input_fn( filepath, - batch_size=256, - num_epochs=3, # pylint: disable=W0613 - perform_shuffle=False, - delimiter=" ", + batch_size: int = 256, + num_epochs: int = 3, # pylint: disable=W0613 + perform_shuffle: bool = False, + delimiter: str = " ", **kwargs ): def _input_fn(): diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py index dbe048f..47f3698 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py +++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py @@ -43,12 +43,12 @@ def batch_norm_layer(x, train_phase, scope_bn, batch_norm_decay): def dnn_layer( inputs, - estimator_mode, - batch_norm, - deep_layers, + estimator_mode: str, + batch_norm: bool, + deep_layers: list, dropout, - batch_norm_decay=0.9, - l2_reg=0, + batch_norm_decay: float = 0.9, + l2_reg: float = 0, **kwargs ): """ @@ -100,7 +100,7 @@ def dnn_layer( return deep_out -def linear_layer(features, feature_size, field_size, l2_reg=0, **kwargs): +def linear_layer(features, feature_size, field_size, l2_reg: float = 0, **kwargs): """ Layer which represents linear function. :param features: input features @@ -131,7 +131,9 @@ def linear_layer(features, feature_size, field_size, l2_reg=0, **kwargs): return linear_out -def embedding_layer(features, feature_size, field_size, embedding_size, l2_reg=0, **kwargs): +def embedding_layer( + features, feature_size, field_size, embedding_size, l2_reg: float = 0, **kwargs +): """ Turns positive integers (indexes) into dense vectors of fixed size. eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]] @@ -199,7 +201,7 @@ class KMaxPooling(Layer): - **axis**: positive integer, the dimension to look for elements. """ - def __init__(self, k=1, axis=-1, **kwargs): + def __init__(self, k: int = 1, axis: int = -1, **kwargs): self.dims = 1 self.k = k diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py index dd61d6e..3ab9bbb 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py +++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py @@ -29,7 +29,7 @@ class OptimizerKey(object): FTRL = "ftrl" -def get_optimizer(optimizer_key, learning_rate): +def get_optimizer(optimizer_key: str, learning_rate: float): optimizer_key = optimizer_key.lower() if optimizer_key == OptimizerKey.ADAM: diff --git a/submarine-sdk/pysubmarine/submarine/models/client.py b/submarine-sdk/pysubmarine/submarine/models/client.py index e633188..da13082 100644 --- a/submarine-sdk/pysubmarine/submarine/models/client.py +++ b/submarine-sdk/pysubmarine/submarine/models/client.py @@ -33,10 +33,10 @@ from .utils import exist_ps, get_job_id, get_worker_index class ModelsClient: def __init__( self, - tracking_uri=None, - registry_uri=None, - aws_access_key_id=None, - aws_secret_access_key=None, + tracking_uri: str = None, + registry_uri: str = None, + aws_access_key_id: str = None, + aws_secret_access_key: str = None, ): """ Set up mlflow server connection, including: s3 endpoint, aws, tracking server @@ -69,26 +69,26 @@ class ModelsClient: experiment_id = self._get_or_create_experiment(experiment_name) return mlflow.start_run(run_name=run_name, experiment_id=experiment_id) - def log_param(self, key, value): + def log_param(self, key: str, value: str): mlflow.log_param(key, value) def log_params(self, params): mlflow.log_params(params) - def log_metric(self, key, value, step=None): + def log_metric(self, key: str, value: str, step=None): mlflow.log_metric(key, value, step) def log_metrics(self, metrics, step=None): mlflow.log_metrics(metrics, step) - def load_model(self, name, version): + def load_model(self, name: str, version: str): model = mlflow.pyfunc.load_model(model_uri=f"models:/{name}/{version}") return model - def update_model(self, name, new_name): + def update_model(self, name: str, new_name: str): self.client.rename_registered_model(name=name, new_name=new_name) - def delete_model(self, name, version): + def delete_model(self, name: str, version: str): self.client.delete_model_version(name=name, version=version) def save_model(self, model_type, model, artifact_path, registered_model_name=None): diff --git a/submarine-sdk/pysubmarine/submarine/models/pytorch.py b/submarine-sdk/pysubmarine/submarine/models/pytorch.py index a143aa5..38cdd57 100644 --- a/submarine-sdk/pysubmarine/submarine/models/pytorch.py +++ b/submarine-sdk/pysubmarine/submarine/models/pytorch.py @@ -18,5 +18,5 @@ import os import torch -def save_model(model, artifact_path): +def save_model(model, artifact_path: str): torch.save(model, os.path.join(artifact_path, "model.pth")) diff --git a/submarine-sdk/pysubmarine/submarine/models/tensorflow.py b/submarine-sdk/pysubmarine/submarine/models/tensorflow.py index fbe5324..a947a91 100644 --- a/submarine-sdk/pysubmarine/submarine/models/tensorflow.py +++ b/submarine-sdk/pysubmarine/submarine/models/tensorflow.py @@ -14,5 +14,5 @@ # limitations under the License. -def save_model(model, artifact_path): +def save_model(model, artifact_path: str): model.save(artifact_path) diff --git a/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py index 23e8a8d..e7e1e3a 100644 --- a/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py +++ b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py @@ -103,7 +103,7 @@ class SqlAlchemyStore(AbstractStore): return make_managed_session @staticmethod - def _save_to_db(session, objs): + def _save_to_db(session, objs: object): """ Store in db """ diff --git a/submarine-sdk/pysubmarine/submarine/utils/__init__.py b/submarine-sdk/pysubmarine/submarine/utils/__init__.py index 4908ba6..1b0f045 100644 --- a/submarine-sdk/pysubmarine/submarine/utils/__init__.py +++ b/submarine-sdk/pysubmarine/submarine/utils/__init__.py @@ -19,7 +19,7 @@ from submarine.exceptions import SubmarineException from submarine.utils.db_utils import get_db_uri, set_db_uri -def extract_db_type_from_uri(db_uri): +def extract_db_type_from_uri(db_uri: str): """ Parse the specified DB URI to extract the database type. Confirm the database type is supported. If a driver is specified, confirm it passes a plausible regex. diff --git a/submarine-sdk/pysubmarine/submarine/utils/db_utils.py b/submarine-sdk/pysubmarine/submarine/utils/db_utils.py index b23ce2d..8fcc9a5 100644 --- a/submarine-sdk/pysubmarine/submarine/utils/db_utils.py +++ b/submarine-sdk/pysubmarine/submarine/utils/db_utils.py @@ -22,14 +22,14 @@ _DB_URI_ENV_VAR = "SUBMARINE_DB_URI" _db_uri = None -def is_db_uri_set(): +def is_db_uri_set() -> bool: """Returns True if the DB URI has been set, False otherwise.""" if _db_uri or env.get_env(_DB_URI_ENV_VAR): return True return False -def set_db_uri(uri): +def set_db_uri(uri: str): """ Set the DB URI. This does not affect the currently active run (if one exists), but takes effect for successive runs. @@ -38,7 +38,7 @@ def set_db_uri(uri): _db_uri = uri -def get_db_uri(): +def get_db_uri() -> str: """ Get the current DB URI. :return: The DB URI. diff --git a/submarine-sdk/pysubmarine/submarine/utils/env.py b/submarine-sdk/pysubmarine/submarine/utils/env.py index 3797efc..110a134 100644 --- a/submarine-sdk/pysubmarine/submarine/utils/env.py +++ b/submarine-sdk/pysubmarine/submarine/utils/env.py @@ -19,22 +19,22 @@ import os from collections.abc import Mapping -def get_env(variable_name): +def get_env(variable_name: str): return os.environ.get(variable_name) -def unset_variable(variable_name): +def unset_variable(variable_name: str) -> None: if variable_name in os.environ: del os.environ[variable_name] -def check_env_exists(variable_name): +def check_env_exists(variable_name: str) -> bool: if variable_name not in os.environ: return False return True -def get_from_json(path, defaultParams): +def get_from_json(path: str, defaultParams: dict): """ If model parameters not specify in Json, use parameter in defaultParams :param path: The json file that specifies the model parameters. @@ -50,7 +50,7 @@ def get_from_json(path, defaultParams): return get_from_dicts(params, defaultParams) -def get_from_dicts(params, defaultParams): +def get_from_dicts(params: dict, defaultParams: dict): """ If model parameters not specify in params, use parameter in defaultParams :param params: parameters which will be merged @@ -71,7 +71,7 @@ def get_from_dicts(params, defaultParams): return dct -def get_from_registry(key, registry): +def get_from_registry(key: str, registry: dict): if hasattr(key, "lower"): key = key.lower() if key in registry: diff --git a/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py b/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py index db71b13..5054af0 100644 --- a/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py +++ b/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py @@ -51,7 +51,7 @@ def http_request(base_url, endpoint, method, json_body, timeout=60, headers=None return result -def _can_parse_as_json(string): +def _can_parse_as_json(string: str) -> bool: try: json.loads(string) return True @@ -59,7 +59,7 @@ def _can_parse_as_json(string): return False -def verify_rest_response(response, endpoint): +def verify_rest_response(response, endpoint: str): """Verify the return code and raise exception if the request was not successful.""" if response.status_code != 200: if _can_parse_as_json(response.text): diff --git a/submarine-sdk/pysubmarine/submarine/utils/validation.py b/submarine-sdk/pysubmarine/submarine/utils/validation.py index 00e2c98..049a873 100644 --- a/submarine-sdk/pysubmarine/submarine/utils/validation.py +++ b/submarine-sdk/pysubmarine/submarine/utils/validation.py @@ -37,7 +37,7 @@ _BAD_CHARACTERS_MESSAGE = ( _UNSUPPORTED_DB_TYPE_MSG = "Supported database engines are {%s}" % ", ".join(DATABASE_ENGINES) -def bad_path_message(name): +def bad_path_message(name: str): return ( "Names may be treated as files in certain cases, and must not resolve to other names" " when treated as such. This name would resolve to '%s'" @@ -45,12 +45,12 @@ def bad_path_message(name): ) -def path_not_unique(name): +def path_not_unique(name: str): norm = posixpath.normpath(name) return norm != name or norm == "." or norm.startswith("..") or norm.startswith("/") -def _validate_param_name(name): +def _validate_param_name(name: str): """Check that `name` is a valid parameter name and raise an exception if it isn't.""" if not _VALID_PARAM_AND_METRIC_NAMES.match(name): raise SubmarineException( @@ -63,7 +63,7 @@ def _validate_param_name(name): ) -def _validate_metric_name(name): +def _validate_metric_name(name: str): """Check that `name` is a valid metric name and raise an exception if it isn't.""" if not _VALID_PARAM_AND_METRIC_NAMES.match(name): raise SubmarineException( @@ -74,7 +74,7 @@ def _validate_metric_name(name): raise SubmarineException("Invalid metric name: '%s'. %s" % (name, bad_path_message(name))) -def _validate_length_limit(entity_name, limit, value): +def _validate_length_limit(entity_name: str, limit: int, value): if len(value) > limit: raise SubmarineException( "%s '%s' had length %s, which exceeded length limit of %s" @@ -82,7 +82,7 @@ def _validate_length_limit(entity_name, limit, value): ) -def validate_metric(key, value, timestamp, step): +def validate_metric(key, value, timestamp, step) -> None: """ Check that a param with the specified key, value, timestamp is valid and raise an exception if it isn't. @@ -107,7 +107,7 @@ def validate_metric(key, value, timestamp, step): ) -def validate_param(key, value): +def validate_param(key, value) -> None: """ Check that a param with the specified key & value is valid and raise an exception if it isn't. --------------------------------------------------------------------- To unsubscribe, e-mail: dev-unsubscr...@submarine.apache.org For additional commands, e-mail: dev-h...@submarine.apache.org