This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new 8abef26  Move the storage types, and align the API types to outcomes
8abef26 is described below

commit 8abef26d9fb3b1df86578c3a6ca25ab18f5c8218
Author: Sean B. Palmer <[email protected]>
AuthorDate: Sun Jul 20 09:48:16 2025 +0100

    Move the storage types, and align the API types to outcomes
---
 atr/blueprints/admin/admin.py |  11 ++--
 atr/blueprints/api/api.py     |  66 +++++++++++++--------
 atr/models/api.py             |  37 +++++++++---
 atr/storage/__init__.py       |  20 ++++---
 atr/storage/types.py          |  58 +++++++++++++++++++
 atr/storage/writers/keys.py   | 130 ++++++++++++++++++------------------------
 6 files changed, 200 insertions(+), 122 deletions(-)

diff --git a/atr/blueprints/admin/admin.py b/atr/blueprints/admin/admin.py
index 468ab65..e017000 100644
--- a/atr/blueprints/admin/admin.py
+++ b/atr/blueprints/admin/admin.py
@@ -47,6 +47,7 @@ import atr.ldap as ldap
 import atr.models.sql as sql
 import atr.routes.keys as keys
 import atr.routes.mapping as mapping
+import atr.storage.types as types
 import atr.template as template
 import atr.util as util
 import atr.validate as validate
@@ -660,18 +661,18 @@ async def admin_test() -> 
quart.wrappers.response.Response:
     async with storage.write(asf_uid) as write:
         wacm = write.as_committee_member("tooling").writer_or_raise()
         start = time.perf_counter_ns()
-        outcomes = await wacm.keys.upload(keys_file_text)
+        outcomes: types.KeyOutcomes = await 
wacm.keys.ensure_stored(keys_file_text)
         end = time.perf_counter_ns()
         logging.info(f"Upload of {outcomes.result_count} keys took {end - 
start} ns")
     for ocr in outcomes.results():
         logging.info(f"Uploaded key: {type(ocr)} {ocr.key_model.fingerprint}")
     for oce in outcomes.exceptions():
         logging.error(f"Error uploading key: {type(oce)} {oce}")
-    parsed_count = outcomes.result_predicate_count(lambda k: k.status == 
wacm.keys.KeyStatus.PARSED)
-    inserted_count = outcomes.result_predicate_count(lambda k: k.status == 
wacm.keys.KeyStatus.INSERTED)
-    linked_count = outcomes.result_predicate_count(lambda k: k.status == 
wacm.keys.KeyStatus.LINKED)
+    parsed_count = outcomes.result_predicate_count(lambda k: k.status == 
types.KeyStatus.PARSED)
+    inserted_count = outcomes.result_predicate_count(lambda k: k.status == 
types.KeyStatus.INSERTED)
+    linked_count = outcomes.result_predicate_count(lambda k: k.status == 
types.KeyStatus.LINKED)
     inserted_and_linked_count = outcomes.result_predicate_count(
-        lambda k: k.status == wacm.keys.KeyStatus.INSERTED_AND_LINKED
+        lambda k: k.status == types.KeyStatus.INSERTED_AND_LINKED
     )
     logging.info(f"Parsed: {parsed_count}")
     logging.info(f"Inserted: {inserted_count}")
diff --git a/atr/blueprints/api/api.py b/atr/blueprints/api/api.py
index d3ebe9d..21acfa8 100644
--- a/atr/blueprints/api/api.py
+++ b/atr/blueprints/api/api.py
@@ -43,6 +43,8 @@ import atr.routes.announce as announce
 import atr.routes.keys as keys
 import atr.routes.start as start
 import atr.routes.voting as voting
+import atr.storage as storage
+import atr.storage.types as types
 import atr.tasks.vote as tasks_vote
 import atr.user as user
 import atr.util as util
@@ -384,32 +386,48 @@ async def keys_get(fingerprint: str) -> DictResponse:
 async def keys_upload(data: models.api.KeysUploadArgs) -> DictResponse:
     asf_uid = _jwt_asf_uid()
     filetext = data.filetext
