This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new be00bf2  Record which keys were inserted and which were linked
be00bf2 is described below

commit be00bf2d86de07e2d0d74846bb58ba5bc5b06867
Author: Sean B. Palmer <[email protected]>
AuthorDate: Fri Jul 18 19:46:28 2025 +0100

    Record which keys were inserted and which were linked
---
 atr/blueprints/admin/admin.py |  20 ++++--
 atr/storage/__init__.py       |  23 +++++--
 atr/storage/writers/keys.py   | 146 ++++++++++++++++++++++++++++--------------
 3 files changed, 133 insertions(+), 56 deletions(-)

diff --git a/atr/blueprints/admin/admin.py b/atr/blueprints/admin/admin.py
index a8739b7..5deac79 100644
--- a/atr/blueprints/admin/admin.py
+++ b/atr/blueprints/admin/admin.py
@@ -646,7 +646,7 @@ async def admin_test() -> quart.wrappers.response.Response:
     import atr.storage as storage
 
     async with aiohttp.ClientSession() as aiohttp_client_session:
-        url = "https://downloads.apache.org/beam/KEYS";
+        url = "https://downloads.apache.org/zeppelin/KEYS";
         async with aiohttp_client_session.get(url) as response:
             keys_file_text = await response.text()
             # logging.info(f"Keys file text: {keys_file_text}")
@@ -662,9 +662,21 @@ async def admin_test() -> quart.wrappers.response.Response:
         start = time.perf_counter_ns()
         outcomes = await wacm.keys.upload(keys_file_text)
         end = time.perf_counter_ns()