-    selected_committee_names = data.committees
-    async with db.session() as db_data:
-        participant_of_committees = await 
interaction.user_committees_participant(asf_uid, caller_data=db_data)
-        participant_of_committee_names = [c.name for c in 
participant_of_committees]
-        for committee_name in selected_committee_names:
-            if committee_name not in participant_of_committee_names:
-                raise exceptions.BadRequest(f"You are not a participant of 
committee {committee_name}")
-        # TODO: Does this export KEYS files?
-        # Appearently it does not
-        # This needs fixing in keys.py too
-        results, success_count, error_count, submitted_committees = await 
interaction.upload_keys(
-            participant_of_committee_names, filetext, selected_committee_names
-        )
-
-    # TODO: Should push this much further upstream
-    import logging
-
-    for result in results:
-        logging.info(result)
-    results = [models.api.KeysUploadSubset(**result) for result in results]
+    selected_committee_name = data.committee
+    outcomes_list = []
+    async with storage.write(asf_uid) as write:
+        wacm = 
write.as_committee_member(selected_committee_name).writer_or_raise()
+        associated: types.KeyOutcomes = await 
wacm.keys.ensure_associated(filetext)
+        outcomes_list.append(associated)
+
+        # TODO: It would be nice to serialise the actual outcomes
+        api_outcomes = []
+        merged_outcomes = storage.outcomes_merge(*outcomes_list)
+        for outcome in merged_outcomes.outcomes():
+            match outcome:
+                case storage.OutcomeResult() as ocr:
+                    result: types.Key = ocr.result_or_raise()
+                    api_outcome = models.api.KeysUploadResult(
+                        status="success",
+                        key=result.key_model,
+                    )
+                case storage.OutcomeException() as oce:
+                    # TODO: This branch means we must improve the return type
+                    match oce.exception_or_none():
+                        case types.PublicKeyError() as pke:
+                            api_outcome = models.api.KeysUploadException(
+                                status="error",
+                                key=pke.key.key_model,
+                                error=str(pke),
+                                error_type=type(pke).__name__,
+                            )
+                        case _ as e:
+                            api_outcome = models.api.KeysUploadException(
+                                status="error",
+                                key=None,
+                                error=str(e),
+                                error_type=type(e).__name__,
+                            )
+            api_outcomes.append(api_outcome)
     return models.api.KeysUploadResults(
         endpoint="/keys/upload",
-        results=results,
-        success_count=success_count,
-        error_count=error_count,
-        submitted_committees=submitted_committees,
+        results=api_outcomes,
+        success_count=merged_outcomes.result_count,
+        error_count=merged_outcomes.exception_count,
+        submitted_committee=selected_committee_name,
     ).model_dump(), 200
 
 
diff --git a/atr/models/api.py b/atr/models/api.py
index 14cc85c..377f0ab 100644
--- a/atr/models/api.py
+++ b/atr/models/api.py
@@ -146,23 +146,42 @@ class KeysGetResults(schema.Strict):
 
 class KeysUploadArgs(schema.Strict):
     filetext: str
-    committees: list[str]
+    committee: str
 
 
-class KeysUploadSubset(schema.Lax):
-    status: Literal["success", "error"]
-    key_id: str
-    fingerprint: str
-    user_id: str
-    email: str
+class KeysUploadException(schema.Strict):
+    status: Literal["error"] = schema.Field(alias="status")
+    key: sql.PublicSigningKey | None
+    error: str
+    error_type: str
+
+
+class KeysUploadResult(schema.Strict):
+    status: Literal["success"] = schema.Field(alias="status")
+    key: sql.PublicSigningKey
+
+
+KeysUploadOutcome = Annotated[
+    KeysUploadResult | KeysUploadException,
+    schema.Field(discriminator="status"),
+]
+
+KeysUploadOutcomeAdapter = pydantic.TypeAdapter(KeysUploadOutcome)
+
+
+# def validate_keys_upload_outcome(value: Any) -> KeysUploadOutcome:
+#     obj = KeysUploadOutcomeAdapter.validate_python(value)
+#     if not isinstance(obj, KeysUploadOutcome):
+#         raise ResultsTypeError(f"Invalid API response: {value}")
+#     return obj
 
 
 class KeysUploadResults(schema.Strict):
     endpoint: Literal["/keys/upload"] = schema.Field(alias="endpoint")
-    results: Sequence[KeysUploadSubset]
+    results: Sequence[KeysUploadResult | KeysUploadException]
     success_count: int
     error_count: int
-    submitted_committees: list[str]
+    submitted_committee: str
 
 
 class KeysUserResults(schema.Strict):
diff --git a/atr/storage/__init__.py b/atr/storage/__init__.py
index 2830fb8..78035cb 100644
--- a/atr/storage/__init__.py
+++ b/atr/storage/__init__.py
@@ -274,7 +274,7 @@ class OutcomeResult[T](OutcomeCore[T]):
         return None
 
 
-class OutcomeError[T, E: Exception](OutcomeCore[T]):
+class OutcomeException[T, E: Exception](OutcomeCore[T]):
     __exception: E
 
     def __init__(self, exception: E, name: str | None = None):
@@ -305,9 +305,9 @@ class OutcomeError[T, E: Exception](OutcomeCore[T]):
 
 
 class Outcomes[T]:
-    __outcomes: list[OutcomeResult[T] | OutcomeError[T, Exception]]
+    __outcomes: list[OutcomeResult[T] | OutcomeException[T, Exception]]
 
-    def __init__(self, *outcomes: OutcomeResult[T] | OutcomeError[T, 
Exception]):
+    def __init__(self, *outcomes: OutcomeResult[T] | OutcomeException[T, 
Exception]):
         self.__outcomes = list(outcomes)
 
     @property
@@ -320,18 +320,18 @@ class Outcomes[T]:
 
     def append(self, result_or_error: T | Exception, name: str | None = None) 
-> None:
         if isinstance(result_or_error, Exception):
-            self.__outcomes.append(OutcomeError(result_or_error, name))
+            self.__outcomes.append(OutcomeException(result_or_error, name))
         else:
             self.__outcomes.append(OutcomeResult(result_or_error, name))
 
     @property
     def exception_count(self) -> int:
-        return sum(1 for outcome in self.__outcomes if isinstance(outcome, 
OutcomeError))
+        return sum(1 for outcome in self.__outcomes if isinstance(outcome, 
OutcomeException))
 
     def exceptions(self) -> list[Exception]:
         exceptions_list = []
         for outcome in self.__outcomes:
-            if isinstance(outcome, OutcomeError):
+            if isinstance(outcome, OutcomeException):
                 exception_or_none = outcome.exception_or_none()
                 if exception_or_none is not None:
                     exceptions_list.append(exception_or_none)
@@ -367,7 +367,7 @@ class Outcomes[T]:
     def result_count(self) -> int:
         return sum(1 for outcome in self.__outcomes if outcome.ok)
 
-    def outcomes(self) -> list[OutcomeResult[T] | OutcomeError[T, Exception]]:
+    def outcomes(self) -> list[OutcomeResult[T] | OutcomeException[T, 
Exception]]:
         return self.__outcomes
 
     def result_predicate_count(self, predicate: Callable[[T], bool]) -> int:
@@ -383,7 +383,7 @@ class Outcomes[T]:
                 try:
                     result = f(outcome.result_or_raise())
                 except Exception as e:
-                    self.__outcomes[i] = OutcomeError(e, outcome.name)
+                    self.__outcomes[i] = OutcomeException(e, outcome.name)
                 else:
                     self.__outcomes[i] = OutcomeResult(result, outcome.name)
 
@@ -412,6 +412,10 @@ class Outcomes[T]:
     #     self.__outcomes[:] = new_outcomes
 
 
+def outcomes_merge[T](*outcomes: Outcomes[T]) -> Outcomes[T]:
+    return Outcomes(*[outcome for outcome in outcomes for outcome in 
outcome.outcomes()])
+
+
 # Context managers
 
 
diff --git a/atr/storage/types.py b/atr/storage/types.py
new file mode 100644
index 0000000..53a4423
--- /dev/null
+++ b/atr/storage/types.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import enum
+from typing import TYPE_CHECKING
+
+import atr.models.schema as schema
+import atr.models.sql as sql
+import atr.storage as storage
+
+
+class KeyStatus(enum.Flag):
+    PARSED = 0
+    INSERTED = enum.auto()
+    LINKED = enum.auto()
+    INSERTED_AND_LINKED = INSERTED | LINKED
+
+
+class Key(schema.Strict):
+    status: KeyStatus
+    key_model: sql.PublicSigningKey
+
+
+class PublicKeyError(Exception):
+    def __init__(self, key: Key, original_error: Exception):
+        self.__key = key
+        self.__original_error = original_error
+
+    def __str__(self) -> str:
+        return f"PublicKeyError: {self.__original_error}"
+
+    @property
+    def key(self) -> Key:
+        return self.__key
+
+    @property
+    def original_error(self) -> Exception:
+        return self.__original_error
+
+
+if TYPE_CHECKING:
+    KeyOutcomes = storage.Outcomes[Key]
+    # KeyOutcomeResult = storage.OutcomeResult[Key]
+    # KeyOutcomeError = storage.OutcomeError[Key, Exception]
diff --git a/atr/storage/writers/keys.py b/atr/storage/writers/keys.py
index 80dfe61..cf876b4 100644
--- a/atr/storage/writers/keys.py
+++ b/atr/storage/writers/keys.py
@@ -19,7 +19,6 @@
 from __future__ import annotations
 
 import asyncio
-import enum
 import logging
 import tempfile
 import time
@@ -30,50 +29,19 @@ import pgpy.constants as constants
 import sqlalchemy.dialects.sqlite as sqlite
 
 import atr.db as db
-import atr.models.schema as schema
 import atr.models.sql as sql
 import atr.storage as storage
+import atr.storage.types as types
 import atr.user as user
 import atr.util as util
 
 if TYPE_CHECKING:
     from collections.abc import Callable, Coroutine
 
-    KeyOutcomes = storage.Outcomes[sql.PublicSigningKey]
-
 PERFORMANCES: Final[dict[int, tuple[str, int]]] = {}
 _MEASURE_PERFORMANCE: Final[bool] = False
 
 
-class KeyStatus(enum.Flag):
-    PARSED = 0
-    INSERTED = enum.auto()
-    LINKED = enum.auto()
-    INSERTED_AND_LINKED = INSERTED | LINKED
-
-
-class Key(schema.Strict):
-    status: KeyStatus
-    key_model: sql.PublicSigningKey
-
-
-class PublicKeyError(Exception):
-    def __init__(self, key: Key, original_error: Exception):
-        self.__key = key
-        self.__original_error = original_error
-
-    def __str__(self) -> str:
-        return f"PublicKeyError: {self.__original_error}"
-
-    @property
-    def key(self) -> Key:
-        return self.__key
-
-    @property
-    def original_error(self) -> Exception:
-        return self.__original_error
-
-
 def performance(func: Callable[..., Any]) -> Callable[..., Any]:
     def wrapper(*args: Any, **kwargs: Any) -> Any:
         if not _MEASURE_PERFORMANCE:
@@ -102,10 +70,6 @@ def performance_async(func: Callable[..., Coroutine[Any, 
Any, Any]]) -> Callable
 
 
 class CommitteeMember:
-    Key = Key
-    KeyStatus = KeyStatus
-    PublicKeyError = PublicKeyError
-
     def __init__(
         self, credentials: storage.WriteAsCommitteeMember, data: db.Session, 
asf_uid: str, committee_name: str
     ):
@@ -118,6 +82,10 @@ class CommitteeMember:
         self.__committee_name = committee_name
         self.__key_block_models_cache = {}
 
+    @performance_async
+    async def associate(self, outcomes: storage.Outcomes[types.Key]) -> 
storage.Outcomes[types.Key]:
+        raise NotImplementedError("Not implemented")
+
     @performance_async
     async def committee(self) -> sql.Committee:
         return await self.__data.committee(name=self.__committee_name, 
_public_signing_keys=True).demand(
@@ -125,30 +93,15 @@ class CommitteeMember:
         )
 
     @performance_async
-    async def upload(self, keys_file_text: str) -> 
storage.Outcomes[CommitteeMember.Key]:
-        outcomes = storage.Outcomes[CommitteeMember.Key]()
-        try:
-            ldap_data = await util.email_to_uid_map()
-            key_blocks = util.parse_key_blocks(keys_file_text)
-        except Exception as e:
-            outcomes.append(e)
-            return outcomes
-        for key_block in key_blocks:
-            try:
-                key_models = await asyncio.to_thread(self.__block_models, 
key_block, ldap_data)
-                outcomes.extend(key_models)
-            except Exception as e:
-                outcomes.append(e)
-        # Try adding the keys to the database
-        # If not, all keys will be replaced with a PostParseError
-        outcomes = await self.__database_add_models(outcomes)
-        if _MEASURE_PERFORMANCE:
-            for key, value in PERFORMANCES.items():
-                logging.info(f"{key}: {value}")
-        return outcomes
+    async def ensure_associated(self, keys_file_text: str) -> 
storage.Outcomes[types.Key]:
+        return await self.__ensure(keys_file_text, associate=True)
+
+    @performance_async
+    async def ensure_stored(self, keys_file_text: str) -> 
storage.Outcomes[types.Key]:
+        return await self.__ensure(keys_file_text, associate=False)
 
     @performance
-    def __block_models(self, key_block: str, ldap_data: dict[str, str]) -> 
list[CommitteeMember.Key | Exception]:
+    def __block_models(self, key_block: str, ldap_data: dict[str, str]) -> 
list[types.Key | Exception]:
         # This cache is only held for the session
         if key_block in self.__key_block_models_cache:
             return self.__key_block_models_cache[key_block]
@@ -165,7 +118,7 @@ class CommitteeMember:
                     if key_model is None:
                         # Was not a primary key, so skip it
                         continue
-                    key = 
CommitteeMember.Key(status=CommitteeMember.KeyStatus.PARSED, 
key_model=key_model)
+                    key = types.Key(status=types.KeyStatus.PARSED, 
key_model=key_model)
                     key_list.append(key)
                 except Exception as e:
                     key_list.append(e)
@@ -174,28 +127,30 @@ class CommitteeMember:
 
     @performance_async
     async def __database_add_models(
-        self, outcomes: storage.Outcomes[CommitteeMember.Key]
-    ) -> storage.Outcomes[CommitteeMember.Key]:
+        self, outcomes: storage.Outcomes[types.Key], associate: bool = True
+    ) -> storage.Outcomes[types.Key]:
         # Try to upsert all models and link to the committee in one transaction
         try:
-            outcomes = await self.__database_add_models_core(outcomes)
+            outcomes = await self.__database_add_models_core(outcomes, 
associate=associate)
         except Exception as e:
             # This logging is just so that ruff does not erase e
             logging.info(f"Post-parse error: {e}")
 
-            def raise_post_parse_error(key: CommitteeMember.Key) -> NoReturn:
+            def raise_post_parse_error(key: types.Key) -> NoReturn:
                 nonlocal e
                 # We assume here that the transaction was rolled back correctly
-                key = 
CommitteeMember.Key(status=CommitteeMember.KeyStatus.PARSED, 
key_model=key.key_model)
-                raise PublicKeyError(key, e)
+                key = types.Key(status=types.KeyStatus.PARSED, 
key_model=key.key_model)
+                raise types.PublicKeyError(key, e)
 
             outcomes.update_results(raise_post_parse_error)
         return outcomes
 
     @performance_async
     async def __database_add_models_core(
-        self, outcomes: storage.Outcomes[CommitteeMember.Key]
-    ) -> storage.Outcomes[CommitteeMember.Key]:
+        self,
+        outcomes: storage.Outcomes[types.Key],
+        associate: bool = True,
+    ) -> storage.Outcomes[types.Key]:
         via = sql.validate_instrumented_attribute
         key_list = outcomes.results()
 