-        logging.info(f"Upload of {outcomes.ok_count} keys took {end - start} 
ns")
-    for oe in outcomes.exceptions():
-        logging.error(f"Error uploading key: {oe}")
+        logging.info(f"Upload of {outcomes.result_count} keys took {end - 
start} ns")
+    for ocr in outcomes.results():
+        logging.info(f"Uploaded key: {type(ocr)} {ocr.key_model.fingerprint}")
+    for oce in outcomes.exceptions():
+        logging.error(f"Error uploading key: {type(oce)} {oce}")
+    parsed_count = outcomes.result_predicate_count(lambda k: k.status == 
wacm.keys.key_status.PARSED)
+    inserted_count = outcomes.result_predicate_count(lambda k: k.status == 
wacm.keys.key_status.INSERTED)
+    linked_count = outcomes.result_predicate_count(lambda k: k.status == 
wacm.keys.key_status.LINKED)
+    inserted_and_linked_count = outcomes.result_predicate_count(
+        lambda k: k.status == wacm.keys.key_status.INSERTED_AND_LINKED
+    )
+    logging.info(f"Parsed: {parsed_count}")
+    logging.info(f"Inserted: {inserted_count}")
+    logging.info(f"Linked: {linked_count}")
+    logging.info(f"InsertedAndLinked: {inserted_and_linked_count}")
     return quart.Response(str(wacm), mimetype="text/plain")
 
 
diff --git a/atr/storage/__init__.py b/atr/storage/__init__.py
index 36fed44..2830fb8 100644
--- a/atr/storage/__init__.py
+++ b/atr/storage/__init__.py
@@ -324,10 +324,6 @@ class Outcomes[T]:
         else:
             self.__outcomes.append(OutcomeResult(result_or_error, name))
 
-    def extend(self, result_or_error_list: Sequence[T | Exception]) -> None:
-        for result_or_error in result_or_error_list:
-            self.append(result_or_error)
-
     @property
     def exception_count(self) -> int:
         return sum(1 for outcome in self.__outcomes if isinstance(outcome, 
OutcomeError))
@@ -341,6 +337,10 @@ class Outcomes[T]:
                     exceptions_list.append(exception_or_none)
         return exceptions_list
 
+    def extend(self, result_or_error_list: Sequence[T | Exception]) -> None:
+        for result_or_error in result_or_error_list:
+            self.append(result_or_error)
+
     def named_results(self) -> dict[str, T]:
         named = {}
         for outcome in self.__outcomes:
@@ -351,6 +351,12 @@ class Outcomes[T]:
     def names(self) -> list[str | None]:
         return [outcome.name for outcome in self.__outcomes if (outcome.name 
is not None)]
 
+    # def replace(self, a: T, b: T) -> None:
+    #     for i, outcome in enumerate(self.__outcomes):
+    #         if isinstance(outcome, OutcomeResult):
+    #             if outcome.result_or_raise() == a:
+    #                 self.__outcomes[i] = OutcomeResult(b, outcome.name)
+
     def results_or_raise(self, exception_class: type[Exception] | None = None) 
-> list[T]:
         return [outcome.result_or_raise(exception_class) for outcome in 
self.__outcomes]
 
@@ -358,12 +364,19 @@ class Outcomes[T]:
         return [outcome.result_or_raise() for outcome in self.__outcomes if 
outcome.ok]
 
     @property
-    def ok_count(self) -> int:
+    def result_count(self) -> int:
         return sum(1 for outcome in self.__outcomes if outcome.ok)
 
     def outcomes(self) -> list[OutcomeResult[T] | OutcomeError[T, Exception]]:
         return self.__outcomes
 
+    def result_predicate_count(self, predicate: Callable[[T], bool]) -> int:
+        return sum(
+            1
+            for outcome in self.__outcomes
+            if isinstance(outcome, OutcomeResult) and 
predicate(outcome.result_or_raise())
+        )
+
     def update_results(self, f: Callable[[T], T]) -> None:
         for i, outcome in enumerate(self.__outcomes):
             if isinstance(outcome, OutcomeResult):
diff --git a/atr/storage/writers/keys.py b/atr/storage/writers/keys.py
index 8602145..e35115b 100644
--- a/atr/storage/writers/keys.py
+++ b/atr/storage/writers/keys.py
@@ -19,6 +19,7 @@
 from __future__ import annotations
 
 import asyncio
+import enum
 import logging
 import tempfile
 import time
@@ -29,6 +30,7 @@ import pgpy.constants as constants
 import sqlalchemy.dialects.sqlite as sqlite
 
 import atr.db as db
+import atr.models.schema as schema
 import atr.models.sql as sql
 import atr.storage as storage
 import atr.user as user
@@ -87,6 +89,18 @@ def performance_async(func: Callable[..., Coroutine[Any, 
Any, Any]]) -> Callable
     return wrapper
 
 
+class KeyStatus(enum.Flag):
+    PARSED = 0
+    INSERTED = enum.auto()
+    LINKED = enum.auto()
+    INSERTED_AND_LINKED = INSERTED | LINKED
+
+
+class Key(schema.Strict):
+    status: KeyStatus
+    key_model: sql.PublicSigningKey
+
+
 class CommitteeMember:
     def __init__(
         self, credentials: storage.WriteAsCommitteeMember, data: db.Session, 
asf_uid: str, committee_name: str
@@ -106,9 +120,17 @@ class CommitteeMember:
             storage.AccessError(f"Committee not found: 
{self.__committee_name}")
         )
 
+    # @property
+    # def key_type(self) -> type[Key]:
+    #     return Key
+
+    @property
+    def key_status(self) -> type[KeyStatus]:
+        return KeyStatus
+
     @performance_async
-    async def upload(self, keys_file_text: str) -> KeyOutcomes:
-        outcomes = storage.Outcomes[sql.PublicSigningKey]()
+    async def upload(self, keys_file_text: str) -> storage.Outcomes[Key]:
+        outcomes = storage.Outcomes[Key]()
         try:
             ldap_data = await util.email_to_uid_map()
             key_blocks = util.parse_key_blocks(keys_file_text)
@@ -130,7 +152,7 @@ class CommitteeMember:
         return outcomes
 
     @performance
-    def __block_models(self, key_block: str, ldap_data: dict[str, str]) -> 
list[sql.PublicSigningKey | Exception]:
+    def __block_models(self, key_block: str, ldap_data: dict[str, str]) -> 
list[Key | Exception]:
         # This cache is only held for the session
         if key_block in self.__key_block_models_cache:
             return self.__key_block_models_cache[key_block]
@@ -140,68 +162,95 @@ class CommitteeMember:
             tmpfile.flush()
             keyring = pgpy.PGPKeyring()
             fingerprints = keyring.load(tmpfile.name)
-            models = []
+            key_list = []
             for fingerprint in fingerprints:
                 try:
-                    model = self.__keyring_fingerprint_model(keyring, 
fingerprint, ldap_data)
-                    if model is None:
+                    key_model = self.__keyring_fingerprint_model(keyring, 
fingerprint, ldap_data)
+                    if key_model is None:
                         # Was not a primary key, so skip it
                         continue
-                    models.append(model)
+                    key = Key(status=KeyStatus.PARSED, key_model=key_model)
+                    key_list.append(key)
                 except Exception as e:
-                    models.append(e)
-            self.__key_block_models_cache[key_block] = models
-            return models
+                    key_list.append(e)
+            self.__key_block_models_cache[key_block] = key_list
+            return key_list
 
     @performance_async
-    async def __database_add_models(self, outcomes: KeyOutcomes) -> 
KeyOutcomes:
+    async def __database_add_models(self, outcomes: storage.Outcomes[Key]) -> 
storage.Outcomes[Key]:
         # Try to upsert all models and link to the committee in one transaction
         try:
-            key_models = outcomes.results()
-
-            await self.__data.begin_immediate()
-            committee = await self.committee()
-
-            key_values = [m.model_dump(exclude={"committees"}) for m in 
key_models]
-            key_insert_result = await self.__data.execute(
-                sqlite.insert(sql.PublicSigningKey)
-                .values(key_values)
-                .on_conflict_do_nothing(index_elements=["fingerprint"])
-            )
-            key_insert_count = key_insert_result.rowcount
-            logging.info(f"Inserted {key_insert_count} keys")
-
-            persisted_fingerprints = {v["fingerprint"] for v in key_values}
-            await self.__data.flush()
-
-            existing_fingerprints = {k.fingerprint for k in 
committee.public_signing_keys}
-            new_fingerprints = persisted_fingerprints - existing_fingerprints
-            if new_fingerprints:
-                link_values = [
-                    {"committee_name": self.__committee_name, 
"key_fingerprint": fp} for fp in new_fingerprints
-                ]
-                link_insert_result = await self.__data.execute(
-                    sqlite.insert(sql.KeyLink)
-                    .values(link_values)
-                    .on_conflict_do_nothing(index_elements=["committee_name", 
"key_fingerprint"])
-                )
-                link_insert_count = link_insert_result.rowcount
-            else:
-                link_insert_count = 0
-            logging.info(f"Inserted {link_insert_count} key links")
-
-            await self.__data.commit()
+            outcomes = await self.__database_add_models_core(outcomes)
         except Exception as e:
             # This logging is just so that ruff does not erase e
             logging.info(f"Post-parse error: {e}")
 
-            def raise_post_parse_error(model: sql.PublicSigningKey) -> 
NoReturn:
+            def raise_post_parse_error(key: Key) -> NoReturn:
                 nonlocal e
-                raise PostParseError(model, e)
+                raise PostParseError(key.key_model, e)
 
             outcomes.update_results(raise_post_parse_error)
         return outcomes
 
+    @performance_async
+    async def __database_add_models_core(self, outcomes: 
storage.Outcomes[Key]) -> storage.Outcomes[Key]:
+        via = sql.validate_instrumented_attribute
+        key_list = outcomes.results()
+
+        await self.__data.begin_immediate()
+        committee = await self.committee()
+
+        key_values = [key.key_model.model_dump(exclude={"committees"}) for key 
in key_list]
+        key_insert_result = await self.__data.execute(
+            sqlite.insert(sql.PublicSigningKey)
+            .values(key_values)
+            .on_conflict_do_nothing(index_elements=["fingerprint"])
+            .returning(via(sql.PublicSigningKey.fingerprint))
+        )
+        key_inserts = {row.fingerprint for row in key_insert_result}
+        logging.info(f"Inserted {len(key_inserts)} keys")
+
+        def replace_with_inserted(key: Key) -> Key:
+            if key.key_model.fingerprint in key_inserts:
+                key.status = KeyStatus.INSERTED
+            return key
+
+        outcomes.update_results(replace_with_inserted)
+
+        persisted_fingerprints = {v["fingerprint"] for v in key_values}
+        await self.__data.flush()
+
+        existing_fingerprints = {k.fingerprint for k in 
committee.public_signing_keys}
+        new_fingerprints = persisted_fingerprints - existing_fingerprints
+        if new_fingerprints:
+            link_values = [{"committee_name": self.__committee_name, 
"key_fingerprint": fp} for fp in new_fingerprints]
+            link_insert_result = await self.__data.execute(
+                sqlite.insert(sql.KeyLink)
+                .values(link_values)
+                .on_conflict_do_nothing(index_elements=["committee_name", 
"key_fingerprint"])
+                .returning(via(sql.KeyLink.key_fingerprint))
+            )
+            link_inserts = {row.key_fingerprint for row in link_insert_result}
+            logging.info(f"Inserted {len(link_inserts)} key links")
+
+            def replace_with_linked(key: Key) -> Key:
+                nonlocal link_inserts
+                match key:
+                    case Key(status=KeyStatus.INSERTED):
+                        if key.key_model.fingerprint in link_inserts:
+                            key.status = KeyStatus.INSERTED_AND_LINKED
+                    case Key(status=KeyStatus.PARSED):
+                        if key.key_model.fingerprint in link_inserts:
+                            key.status = KeyStatus.LINKED
+                return key
+
+            outcomes.update_results(replace_with_linked)
+        else:
+            logging.info("Inserted 0 key links (none to insert)")
+
+        await self.__data.commit()
+        return outcomes
+
     @performance
     def __keyring_fingerprint_model(
         self, keyring: pgpy.PGPKeyring, fingerprint: str, ldap_data: dict[str, 
str]
@@ -211,6 +260,8 @@ class CommitteeMember:
                 return None
             uids = [uid.userid for uid in key.userids]
             asf_uid = self.__uids_asf_uid(uids, ldap_data)
+
+            # TODO: Improve this
             key_size = key.key_size
             length = 0
             if isinstance(key_size, constants.EllipticCurveOID):
@@ -222,6 +273,7 @@ class CommitteeMember:
                 length = key_size
             else:
                 raise ValueError(f"Key size is not an integer: 
{type(key_size)}, {key_size}")
+
             return sql.PublicSigningKey(
                 fingerprint=str(key.fingerprint).lower(),
                 algorithm=key.key_algorithm.value,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to