@@ -212,9 +167,9 @@ class CommitteeMember:
         key_inserts = {row.fingerprint for row in key_insert_result}
         logging.info(f"Inserted {len(key_inserts)} keys")
 
-        def replace_with_inserted(key: CommitteeMember.Key) -> 
CommitteeMember.Key:
+        def replace_with_inserted(key: types.Key) -> types.Key:
             if key.key_model.fingerprint in key_inserts:
-                key.status = CommitteeMember.KeyStatus.INSERTED
+                key.status = types.KeyStatus.INSERTED
             return key
 
         outcomes.update_results(replace_with_inserted)
@@ -224,7 +179,7 @@ class CommitteeMember:
 
         existing_fingerprints = {k.fingerprint for k in 
committee.public_signing_keys}
         new_fingerprints = persisted_fingerprints - existing_fingerprints
-        if new_fingerprints:
+        if new_fingerprints and associate:
             link_values = [{"committee_name": self.__committee_name, 
"key_fingerprint": fp} for fp in new_fingerprints]
             link_insert_result = await self.__data.execute(
                 sqlite.insert(sql.KeyLink)
@@ -235,15 +190,15 @@ class CommitteeMember:
             link_inserts = {row.key_fingerprint for row in link_insert_result}
             logging.info(f"Inserted {len(link_inserts)} key links")
 
-            def replace_with_linked(key: CommitteeMember.Key) -> 
CommitteeMember.Key:
+            def replace_with_linked(key: types.Key) -> types.Key:
                 nonlocal link_inserts
                 match key:
-                    case 
CommitteeMember.Key(status=CommitteeMember.KeyStatus.INSERTED):
+                    case types.Key(status=types.KeyStatus.INSERTED):
                         if key.key_model.fingerprint in link_inserts:
-                            key.status = 
CommitteeMember.KeyStatus.INSERTED_AND_LINKED
-                    case 
CommitteeMember.Key(status=CommitteeMember.KeyStatus.PARSED):
+                            key.status = types.KeyStatus.INSERTED_AND_LINKED
+                    case types.Key(status=types.KeyStatus.PARSED):
                         if key.key_model.fingerprint in link_inserts:
-                            key.status = CommitteeMember.KeyStatus.LINKED
+                            key.status = types.KeyStatus.LINKED
                 return key
 
             outcomes.update_results(replace_with_linked)
@@ -253,6 +208,29 @@ class CommitteeMember:
         await self.__data.commit()
         return outcomes
 
+    @performance_async
+    async def __ensure(self, keys_file_text: str, associate: bool = True) -> 
storage.Outcomes[types.Key]:
+        outcomes = storage.Outcomes[types.Key]()
+        try:
+            ldap_data = await util.email_to_uid_map()
+            key_blocks = util.parse_key_blocks(keys_file_text)
+        except Exception as e:
+            outcomes.append(e)
+            return outcomes
+        for key_block in key_blocks:
+            try:
+                key_models = await asyncio.to_thread(self.__block_models, 
key_block, ldap_data)
+                outcomes.extend(key_models)
+            except Exception as e:
+                outcomes.append(e)
+        # Try adding the keys to the database
+        # If not, all keys will be replaced with a PostParseError
+        outcomes = await self.__database_add_models(outcomes, 
associate=associate)
+        if _MEASURE_PERFORMANCE:
+            for key, value in PERFORMANCES.items():
+                logging.info(f"{key}: {value}")
+        return outcomes
+
     @performance
     def __keyring_fingerprint_model(
         self, keyring: pgpy.PGPKeyring, fingerprint: str, ldap_data: dict[str, 
str]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to