This is an automated email from the ASF dual-hosted git repository. lahirujayathilake pushed a commit to branch access-integration-v2 in repository https://gitbox.apache.org/repos/asf/airavata-custos.git
commit ab5f6921390beae7bdd413cdba45eb3b984013c4 Author: lahiruj <[email protected]> AuthorDate: Wed May 20 17:06:26 2026 -0400 Add identity, lifecycle status, and user-merge models and changes to core --- .gitignore | 9 +- .../db/migrations/000014_users_status.down.sql | 20 ++ internal/db/migrations/000014_users_status.up.sql | 20 ++ .../db/migrations/000015_projects_status.down.sql | 20 ++ .../db/migrations/000015_projects_status.up.sql | 20 ++ .../000016_compute_cluster_users_status.down.sql | 20 ++ .../000016_compute_cluster_users_status.up.sql | 20 ++ ...00017_external_identities_and_user_dns.down.sql | 19 ++ .../000017_external_identities_and_user_dns.up.sql | 50 +++++ internal/db/migrations/000018_user_merges.down.sql | 18 ++ internal/db/migrations/000018_user_merges.up.sql | 30 +++ internal/server/server.go | 230 +++++++++++++++++++++ .../store/compute_allocation_membership_store.go | 18 ++ internal/store/compute_cluster_user_store.go | 46 ++++- internal/store/external_identity_store.go | 132 ++++++++++++ internal/store/project_store.go | 35 +++- internal/store/store.go | 63 ++++++ internal/store/user_dn_store.go | 103 +++++++++ internal/store/user_merge_store.go | 77 +++++++ internal/store/user_store.go | 28 ++- pkg/events/external_identity_subscribe.go | 64 ++++++ pkg/events/types.go | 14 ++ pkg/events/user_dn_subscribe.go | 57 +++++ pkg/models/allocation.go | 9 +- pkg/models/identity.go | 44 ++++ pkg/models/project.go | 56 +++-- pkg/service/compute_cluster_user.go | 49 ++++- pkg/service/external_identity.go | 191 +++++++++++++++++ pkg/service/project.go | 58 +++++- pkg/service/service.go | 108 +++++----- pkg/service/user.go | 102 ++++++++- pkg/service/user_dn.go | 127 ++++++++++++ pkg/service/user_merge.go | 144 +++++++++++++ 33 files changed, 1899 insertions(+), 102 deletions(-) diff --git a/.gitignore b/.gitignore index f7898df4e..d24da884a 100644 --- a/.gitignore +++ b/.gitignore @@ -56,4 +56,11 @@ venv/ # Vault vault.db -raft.db \ No newline at end of file +raft.db + +# Configurations +.env +config.yaml + +# Binaries +bin/ \ No newline at end of file diff --git a/internal/db/migrations/000014_users_status.down.sql b/internal/db/migrations/000014_users_status.down.sql new file mode 100644 index 000000000..8c5d8445d --- /dev/null +++ b/internal/db/migrations/000014_users_status.down.sql @@ -0,0 +1,20 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +ALTER TABLE users + DROP KEY idx_users_status, + DROP COLUMN status; diff --git a/internal/db/migrations/000014_users_status.up.sql b/internal/db/migrations/000014_users_status.up.sql new file mode 100644 index 000000000..5b88416a6 --- /dev/null +++ b/internal/db/migrations/000014_users_status.up.sql @@ -0,0 +1,20 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +ALTER TABLE users + ADD COLUMN status VARCHAR(32) NOT NULL DEFAULT 'ACTIVE' AFTER email, + ADD KEY idx_users_status (status); diff --git a/internal/db/migrations/000015_projects_status.down.sql b/internal/db/migrations/000015_projects_status.down.sql new file mode 100644 index 000000000..136999f6f --- /dev/null +++ b/internal/db/migrations/000015_projects_status.down.sql @@ -0,0 +1,20 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +ALTER TABLE projects + DROP KEY idx_projects_status, + DROP COLUMN status; diff --git a/internal/db/migrations/000015_projects_status.up.sql b/internal/db/migrations/000015_projects_status.up.sql new file mode 100644 index 000000000..1de30ed61 --- /dev/null +++ b/internal/db/migrations/000015_projects_status.up.sql @@ -0,0 +1,20 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +ALTER TABLE projects + ADD COLUMN status VARCHAR(32) NOT NULL DEFAULT 'ACTIVE' AFTER project_pi_id, + ADD KEY idx_projects_status (status); diff --git a/internal/db/migrations/000016_compute_cluster_users_status.down.sql b/internal/db/migrations/000016_compute_cluster_users_status.down.sql new file mode 100644 index 000000000..fe9eea446 --- /dev/null +++ b/internal/db/migrations/000016_compute_cluster_users_status.down.sql @@ -0,0 +1,20 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +ALTER TABLE compute_cluster_users + DROP KEY idx_compute_cluster_users_status, + DROP COLUMN status; diff --git a/internal/db/migrations/000016_compute_cluster_users_status.up.sql b/internal/db/migrations/000016_compute_cluster_users_status.up.sql new file mode 100644 index 000000000..5cbeae23e --- /dev/null +++ b/internal/db/migrations/000016_compute_cluster_users_status.up.sql @@ -0,0 +1,20 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +ALTER TABLE compute_cluster_users + ADD COLUMN status VARCHAR(32) NOT NULL DEFAULT 'ACTIVE' AFTER local_username, + ADD KEY idx_compute_cluster_users_status (status); diff --git a/internal/db/migrations/000017_external_identities_and_user_dns.down.sql b/internal/db/migrations/000017_external_identities_and_user_dns.down.sql new file mode 100644 index 000000000..88dae2ca8 --- /dev/null +++ b/internal/db/migrations/000017_external_identities_and_user_dns.down.sql @@ -0,0 +1,19 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +DROP TABLE IF EXISTS user_dns; +DROP TABLE IF EXISTS external_identities; diff --git a/internal/db/migrations/000017_external_identities_and_user_dns.up.sql b/internal/db/migrations/000017_external_identities_and_user_dns.up.sql new file mode 100644 index 000000000..b093a2ff8 --- /dev/null +++ b/internal/db/migrations/000017_external_identities_and_user_dns.up.sql @@ -0,0 +1,50 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- oidc_sub is nullable: not every external identity issues an OIDC subject +-- (AMIE binds by external_id only). UNIQUE permits multiple NULLs but blocks +-- collisions on real values across IdPs. +CREATE TABLE IF NOT EXISTS external_identities +( + id VARCHAR(255) NOT NULL, + user_id VARCHAR(255) NOT NULL, + source VARCHAR(64) NOT NULL, + external_id VARCHAR(255) NOT NULL, + oidc_sub VARCHAR(255) NULL DEFAULT NULL, + metadata TEXT NULL DEFAULT NULL, + created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + updated_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + PRIMARY KEY (id), + UNIQUE KEY uq_external_identities_source_external (source, external_id), + UNIQUE KEY uq_external_identities_oidc_sub (oidc_sub), + KEY idx_external_identities_user (user_id), + CONSTRAINT fk_external_identities_user FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci; + +-- A DN is a globally-unique credential. UNIQUE on dn alone subsumes +-- (user_id, dn), so the composite index is omitted. +CREATE TABLE IF NOT EXISTS user_dns +( + id VARCHAR(255) NOT NULL, + user_id VARCHAR(255) NOT NULL, + dn VARCHAR(512) NOT NULL, + created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + PRIMARY KEY (id), + UNIQUE KEY uq_user_dns_dn (dn), + KEY idx_user_dns_user (user_id), + CONSTRAINT fk_user_dns_user FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci; diff --git a/internal/db/migrations/000018_user_merges.down.sql b/internal/db/migrations/000018_user_merges.down.sql new file mode 100644 index 000000000..bc65c15a1 --- /dev/null +++ b/internal/db/migrations/000018_user_merges.down.sql @@ -0,0 +1,18 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +DROP TABLE IF EXISTS user_merges; diff --git a/internal/db/migrations/000018_user_merges.up.sql b/internal/db/migrations/000018_user_merges.up.sql new file mode 100644 index 000000000..7aeeb8844 --- /dev/null +++ b/internal/db/migrations/000018_user_merges.up.sql @@ -0,0 +1,30 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +CREATE TABLE IF NOT EXISTS user_merges +( + id BIGINT NOT NULL AUTO_INCREMENT, + retiring_user_id VARCHAR(255) NOT NULL, + surviving_user_id VARCHAR(255) NOT NULL, + reason TEXT NULL, + merged_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + PRIMARY KEY (id), + UNIQUE KEY uq_user_merges_retiring (retiring_user_id), + KEY idx_user_merges_surviving (surviving_user_id), + CONSTRAINT fk_user_merges_retiring FOREIGN KEY (retiring_user_id) REFERENCES users (id) ON DELETE RESTRICT, + CONSTRAINT fk_user_merges_surviving FOREIGN KEY (surviving_user_id) REFERENCES users (id) ON DELETE RESTRICT +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci; diff --git a/internal/server/server.go b/internal/server/server.go index 6af71d2df..05685d59a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -56,9 +56,14 @@ func (s *Server) routes() { s.mux.HandleFunc("POST /users", s.createUser) s.mux.HandleFunc("GET /users/{id}", s.getUser) + s.mux.HandleFunc("PUT /users/{id}/status", s.updateUserStatus) + s.mux.HandleFunc("POST /users/merge", s.mergeUsers) + s.mux.HandleFunc("GET /users/{id}/merge", s.getUserMergeByRetiringUser) + s.mux.HandleFunc("GET /users/{id}/merged-users", s.listUserMergesBySurvivingUser) s.mux.HandleFunc("POST /projects", s.createProject) s.mux.HandleFunc("GET /projects/{id}", s.getProject) + s.mux.HandleFunc("PUT /projects/{id}/status", s.updateProjectStatus) s.mux.HandleFunc("POST /compute-clusters", s.createComputeCluster) s.mux.HandleFunc("GET /compute-clusters", s.listComputeClusters) @@ -67,6 +72,7 @@ func (s *Server) routes() { s.mux.HandleFunc("POST /compute-cluster-users", s.createComputeClusterUser) s.mux.HandleFunc("GET /compute-cluster-users/{id}", s.getComputeClusterUser) s.mux.HandleFunc("PUT /compute-cluster-users/{id}", s.updateComputeClusterUser) + s.mux.HandleFunc("PUT /compute-cluster-users/{id}/status", s.updateComputeClusterUserStatus) s.mux.HandleFunc("DELETE /compute-cluster-users/{id}", s.deleteComputeClusterUser) s.mux.HandleFunc("GET /compute-clusters/{id}/users", s.listComputeClusterUsersByCluster) s.mux.HandleFunc("GET /compute-clusters/{id}/users/{userId}", s.getComputeClusterUserByPair) @@ -131,6 +137,20 @@ func (s *Server) routes() { s.mux.HandleFunc("GET /compute-allocations/{id}/usages/total", s.getTotalSUUsageForAllocation) s.mux.HandleFunc("GET /compute-allocations/{id}/users/{userId}/usages/total", s.getTotalSUUsageForUserInAllocation) s.mux.HandleFunc("GET /users/{id}/compute-allocation-usages", s.listUsagesByUser) + + s.mux.HandleFunc("POST /external-identities", s.createExternalIdentity) + s.mux.HandleFunc("GET /external-identities/{id}", s.getExternalIdentity) + s.mux.HandleFunc("PUT /external-identities/{id}", s.updateExternalIdentity) + s.mux.HandleFunc("DELETE /external-identities/{id}", s.deleteExternalIdentity) + s.mux.HandleFunc("GET /external-identities/sources/{source}/external/{externalId}", s.getExternalIdentityBySourceAndExternalID) + s.mux.HandleFunc("GET /external-identities/oidc-subjects/{oidcSub}", s.getExternalIdentityByOIDCSub) + s.mux.HandleFunc("GET /users/{id}/external-identities", s.listExternalIdentitiesForUser) + + s.mux.HandleFunc("POST /user-dns", s.addUserDN) + s.mux.HandleFunc("GET /user-dns/{id}", s.getUserDN) + s.mux.HandleFunc("DELETE /user-dns/{id}", s.removeUserDN) + s.mux.HandleFunc("GET /user-dns/lookup", s.getUserDNByDN) + s.mux.HandleFunc("GET /users/{id}/user-dns", s.listUserDNs) } func (s *Server) healthz(w http.ResponseWriter, _ *http.Request) { @@ -862,6 +882,216 @@ func (s *Server) getTotalSUUsageForUserInAllocation(w http.ResponseWriter, r *ht }) } +type statusUpdateRequest struct { + Status string `json:"status"` +} + +func (s *Server) updateUserStatus(w http.ResponseWriter, r *http.Request) { + var req statusUpdateRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + u, err := s.svc.UpdateUserStatus(r.Context(), r.PathValue("id"), models.UserStatus(req.Status)) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, u) +} + +func (s *Server) updateProjectStatus(w http.ResponseWriter, r *http.Request) { + var req statusUpdateRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + p, err := s.svc.UpdateProjectStatus(r.Context(), r.PathValue("id"), models.ProjectStatus(req.Status)) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, p) +} + +func (s *Server) updateComputeClusterUserStatus(w http.ResponseWriter, r *http.Request) { + var req statusUpdateRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + cu, err := s.svc.UpdateComputeClusterUserStatus(r.Context(), r.PathValue("id"), models.AllocationStatus(req.Status)) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, cu) +} + +func (s *Server) createExternalIdentity(w http.ResponseWriter, r *http.Request) { + var e models.ExternalIdentity + if err := decodeJSON(r, &e); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + created, err := s.svc.CreateExternalIdentity(r.Context(), &e) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusCreated, created) +} + +func (s *Server) getExternalIdentity(w http.ResponseWriter, r *http.Request) { + e, err := s.svc.GetExternalIdentity(r.Context(), r.PathValue("id")) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, e) +} + +func (s *Server) getExternalIdentityBySourceAndExternalID(w http.ResponseWriter, r *http.Request) { + e, err := s.svc.GetExternalIdentityBySourceAndExternalID(r.Context(), r.PathValue("source"), r.PathValue("externalId")) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, e) +} + +func (s *Server) getExternalIdentityByOIDCSub(w http.ResponseWriter, r *http.Request) { + e, err := s.svc.GetExternalIdentityByOIDCSub(r.Context(), r.PathValue("oidcSub")) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, e) +} + +func (s *Server) listExternalIdentitiesForUser(w http.ResponseWriter, r *http.Request) { + out, err := s.svc.ListExternalIdentitiesForUser(r.Context(), r.PathValue("id")) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, out) +} + +func (s *Server) updateExternalIdentity(w http.ResponseWriter, r *http.Request) { + var e models.ExternalIdentity + if err := decodeJSON(r, &e); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + e.ID = r.PathValue("id") + if err := s.svc.UpdateExternalIdentity(r.Context(), &e); err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, &e) +} + +func (s *Server) deleteExternalIdentity(w http.ResponseWriter, r *http.Request) { + if err := s.svc.DeleteExternalIdentity(r.Context(), r.PathValue("id")); err != nil { + writeServiceError(w, err) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (s *Server) addUserDN(w http.ResponseWriter, r *http.Request) { + var d models.UserDN + if err := decodeJSON(r, &d); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + created, err := s.svc.AddUserDN(r.Context(), &d) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusCreated, created) +} + +func (s *Server) getUserDN(w http.ResponseWriter, r *http.Request) { + d, err := s.svc.GetUserDN(r.Context(), r.PathValue("id")) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, d) +} + +func (s *Server) getUserDNByDN(w http.ResponseWriter, r *http.Request) { + dn := r.URL.Query().Get("dn") + if dn == "" { + writeError(w, http.StatusBadRequest, errors.New("dn query parameter is required")) + return + } + d, err := s.svc.GetUserDNByDN(r.Context(), dn) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, d) +} + +func (s *Server) listUserDNs(w http.ResponseWriter, r *http.Request) { + out, err := s.svc.ListUserDNs(r.Context(), r.PathValue("id")) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, out) +} + +func (s *Server) removeUserDN(w http.ResponseWriter, r *http.Request) { + if err := s.svc.RemoveUserDN(r.Context(), r.PathValue("id")); err != nil { + writeServiceError(w, err) + return + } + w.WriteHeader(http.StatusNoContent) +} + +type mergeUsersRequest struct { + SurvivingUserID string `json:"surviving_user_id"` + RetiringUserID string `json:"retiring_user_id"` + Reason string `json:"reason,omitempty"` +} + +func (s *Server) mergeUsers(w http.ResponseWriter, r *http.Request) { + var req mergeUsersRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + survivor, err := s.svc.MergeUsers(r.Context(), req.SurvivingUserID, req.RetiringUserID, req.Reason) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, survivor) +} + +func (s *Server) getUserMergeByRetiringUser(w http.ResponseWriter, r *http.Request) { + m, err := s.svc.GetUserMergeByRetiringUser(r.Context(), r.PathValue("id")) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, m) +} + +func (s *Server) listUserMergesBySurvivingUser(w http.ResponseWriter, r *http.Request) { + out, err := s.svc.ListUserMergesBySurvivingUser(r.Context(), r.PathValue("id")) + if err != nil { + writeServiceError(w, err) + return + } + writeJSON(w, http.StatusOK, out) +} + // LoggingMiddleware logs every request once it completes. func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/store/compute_allocation_membership_store.go b/internal/store/compute_allocation_membership_store.go index baf217782..e79c842fb 100644 --- a/internal/store/compute_allocation_membership_store.go +++ b/internal/store/compute_allocation_membership_store.go @@ -115,6 +115,24 @@ func (s *mysqlComputeAllocationMembershipStore) Update(ctx context.Context, tx * return err } +func (s *mysqlComputeAllocationMembershipStore) ReassignUser(ctx context.Context, tx *sql.Tx, fromUserID, toUserID string) error { + if _, err := tx.ExecContext(ctx, + `DELETE FROM compute_allocation_memberships + WHERE user_id = ? + AND compute_allocation_id IN ( + SELECT compute_allocation_id FROM ( + SELECT compute_allocation_id FROM compute_allocation_memberships WHERE user_id = ? + ) AS s + )`, + fromUserID, toUserID); err != nil { + return err + } + _, err := tx.ExecContext(ctx, + `UPDATE compute_allocation_memberships SET user_id = ? WHERE user_id = ?`, + toUserID, fromUserID) + return err +} + func (s *mysqlComputeAllocationMembershipStore) Delete(ctx context.Context, tx *sql.Tx, id string) error { _, err := tx.ExecContext(ctx, `DELETE FROM compute_allocation_memberships WHERE id = ?`, id) return err diff --git a/internal/store/compute_cluster_user_store.go b/internal/store/compute_cluster_user_store.go index f5926ad1e..c345d0330 100644 --- a/internal/store/compute_cluster_user_store.go +++ b/internal/store/compute_cluster_user_store.go @@ -36,10 +36,12 @@ func NewComputeClusterUserStore(db *sqlx.DB) ComputeClusterUserStore { return &mysqlComputeClusterUserStore{db: db} } +const computeClusterUserColumns = `id, compute_cluster_id, user_id, local_username, status` + func (s *mysqlComputeClusterUserStore) FindByID(ctx context.Context, id string) (*models.ComputeClusterUser, error) { var c models.ComputeClusterUser err := s.db.GetContext(ctx, &c, - `SELECT id, compute_cluster_id, user_id, local_username + `SELECT `+computeClusterUserColumns+` FROM compute_cluster_users WHERE id = ?`, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -53,7 +55,7 @@ func (s *mysqlComputeClusterUserStore) FindByID(ctx context.Context, id string) func (s *mysqlComputeClusterUserStore) FindByPair(ctx context.Context, clusterID, userID string) (*models.ComputeClusterUser, error) { var c models.ComputeClusterUser err := s.db.GetContext(ctx, &c, - `SELECT id, compute_cluster_id, user_id, local_username + `SELECT `+computeClusterUserColumns+` FROM compute_cluster_users WHERE compute_cluster_id = ? AND user_id = ?`, clusterID, userID) if err != nil { @@ -68,7 +70,7 @@ func (s *mysqlComputeClusterUserStore) FindByPair(ctx context.Context, clusterID func (s *mysqlComputeClusterUserStore) FindByCluster(ctx context.Context, clusterID string) ([]models.ComputeClusterUser, error) { var users []models.ComputeClusterUser err := s.db.SelectContext(ctx, &users, - `SELECT id, compute_cluster_id, user_id, local_username + `SELECT `+computeClusterUserColumns+` FROM compute_cluster_users WHERE compute_cluster_id = ? ORDER BY local_username`, clusterID) @@ -81,7 +83,7 @@ func (s *mysqlComputeClusterUserStore) FindByCluster(ctx context.Context, cluste func (s *mysqlComputeClusterUserStore) FindByUser(ctx context.Context, userID string) ([]models.ComputeClusterUser, error) { var users []models.ComputeClusterUser err := s.db.SelectContext(ctx, &users, - `SELECT id, compute_cluster_id, user_id, local_username + `SELECT `+computeClusterUserColumns+` FROM compute_cluster_users WHERE user_id = ? ORDER BY compute_cluster_id`, userID) @@ -93,9 +95,9 @@ func (s *mysqlComputeClusterUserStore) FindByUser(ctx context.Context, userID st func (s *mysqlComputeClusterUserStore) Create(ctx context.Context, tx *sql.Tx, c *models.ComputeClusterUser) error { _, err := tx.ExecContext(ctx, - `INSERT INTO compute_cluster_users (id, compute_cluster_id, user_id, local_username) - VALUES (?, ?, ?, ?)`, - c.ID, c.ComputeClusterID, c.UserID, c.LocalUsername) + `INSERT INTO compute_cluster_users (id, compute_cluster_id, user_id, local_username, status) + VALUES (?, ?, ?, ?, ?)`, + c.ID, c.ComputeClusterID, c.UserID, c.LocalUsername, c.Status) return err } @@ -104,9 +106,35 @@ func (s *mysqlComputeClusterUserStore) Update(ctx context.Context, tx *sql.Tx, c `UPDATE compute_cluster_users SET compute_cluster_id = ?, user_id = ?, - local_username = ? + local_username = ?, + status = ? WHERE id = ?`, - c.ComputeClusterID, c.UserID, c.LocalUsername, c.ID) + c.ComputeClusterID, c.UserID, c.LocalUsername, c.Status, c.ID) + return err +} + +func (s *mysqlComputeClusterUserStore) UpdateStatus(ctx context.Context, tx *sql.Tx, id string, status models.AllocationStatus) error { + _, err := tx.ExecContext(ctx, + `UPDATE compute_cluster_users SET status = ? WHERE id = ?`, + status, id) + return err +} + +func (s *mysqlComputeClusterUserStore) ReassignUser(ctx context.Context, tx *sql.Tx, fromUserID, toUserID string) error { + if _, err := tx.ExecContext(ctx, + `DELETE FROM compute_cluster_users + WHERE user_id = ? + AND compute_cluster_id IN ( + SELECT compute_cluster_id FROM ( + SELECT compute_cluster_id FROM compute_cluster_users WHERE user_id = ? + ) AS s + )`, + fromUserID, toUserID); err != nil { + return err + } + _, err := tx.ExecContext(ctx, + `UPDATE compute_cluster_users SET user_id = ? WHERE user_id = ?`, + toUserID, fromUserID) return err } diff --git a/internal/store/external_identity_store.go b/internal/store/external_identity_store.go new file mode 100644 index 000000000..4728f2807 --- /dev/null +++ b/internal/store/external_identity_store.go @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package store + +import ( + "context" + "database/sql" + "errors" + + "github.com/jmoiron/sqlx" + + "github.com/apache/airavata-custos/pkg/models" +) + +type mysqlExternalIdentityStore struct { + db *sqlx.DB +} + +// NewExternalIdentityStore returns a MySQL-backed ExternalIdentityStore. +func NewExternalIdentityStore(db *sqlx.DB) ExternalIdentityStore { + return &mysqlExternalIdentityStore{db: db} +} + +// oidc_sub and metadata are nullable; project NULL to ” for the model. +const externalIdentityColumns = `id, user_id, source, external_id, COALESCE(oidc_sub, '') AS oidc_sub, COALESCE(metadata, '') AS metadata, created_at` + +// nullableString returns nil when s is empty so the column stores SQL NULL +// rather than ”. NULL is the only value MySQL UNIQUE allows to repeat. +func nullableString(s string) any { + if s == "" { + return nil + } + return s +} + +func (s *mysqlExternalIdentityStore) FindByID(ctx context.Context, id string) (*models.ExternalIdentity, error) { + var e models.ExternalIdentity + err := s.db.GetContext(ctx, &e, + `SELECT `+externalIdentityColumns+` FROM external_identities WHERE id = ?`, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &e, nil +} + +func (s *mysqlExternalIdentityStore) FindBySourceAndExternalID(ctx context.Context, source, externalID string) (*models.ExternalIdentity, error) { + var e models.ExternalIdentity + err := s.db.GetContext(ctx, &e, + `SELECT `+externalIdentityColumns+` FROM external_identities WHERE source = ? AND external_id = ?`, + source, externalID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &e, nil +} + +func (s *mysqlExternalIdentityStore) FindByOIDCSub(ctx context.Context, oidcSub string) (*models.ExternalIdentity, error) { + if oidcSub == "" { + return nil, nil + } + var e models.ExternalIdentity + err := s.db.GetContext(ctx, &e, + `SELECT `+externalIdentityColumns+` FROM external_identities WHERE oidc_sub = ?`, oidcSub) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &e, nil +} + +func (s *mysqlExternalIdentityStore) FindByUser(ctx context.Context, userID string) ([]models.ExternalIdentity, error) { + var out []models.ExternalIdentity + err := s.db.SelectContext(ctx, &out, + `SELECT `+externalIdentityColumns+` FROM external_identities WHERE user_id = ? ORDER BY created_at ASC`, + userID) + if err != nil { + return nil, err + } + return out, nil +} + +func (s *mysqlExternalIdentityStore) Create(ctx context.Context, tx *sql.Tx, e *models.ExternalIdentity) error { + _, err := tx.ExecContext(ctx, + `INSERT INTO external_identities (id, user_id, source, external_id, oidc_sub, metadata) + VALUES (?, ?, ?, ?, ?, ?)`, + e.ID, e.UserID, e.Source, e.ExternalID, nullableString(e.OIDCSub), nullableString(e.Metadata)) + return err +} + +func (s *mysqlExternalIdentityStore) Update(ctx context.Context, tx *sql.Tx, e *models.ExternalIdentity) error { + _, err := tx.ExecContext(ctx, + `UPDATE external_identities + SET user_id = ?, source = ?, external_id = ?, oidc_sub = ?, metadata = ? + WHERE id = ?`, + e.UserID, e.Source, e.ExternalID, nullableString(e.OIDCSub), nullableString(e.Metadata), e.ID) + return err +} + +func (s *mysqlExternalIdentityStore) ReassignUser(ctx context.Context, tx *sql.Tx, fromUserID, toUserID string) error { + _, err := tx.ExecContext(ctx, + `UPDATE external_identities SET user_id = ? WHERE user_id = ?`, + toUserID, fromUserID) + return err +} + +func (s *mysqlExternalIdentityStore) Delete(ctx context.Context, tx *sql.Tx, id string) error { + _, err := tx.ExecContext(ctx, `DELETE FROM external_identities WHERE id = ?`, id) + return err +} diff --git a/internal/store/project_store.go b/internal/store/project_store.go index fc447500a..13fe71fbd 100644 --- a/internal/store/project_store.go +++ b/internal/store/project_store.go @@ -36,11 +36,12 @@ func NewProjectStore(db *sqlx.DB) ProjectStore { return &mysqlProjectStore{db: db} } +const projectColumns = `id, originated_id, title, origination, project_pi_id, status, created_time` + func (s *mysqlProjectStore) FindByID(ctx context.Context, id string) (*models.Project, error) { var p models.Project err := s.db.GetContext(ctx, &p, - `SELECT id, originated_id, title, origination, project_pi_id, created_time - FROM projects WHERE id = ?`, id) + `SELECT `+projectColumns+` FROM projects WHERE id = ?`, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -53,8 +54,7 @@ func (s *mysqlProjectStore) FindByID(ctx context.Context, id string) (*models.Pr func (s *mysqlProjectStore) FindByOriginatedID(ctx context.Context, originatedID string) (*models.Project, error) { var p models.Project err := s.db.GetContext(ctx, &p, - `SELECT id, originated_id, title, origination, project_pi_id, created_time - FROM projects WHERE originated_id = ?`, originatedID) + `SELECT `+projectColumns+` FROM projects WHERE originated_id = ?`, originatedID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -67,8 +67,7 @@ func (s *mysqlProjectStore) FindByOriginatedID(ctx context.Context, originatedID func (s *mysqlProjectStore) FindByPI(ctx context.Context, piUserID string) ([]models.Project, error) { var projects []models.Project err := s.db.SelectContext(ctx, &projects, - `SELECT id, originated_id, title, origination, project_pi_id, created_time - FROM projects WHERE project_pi_id = ?`, piUserID) + `SELECT `+projectColumns+` FROM projects WHERE project_pi_id = ?`, piUserID) if err != nil { return nil, err } @@ -77,17 +76,31 @@ func (s *mysqlProjectStore) FindByPI(ctx context.Context, piUserID string) ([]mo func (s *mysqlProjectStore) Create(ctx context.Context, tx *sql.Tx, p *models.Project) error { _, err := tx.ExecContext(ctx, - `INSERT INTO projects (id, originated_id, title, origination, project_pi_id, created_time) - VALUES (?, ?, ?, ?, ?, ?)`, - p.ID, p.OriginatedID, p.Title, p.Origination, p.ProjectPIID, p.CreatedTime) + `INSERT INTO projects (id, originated_id, title, origination, project_pi_id, status, created_time) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + p.ID, p.OriginatedID, p.Title, p.Origination, p.ProjectPIID, p.Status, p.CreatedTime) return err } func (s *mysqlProjectStore) Update(ctx context.Context, tx *sql.Tx, p *models.Project) error { _, err := tx.ExecContext(ctx, - `UPDATE projects SET originated_id = ?, title = ?, origination = ?, project_pi_id = ? + `UPDATE projects SET originated_id = ?, title = ?, origination = ?, project_pi_id = ?, status = ? WHERE id = ?`, - p.OriginatedID, p.Title, p.Origination, p.ProjectPIID, p.ID) + p.OriginatedID, p.Title, p.Origination, p.ProjectPIID, p.Status, p.ID) + return err +} + +func (s *mysqlProjectStore) UpdateStatus(ctx context.Context, tx *sql.Tx, id string, status models.ProjectStatus) error { + _, err := tx.ExecContext(ctx, + `UPDATE projects SET status = ? WHERE id = ?`, + status, id) + return err +} + +func (s *mysqlProjectStore) ReassignPI(ctx context.Context, tx *sql.Tx, fromUserID, toUserID string) error { + _, err := tx.ExecContext(ctx, + `UPDATE projects SET project_pi_id = ? WHERE project_pi_id = ?`, + toUserID, fromUserID) return err } diff --git a/internal/store/store.go b/internal/store/store.go index b066df428..32824f7ff 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -37,10 +37,23 @@ type UserStore interface { Create(ctx context.Context, tx *sql.Tx, u *models.User) error // Update replaces mutable fields of an existing user within the provided transaction. Update(ctx context.Context, tx *sql.Tx, u *models.User) error + // UpdateStatus sets the lifecycle status of an existing user within the provided transaction. + UpdateStatus(ctx context.Context, tx *sql.Tx, id string, status models.UserStatus) error // Delete removes a user by ID within the provided transaction. Delete(ctx context.Context, tx *sql.Tx, id string) error } +// UserMergeStore records when one user is consolidated into another. Rows are +// append-only; each retiring user can be merged at most once. +type UserMergeStore interface { + // Record inserts a new merge record within the provided transaction. + Record(ctx context.Context, tx *sql.Tx, retiringUserID, survivingUserID, reason string) error + // FindByRetiringUser returns the merge record whose retiring user matches, or nil if absent. + FindByRetiringUser(ctx context.Context, retiringUserID string) (*models.UserMerge, error) + // FindBySurvivingUser returns every merge record whose survivor matches, oldest first. + FindBySurvivingUser(ctx context.Context, survivingUserID string) ([]models.UserMerge, error) +} + // OrganizationStore defines persistence operations for organizations. type OrganizationStore interface { // FindByID returns the organization with the given ID, or nil if not found. @@ -86,10 +99,53 @@ type ComputeClusterUserStore interface { Create(ctx context.Context, tx *sql.Tx, c *models.ComputeClusterUser) error // Update replaces mutable fields of an existing mapping within the provided transaction. Update(ctx context.Context, tx *sql.Tx, c *models.ComputeClusterUser) error + // UpdateStatus sets the lifecycle status of an existing mapping within the provided transaction. + UpdateStatus(ctx context.Context, tx *sql.Tx, id string, status models.AllocationStatus) error + // ReassignUser moves every mapping owned by fromUserID over to toUserID, + // dropping fromUserID's rows on clusters where toUserID already has one. + ReassignUser(ctx context.Context, tx *sql.Tx, fromUserID, toUserID string) error // Delete removes a mapping by ID within the provided transaction. Delete(ctx context.Context, tx *sql.Tx, id string) error } +// ExternalIdentityStore defines persistence operations for external-identity +// bindings between Custos users and identifiers issued by external systems. +type ExternalIdentityStore interface { + // FindByID returns the external identity with the given ID, or nil if not found. + FindByID(ctx context.Context, id string) (*models.ExternalIdentity, error) + // FindBySourceAndExternalID returns the binding for the given (source, external_id) pair, or nil if absent. + FindBySourceAndExternalID(ctx context.Context, source, externalID string) (*models.ExternalIdentity, error) + // FindByOIDCSub returns the first binding matching the given OIDC subject, or nil if none. + FindByOIDCSub(ctx context.Context, oidcSub string) (*models.ExternalIdentity, error) + // FindByUser returns every external identity bound to the given user, ordered by created_at. + FindByUser(ctx context.Context, userID string) ([]models.ExternalIdentity, error) + // Create inserts a new external identity within the provided transaction. + Create(ctx context.Context, tx *sql.Tx, e *models.ExternalIdentity) error + // Update replaces mutable fields of an existing external identity within the provided transaction. + Update(ctx context.Context, tx *sql.Tx, e *models.ExternalIdentity) error + // ReassignUser moves every external identity owned by fromUserID over to toUserID. + ReassignUser(ctx context.Context, tx *sql.Tx, fromUserID, toUserID string) error + // Delete removes an external identity by ID within the provided transaction. + Delete(ctx context.Context, tx *sql.Tx, id string) error +} + +// UserDNStore defines persistence operations for X.509 distinguished-name +// bindings against a Custos user. +type UserDNStore interface { + // FindByID returns the DN binding with the given ID, or nil if not found. + FindByID(ctx context.Context, id string) (*models.UserDN, error) + // FindByDN returns the binding matching the given DN, or nil if absent. + FindByDN(ctx context.Context, dn string) (*models.UserDN, error) + // FindByUser returns every DN bound to the given user, ordered by created_at. + FindByUser(ctx context.Context, userID string) ([]models.UserDN, error) + // Create inserts a new DN binding within the provided transaction. + Create(ctx context.Context, tx *sql.Tx, d *models.UserDN) error + // ReassignUser moves every DN owned by fromUserID over to toUserID, dropping duplicates. + ReassignUser(ctx context.Context, tx *sql.Tx, fromUserID, toUserID string) error + // Delete removes a DN binding by ID within the provided transaction. + Delete(ctx context.Context, tx *sql.Tx, id string) error +} + // ProjectStore defines persistence operations for projects. type ProjectStore interface { // FindByID returns the project with the given ID, or nil if not found. @@ -102,6 +158,10 @@ type ProjectStore interface { Create(ctx context.Context, tx *sql.Tx, p *models.Project) error // Update replaces mutable fields of an existing project within the provided transaction. Update(ctx context.Context, tx *sql.Tx, p *models.Project) error + // UpdateStatus sets the lifecycle status of an existing project within the provided transaction. + UpdateStatus(ctx context.Context, tx *sql.Tx, id string, status models.ProjectStatus) error + // ReassignPI changes project_pi_id from fromUserID to toUserID for every project. + ReassignPI(ctx context.Context, tx *sql.Tx, fromUserID, toUserID string) error // Delete removes a project by ID within the provided transaction. Delete(ctx context.Context, tx *sql.Tx, id string) error } @@ -250,6 +310,9 @@ type ComputeAllocationMembershipStore interface { Create(ctx context.Context, tx *sql.Tx, m *models.ComputeAllocationMembership) error // Update replaces mutable fields of an existing membership within the provided transaction. Update(ctx context.Context, tx *sql.Tx, m *models.ComputeAllocationMembership) error + // ReassignUser moves every membership owned by fromUserID over to toUserID, + // dropping fromUserID's rows on allocations where toUserID already has one. + ReassignUser(ctx context.Context, tx *sql.Tx, fromUserID, toUserID string) error // Delete removes a membership by ID within the provided transaction. Delete(ctx context.Context, tx *sql.Tx, id string) error } diff --git a/internal/store/user_dn_store.go b/internal/store/user_dn_store.go new file mode 100644 index 000000000..bfe6f69fe --- /dev/null +++ b/internal/store/user_dn_store.go @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package store + +import ( + "context" + "database/sql" + "errors" + + "github.com/jmoiron/sqlx" + + "github.com/apache/airavata-custos/pkg/models" +) + +type mysqlUserDNStore struct { + db *sqlx.DB +} + +// NewUserDNStore returns a MySQL-backed UserDNStore. +func NewUserDNStore(db *sqlx.DB) UserDNStore { + return &mysqlUserDNStore{db: db} +} + +const userDNColumns = `id, user_id, dn, created_at` + +func (s *mysqlUserDNStore) FindByID(ctx context.Context, id string) (*models.UserDN, error) { + var d models.UserDN + err := s.db.GetContext(ctx, &d, + `SELECT `+userDNColumns+` FROM user_dns WHERE id = ?`, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &d, nil +} + +func (s *mysqlUserDNStore) FindByDN(ctx context.Context, dn string) (*models.UserDN, error) { + var d models.UserDN + err := s.db.GetContext(ctx, &d, + `SELECT `+userDNColumns+` FROM user_dns WHERE dn = ?`, dn) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &d, nil +} + +func (s *mysqlUserDNStore) FindByUser(ctx context.Context, userID string) ([]models.UserDN, error) { + var out []models.UserDN + err := s.db.SelectContext(ctx, &out, + `SELECT `+userDNColumns+` FROM user_dns WHERE user_id = ? ORDER BY created_at ASC`, + userID) + if err != nil { + return nil, err + } + return out, nil +} + +func (s *mysqlUserDNStore) Create(ctx context.Context, tx *sql.Tx, d *models.UserDN) error { + _, err := tx.ExecContext(ctx, + `INSERT INTO user_dns (id, user_id, dn) VALUES (?, ?, ?)`, + d.ID, d.UserID, d.DN) + return err +} + +func (s *mysqlUserDNStore) ReassignUser(ctx context.Context, tx *sql.Tx, fromUserID, toUserID string) error { + // Drop fromUserID's DNs already held by the survivor, then move the rest. + if _, err := tx.ExecContext(ctx, + `DELETE FROM user_dns + WHERE user_id = ? + AND dn IN (SELECT dn FROM (SELECT dn FROM user_dns WHERE user_id = ?) AS s)`, + fromUserID, toUserID); err != nil { + return err + } + _, err := tx.ExecContext(ctx, + `UPDATE user_dns SET user_id = ? WHERE user_id = ?`, + toUserID, fromUserID) + return err +} + +func (s *mysqlUserDNStore) Delete(ctx context.Context, tx *sql.Tx, id string) error { + _, err := tx.ExecContext(ctx, `DELETE FROM user_dns WHERE id = ?`, id) + return err +} diff --git a/internal/store/user_merge_store.go b/internal/store/user_merge_store.go new file mode 100644 index 000000000..827462e55 --- /dev/null +++ b/internal/store/user_merge_store.go @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package store + +import ( + "context" + "database/sql" + "errors" + + "github.com/jmoiron/sqlx" + + "github.com/apache/airavata-custos/pkg/models" +) + +type mysqlUserMergeStore struct { + db *sqlx.DB +} + +// NewUserMergeStore returns a MySQL-backed UserMergeStore. +func NewUserMergeStore(db *sqlx.DB) UserMergeStore { + return &mysqlUserMergeStore{db: db} +} + +const userMergeColumns = `id, retiring_user_id, surviving_user_id, COALESCE(reason, '') AS reason, merged_at` + +func (s *mysqlUserMergeStore) Record(ctx context.Context, tx *sql.Tx, retiringUserID, survivingUserID, reason string) error { + var reasonArg any + if reason == "" { + reasonArg = nil + } else { + reasonArg = reason + } + _, err := tx.ExecContext(ctx, + `INSERT INTO user_merges (retiring_user_id, surviving_user_id, reason) + VALUES (?, ?, ?)`, + retiringUserID, survivingUserID, reasonArg) + return err +} + +func (s *mysqlUserMergeStore) FindByRetiringUser(ctx context.Context, retiringUserID string) (*models.UserMerge, error) { + var m models.UserMerge + err := s.db.GetContext(ctx, &m, + `SELECT `+userMergeColumns+` FROM user_merges WHERE retiring_user_id = ?`, retiringUserID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &m, nil +} + +func (s *mysqlUserMergeStore) FindBySurvivingUser(ctx context.Context, survivingUserID string) ([]models.UserMerge, error) { + var out []models.UserMerge + err := s.db.SelectContext(ctx, &out, + `SELECT `+userMergeColumns+` FROM user_merges WHERE surviving_user_id = ? ORDER BY merged_at ASC`, + survivingUserID) + if err != nil { + return nil, err + } + return out, nil +} diff --git a/internal/store/user_store.go b/internal/store/user_store.go index eda2b749f..6c88f3e3f 100644 --- a/internal/store/user_store.go +++ b/internal/store/user_store.go @@ -36,11 +36,12 @@ func NewUserStore(db *sqlx.DB) UserStore { return &mysqlUserStore{db: db} } +const userColumns = `id, organization_id, first_name, last_name, middle_name, email, status` + func (s *mysqlUserStore) FindByID(ctx context.Context, id string) (*models.User, error) { var u models.User err := s.db.GetContext(ctx, &u, - `SELECT id, organization_id, first_name, last_name, middle_name, email - FROM users WHERE id = ?`, id) + `SELECT `+userColumns+` FROM users WHERE id = ?`, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -53,8 +54,7 @@ func (s *mysqlUserStore) FindByID(ctx context.Context, id string) (*models.User, func (s *mysqlUserStore) FindByEmail(ctx context.Context, email string) (*models.User, error) { var u models.User err := s.db.GetContext(ctx, &u, - `SELECT id, organization_id, first_name, last_name, middle_name, email - FROM users WHERE email = ?`, email) + `SELECT `+userColumns+` FROM users WHERE email = ?`, email) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -67,8 +67,7 @@ func (s *mysqlUserStore) FindByEmail(ctx context.Context, email string) (*models func (s *mysqlUserStore) FindByOrganization(ctx context.Context, organizationID string) ([]models.User, error) { var users []models.User err := s.db.SelectContext(ctx, &users, - `SELECT id, organization_id, first_name, last_name, middle_name, email - FROM users WHERE organization_id = ?`, organizationID) + `SELECT `+userColumns+` FROM users WHERE organization_id = ?`, organizationID) if err != nil { return nil, err } @@ -77,17 +76,24 @@ func (s *mysqlUserStore) FindByOrganization(ctx context.Context, organizationID func (s *mysqlUserStore) Create(ctx context.Context, tx *sql.Tx, u *models.User) error { _, err := tx.ExecContext(ctx, - `INSERT INTO users (id, organization_id, first_name, last_name, middle_name, email) - VALUES (?, ?, ?, ?, ?, ?)`, - u.ID, u.OrganizationID, u.FirstName, u.LastName, u.MiddleName, u.Email) + `INSERT INTO users (id, organization_id, first_name, last_name, middle_name, email, status) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + u.ID, u.OrganizationID, u.FirstName, u.LastName, u.MiddleName, u.Email, u.Status) return err } func (s *mysqlUserStore) Update(ctx context.Context, tx *sql.Tx, u *models.User) error { _, err := tx.ExecContext(ctx, - `UPDATE users SET organization_id = ?, first_name = ?, last_name = ?, middle_name = ?, email = ? + `UPDATE users SET organization_id = ?, first_name = ?, last_name = ?, middle_name = ?, email = ?, status = ? WHERE id = ?`, - u.OrganizationID, u.FirstName, u.LastName, u.MiddleName, u.Email, u.ID) + u.OrganizationID, u.FirstName, u.LastName, u.MiddleName, u.Email, u.Status, u.ID) + return err +} + +func (s *mysqlUserStore) UpdateStatus(ctx context.Context, tx *sql.Tx, id string, status models.UserStatus) error { + _, err := tx.ExecContext(ctx, + `UPDATE users SET status = ? WHERE id = ?`, + status, id) return err } diff --git a/pkg/events/external_identity_subscribe.go b/pkg/events/external_identity_subscribe.go new file mode 100644 index 000000000..05e2722ff --- /dev/null +++ b/pkg/events/external_identity_subscribe.go @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package events + +import ( + "log/slog" + + "github.com/apache/airavata-custos/pkg/models" +) + +// ExternalIdentityHandler handles external-identity lifecycle events with a +// typed payload. +type ExternalIdentityHandler func(identity models.ExternalIdentity) + +// SubscribeExternalIdentityCreated registers a typed handler invoked whenever +// an external_identity::create event is published. +func (b *Bus) SubscribeExternalIdentityCreated(handler ExternalIdentityHandler) { + b.subscribeExternalIdentity(ExternalIdentityCreateEvent, handler) +} + +// SubscribeExternalIdentityUpdated registers a typed handler invoked whenever +// an external_identity::update event is published. +func (b *Bus) SubscribeExternalIdentityUpdated(handler ExternalIdentityHandler) { + b.subscribeExternalIdentity(ExternalIdentityUpdateEvent, handler) +} + +// SubscribeExternalIdentityDeleted registers a typed handler invoked whenever +// an external_identity::delete event is published. +func (b *Bus) SubscribeExternalIdentityDeleted(handler ExternalIdentityHandler) { + b.subscribeExternalIdentity(ExternalIdentityDeleteEvent, handler) +} + +func (b *Bus) subscribeExternalIdentity(topic EventType, handler ExternalIdentityHandler) { + b.Subscribe(topic, func(event Event, value interface{}) { + switch e := value.(type) { + case models.ExternalIdentity: + handler(e) + case *models.ExternalIdentity: + if e != nil { + handler(*e) + } + default: + slog.Warn("external identity event payload has unexpected type", + "type", event.Type, + "got", value, + ) + } + }) +} diff --git a/pkg/events/types.go b/pkg/events/types.go index 8ae8de802..4e191e27b 100644 --- a/pkg/events/types.go +++ b/pkg/events/types.go @@ -110,6 +110,20 @@ const ( ComputeAllocationResourceMappingDeleteEvent EventType = "compute_allocation_resource_mapping::delete" ) +// ExternalIdentity lifecycle message types. +const ( + ExternalIdentityCreateEvent EventType = "external_identity::create" + ExternalIdentityUpdateEvent EventType = "external_identity::update" + ExternalIdentityDeleteEvent EventType = "external_identity::delete" +) + +// UserDN lifecycle message types. DN bindings are append-only credentials, so +// no update topic. +const ( + UserDNCreateEvent EventType = "user_dn::create" + UserDNDeleteEvent EventType = "user_dn::delete" +) + // Event represents a change in the system that downstream consumers may be interested in. // The payload is the full record after the change (e.g. the // new state of a project after an update). diff --git a/pkg/events/user_dn_subscribe.go b/pkg/events/user_dn_subscribe.go new file mode 100644 index 000000000..9190976e6 --- /dev/null +++ b/pkg/events/user_dn_subscribe.go @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package events + +import ( + "log/slog" + + "github.com/apache/airavata-custos/pkg/models" +) + +// UserDNHandler handles DN-binding lifecycle events with a typed payload. +type UserDNHandler func(dn models.UserDN) + +// SubscribeUserDNCreated registers a typed handler invoked whenever a +// user_dn::create event is published. +func (b *Bus) SubscribeUserDNCreated(handler UserDNHandler) { + b.subscribeUserDN(UserDNCreateEvent, handler) +} + +// SubscribeUserDNDeleted registers a typed handler invoked whenever a +// user_dn::delete event is published. +func (b *Bus) SubscribeUserDNDeleted(handler UserDNHandler) { + b.subscribeUserDN(UserDNDeleteEvent, handler) +} + +func (b *Bus) subscribeUserDN(topic EventType, handler UserDNHandler) { + b.Subscribe(topic, func(event Event, value interface{}) { + switch d := value.(type) { + case models.UserDN: + handler(d) + case *models.UserDN: + if d != nil { + handler(*d) + } + default: + slog.Warn("user dn event payload has unexpected type", + "type", event.Type, + "got", value, + ) + } + }) +} diff --git a/pkg/models/allocation.go b/pkg/models/allocation.go index a8fcd3976..9fb9e94f5 100644 --- a/pkg/models/allocation.go +++ b/pkg/models/allocation.go @@ -16,10 +16,11 @@ type ComputeCluster struct { } type ComputeClusterUser struct { - ID string `json:"id" db:"id"` - ComputeClusterID string `json:"compute_cluster_id" db:"compute_cluster_id"` - UserID string `json:"user_id" db:"user_id"` - LocalUsername string `json:"local_username" db:"local_username"` // The username of the user on the compute cluster, which may be different from their Airavata Custos username. + ID string `json:"id" db:"id"` + ComputeClusterID string `json:"compute_cluster_id" db:"compute_cluster_id"` + UserID string `json:"user_id" db:"user_id"` + LocalUsername string `json:"local_username" db:"local_username"` // The username of the user on the compute cluster, which may be different from their Airavata Custos username. + Status AllocationStatus `json:"status" db:"status"` } type ComputeAllocation struct { diff --git a/pkg/models/identity.go b/pkg/models/identity.go new file mode 100644 index 000000000..dc3cf5d95 --- /dev/null +++ b/pkg/models/identity.go @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package models + +import "time" + +// ExternalIdentity links a User to its identifier in an external system +// (ACCESS, NAIRR, CILogon, etc.). One user may have many external identities. +// Source-specific attributes (e.g. NSF status code, ACCESS org code) belong +// in Metadata as a JSON-encoded blob. +type ExternalIdentity struct { + ID string `json:"id" db:"id"` + UserID string `json:"user_id" db:"user_id"` + Source string `json:"source" db:"source"` // e.g. access, nairr, cilogon + ExternalID string `json:"external_id" db:"external_id"` // the source's native identifier + OIDCSub string `json:"oidc_sub,omitempty" db:"oidc_sub"` // OIDC subject when the source issues one + Metadata string `json:"metadata,omitempty" db:"metadata"` // JSON-encoded source-specific fields + CreatedAt time.Time `json:"created_at" db:"created_at"` +} + +// UserDN binds an X.509 distinguished name (e.g. mTLS client cert subject) to +// a User. Append-only: DNs are credentials and are added or removed, never +// edited. +type UserDN struct { + ID string `json:"id" db:"id"` + UserID string `json:"user_id" db:"user_id"` + DN string `json:"dn" db:"dn"` + CreatedAt time.Time `json:"created_at" db:"created_at"` +} diff --git a/pkg/models/project.go b/pkg/models/project.go index d6325959d..38de362a0 100644 --- a/pkg/models/project.go +++ b/pkg/models/project.go @@ -2,13 +2,33 @@ package models import "time" +// UserStatus enumerates the lifecycle states a User may occupy. +type UserStatus string + +const ( + UserActive UserStatus = "ACTIVE" + UserInactive UserStatus = "INACTIVE" + UserSuspended UserStatus = "SUSPENDED" + UserMerged UserStatus = "MERGED" +) + +// ProjectStatus enumerates the lifecycle states a Project may occupy. +type ProjectStatus string + +const ( + ProjectActive ProjectStatus = "ACTIVE" + ProjectInactive ProjectStatus = "INACTIVE" + ProjectDeleted ProjectStatus = "DELETED" +) + type Project struct { - ID string `json:"id" db:"id"` - OriginatedID string `json:"originated_id" db:"originated_id"` // The ID of the project in origination. For example: ACCESS Record ID. - Title string `json:"title" db:"title"` - Origination string `json:"origination" db:"origination"` // ACCESS, NAIRR, XRASS, etc. - ProjectPIID string `json:"project_pi_id" db:"project_pi_id"` - CreatedTime time.Time `json:"created_time" db:"created_time"` + ID string `json:"id" db:"id"` + OriginatedID string `json:"originated_id" db:"originated_id"` // The ID of the project in origination. For example: ACCESS Record ID. + Title string `json:"title" db:"title"` + Origination string `json:"origination" db:"origination"` // ACCESS, NAIRR, XRASS, etc. + ProjectPIID string `json:"project_pi_id" db:"project_pi_id"` + Status ProjectStatus `json:"status" db:"status"` + CreatedTime time.Time `json:"created_time" db:"created_time"` } type Organization struct { @@ -18,10 +38,22 @@ type Organization struct { } type User struct { - ID string `json:"id" db:"id"` - OrganizationID string `json:"organization_id" db:"organization_id"` - FirstName string `json:"first_name" db:"first_name"` - LastName string `json:"last_name" db:"last_name"` - MiddleName string `json:"middle_name,omitempty" db:"middle_name"` - Email string `json:"email" db:"email"` + ID string `json:"id" db:"id"` + OrganizationID string `json:"organization_id" db:"organization_id"` + FirstName string `json:"first_name" db:"first_name"` + LastName string `json:"last_name" db:"last_name"` + MiddleName string `json:"middle_name,omitempty" db:"middle_name"` + Email string `json:"email" db:"email"` + Status UserStatus `json:"status" db:"status"` +} + +// UserMerge is the audit record that links a retiring user to the surviving +// user that absorbed its identity-forward state. Each retiring user can be +// merged at most once; merges are not reversed in-place. +type UserMerge struct { + ID int64 `json:"id" db:"id"` + RetiringUserID string `json:"retiring_user_id" db:"retiring_user_id"` + SurvivingUserID string `json:"surviving_user_id" db:"surviving_user_id"` + Reason string `json:"reason,omitempty" db:"reason"` + MergedAt time.Time `json:"merged_at" db:"merged_at"` } diff --git a/pkg/service/compute_cluster_user.go b/pkg/service/compute_cluster_user.go index 3e8ed9868..b3b268ae4 100644 --- a/pkg/service/compute_cluster_user.go +++ b/pkg/service/compute_cluster_user.go @@ -45,6 +45,9 @@ func (s *Service) CreateComputeClusterUser(ctx context.Context, cu *models.Compu if cu.ID == "" { cu.ID = newID() } + if cu.Status == "" { + cu.Status = models.ACTIVE + } if cluster, err := s.clusters.FindByID(ctx, cu.ComputeClusterID); err != nil { return nil, fmt.Errorf("lookup compute cluster: %w", err) @@ -132,19 +135,30 @@ func (s *Service) ListComputeClusterUsersByUser(ctx context.Context, userID stri } // UpdateComputeClusterUser persists changes to an existing compute-cluster -// user mapping. +// user mapping. Fields left blank/zero on the supplied record fall back to +// the stored value. func (s *Service) UpdateComputeClusterUser(ctx context.Context, cu *models.ComputeClusterUser) error { if cu == nil || cu.ID == "" { return fmt.Errorf("%w: compute cluster user id is required", ErrInvalidInput) } + existing, err := s.clusterUsers.FindByID(ctx, cu.ID) + if err != nil { + return fmt.Errorf("lookup compute cluster user: %w", err) + } + if existing == nil { + return ErrNotFound + } if cu.ComputeClusterID == "" { - return fmt.Errorf("%w: compute_cluster_id is required", ErrInvalidInput) + cu.ComputeClusterID = existing.ComputeClusterID } if cu.UserID == "" { - return fmt.Errorf("%w: user_id is required", ErrInvalidInput) + cu.UserID = existing.UserID } if cu.LocalUsername == "" { - return fmt.Errorf("%w: local_username is required", ErrInvalidInput) + cu.LocalUsername = existing.LocalUsername + } + if cu.Status == "" { + cu.Status = existing.Status } if err := s.inTx(ctx, func(tx *sql.Tx) error { return s.clusterUsers.Update(ctx, tx, cu) @@ -156,6 +170,33 @@ func (s *Service) UpdateComputeClusterUser(ctx context.Context, cu *models.Compu return nil } +// UpdateComputeClusterUserStatus sets the lifecycle status of the mapping +// identified by id. Other fields are preserved. +func (s *Service) UpdateComputeClusterUserStatus(ctx context.Context, id string, status models.AllocationStatus) (*models.ComputeClusterUser, error) { + if id == "" { + return nil, fmt.Errorf("%w: compute cluster user id is required", ErrInvalidInput) + } + if status == "" { + return nil, fmt.Errorf("%w: status is required", ErrInvalidInput) + } + existing, err := s.clusterUsers.FindByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("lookup compute cluster user: %w", err) + } + if existing == nil { + return nil, ErrNotFound + } + if err := s.inTx(ctx, func(tx *sql.Tx) error { + return s.clusterUsers.UpdateStatus(ctx, tx, id, status) + }); err != nil { + return nil, fmt.Errorf("update compute cluster user status: %w", err) + } + existing.Status = status + + s.eventBus.Publish(events.ComputeClusterUserUpdateEvent, existing) + return existing, nil +} + // DeleteComputeClusterUser removes a compute-cluster user mapping by ID. func (s *Service) DeleteComputeClusterUser(ctx context.Context, id string) error { if id == "" { diff --git a/pkg/service/external_identity.go b/pkg/service/external_identity.go new file mode 100644 index 000000000..1fe89af74 --- /dev/null +++ b/pkg/service/external_identity.go @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package service + +import ( + "context" + "database/sql" + "fmt" + + "github.com/apache/airavata-custos/pkg/events" + "github.com/apache/airavata-custos/pkg/models" +) + +// CreateExternalIdentity persists a new external identity. If e.ID is empty, a +// new UUID is generated. The referenced user must already exist and the +// (source, external_id) pair is unique. +func (s *Service) CreateExternalIdentity(ctx context.Context, e *models.ExternalIdentity) (*models.ExternalIdentity, error) { + if e == nil { + return nil, fmt.Errorf("%w: external identity is nil", ErrInvalidInput) + } + if e.UserID == "" { + return nil, fmt.Errorf("%w: external identity user_id is required", ErrInvalidInput) + } + if e.Source == "" { + return nil, fmt.Errorf("%w: external identity source is required", ErrInvalidInput) + } + if e.ExternalID == "" { + return nil, fmt.Errorf("%w: external identity external_id is required", ErrInvalidInput) + } + + if user, err := s.users.FindByID(ctx, e.UserID); err != nil { + return nil, fmt.Errorf("verify user: %w", err) + } else if user == nil { + return nil, fmt.Errorf("%w: user %q does not exist", ErrInvalidInput, e.UserID) + } + + if existing, err := s.extIDs.FindBySourceAndExternalID(ctx, e.Source, e.ExternalID); err != nil { + return nil, fmt.Errorf("lookup external identity: %w", err) + } else if existing != nil { + return nil, fmt.Errorf("%w: external identity for source %q, external_id %q", ErrAlreadyExists, e.Source, e.ExternalID) + } + + if e.ID == "" { + e.ID = newID() + } + + if err := s.inTx(ctx, func(tx *sql.Tx) error { + return s.extIDs.Create(ctx, tx, e) + }); err != nil { + return nil, fmt.Errorf("create external identity: %w", err) + } + + s.eventBus.Publish(events.ExternalIdentityCreateEvent, e) + return e, nil +} + +// GetExternalIdentity retrieves an external identity by ID. Returns +// ErrNotFound when no row matches. +func (s *Service) GetExternalIdentity(ctx context.Context, id string) (*models.ExternalIdentity, error) { + e, err := s.extIDs.FindByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("get external identity: %w", err) + } + if e == nil { + return nil, ErrNotFound + } + return e, nil +} + +// GetExternalIdentityBySourceAndExternalID retrieves the unique external +// identity for the given (source, external_id) pair. +func (s *Service) GetExternalIdentityBySourceAndExternalID(ctx context.Context, source, externalID string) (*models.ExternalIdentity, error) { + if source == "" { + return nil, fmt.Errorf("%w: source is required", ErrInvalidInput) + } + if externalID == "" { + return nil, fmt.Errorf("%w: external_id is required", ErrInvalidInput) + } + e, err := s.extIDs.FindBySourceAndExternalID(ctx, source, externalID) + if err != nil { + return nil, fmt.Errorf("get external identity by source/external_id: %w", err) + } + if e == nil { + return nil, ErrNotFound + } + return e, nil +} + +// GetExternalIdentityByOIDCSub retrieves the first external identity matching +// the given OIDC subject. +func (s *Service) GetExternalIdentityByOIDCSub(ctx context.Context, oidcSub string) (*models.ExternalIdentity, error) { + if oidcSub == "" { + return nil, fmt.Errorf("%w: oidc_sub is required", ErrInvalidInput) + } + e, err := s.extIDs.FindByOIDCSub(ctx, oidcSub) + if err != nil { + return nil, fmt.Errorf("get external identity by oidc_sub: %w", err) + } + if e == nil { + return nil, ErrNotFound + } + return e, nil +} + +// ListExternalIdentitiesForUser returns every external identity belonging to +// the given user. +func (s *Service) ListExternalIdentitiesForUser(ctx context.Context, userID string) ([]models.ExternalIdentity, error) { + if userID == "" { + return nil, fmt.Errorf("%w: user_id is required", ErrInvalidInput) + } + out, err := s.extIDs.FindByUser(ctx, userID) + if err != nil { + return nil, fmt.Errorf("list external identities by user: %w", err) + } + return out, nil +} + +// UpdateExternalIdentity persists changes to an existing external identity. +// Fields left blank/zero on the supplied record fall back to the stored value. +func (s *Service) UpdateExternalIdentity(ctx context.Context, e *models.ExternalIdentity) error { + if e == nil || e.ID == "" { + return fmt.Errorf("%w: external identity id is required", ErrInvalidInput) + } + existing, err := s.extIDs.FindByID(ctx, e.ID) + if err != nil { + return fmt.Errorf("lookup external identity: %w", err) + } + if existing == nil { + return ErrNotFound + } + if e.UserID == "" { + e.UserID = existing.UserID + } + if e.Source == "" { + e.Source = existing.Source + } + if e.ExternalID == "" { + e.ExternalID = existing.ExternalID + } + if e.OIDCSub == "" { + e.OIDCSub = existing.OIDCSub + } + if e.Metadata == "" { + e.Metadata = existing.Metadata + } + if err := s.inTx(ctx, func(tx *sql.Tx) error { + return s.extIDs.Update(ctx, tx, e) + }); err != nil { + return fmt.Errorf("update external identity: %w", err) + } + + s.eventBus.Publish(events.ExternalIdentityUpdateEvent, e) + return nil +} + +// DeleteExternalIdentity removes an external identity by ID. +func (s *Service) DeleteExternalIdentity(ctx context.Context, id string) error { + if id == "" { + return fmt.Errorf("%w: external identity id is required", ErrInvalidInput) + } + e, err := s.extIDs.FindByID(ctx, id) + if err != nil { + return fmt.Errorf("lookup external identity: %w", err) + } + if e == nil { + return ErrNotFound + } + if err := s.inTx(ctx, func(tx *sql.Tx) error { + return s.extIDs.Delete(ctx, tx, id) + }); err != nil { + return fmt.Errorf("delete external identity: %w", err) + } + + s.eventBus.Publish(events.ExternalIdentityDeleteEvent, e) + return nil +} diff --git a/pkg/service/project.go b/pkg/service/project.go index 938ff4f8e..338542473 100644 --- a/pkg/service/project.go +++ b/pkg/service/project.go @@ -56,6 +56,9 @@ func (s *Service) CreateProject(ctx context.Context, project *models.Project) (* if project.ID == "" { project.ID = newID() } + if project.Status == "" { + project.Status = models.ProjectActive + } if project.CreatedTime.IsZero() { project.CreatedTime = nowUTC() } @@ -103,11 +106,37 @@ func (s *Service) ListProjectsByPI(ctx context.Context, piUserID string) ([]mode return projects, nil } -// UpdateProject persists changes to an existing project. +// UpdateProject persists changes to an existing project. Fields left +// blank/zero on the supplied record fall back to the stored value. func (s *Service) UpdateProject(ctx context.Context, project *models.Project) error { if project == nil || project.ID == "" { return fmt.Errorf("%w: project id is required", ErrInvalidInput) } + existing, err := s.projs.FindByID(ctx, project.ID) + if err != nil { + return fmt.Errorf("lookup project: %w", err) + } + if existing == nil { + return ErrNotFound + } + if project.OriginatedID == "" { + project.OriginatedID = existing.OriginatedID + } + if project.Title == "" { + project.Title = existing.Title + } + if project.Origination == "" { + project.Origination = existing.Origination + } + if project.ProjectPIID == "" { + project.ProjectPIID = existing.ProjectPIID + } + if project.Status == "" { + project.Status = existing.Status + } + if project.CreatedTime.IsZero() { + project.CreatedTime = existing.CreatedTime + } if err := s.inTx(ctx, func(tx *sql.Tx) error { return s.projs.Update(ctx, tx, project) }); err != nil { @@ -118,6 +147,33 @@ func (s *Service) UpdateProject(ctx context.Context, project *models.Project) er return nil } +// UpdateProjectStatus sets the lifecycle status of the project identified by +// id. Other fields are preserved. +func (s *Service) UpdateProjectStatus(ctx context.Context, id string, status models.ProjectStatus) (*models.Project, error) { + if id == "" { + return nil, fmt.Errorf("%w: project id is required", ErrInvalidInput) + } + if status == "" { + return nil, fmt.Errorf("%w: status is required", ErrInvalidInput) + } + existing, err := s.projs.FindByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("lookup project: %w", err) + } + if existing == nil { + return nil, ErrNotFound + } + if err := s.inTx(ctx, func(tx *sql.Tx) error { + return s.projs.UpdateStatus(ctx, tx, id, status) + }); err != nil { + return nil, fmt.Errorf("update project status: %w", err) + } + existing.Status = status + + s.eventBus.Publish(events.ProjectUpdateEvent, existing) + return existing, nil +} + // DeleteProject removes a project by ID. func (s *Service) DeleteProject(ctx context.Context, id string) error { if id == "" { diff --git a/pkg/service/service.go b/pkg/service/service.go index 23ca9bc83..3ef5c92cd 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -34,46 +34,52 @@ import ( // mutating operation in a transaction so callers do not need to manage // *sql.Tx themselves. type Service struct { - db *sqlx.DB - eventBus *events.Bus - orgs store.OrganizationStore - users store.UserStore - projs store.ProjectStore - clusters store.ComputeClusterStore - clusterUsers store.ComputeClusterUserStore - allocs store.ComputeAllocationStore - resources store.ComputeAllocationResourceStore - resourceMappings store.ComputeAllocationResourceMappingStore - resourceRates store.ComputeAllocationResourceRateStore - allocDiffs store.ComputeAllocationDiffStore - changeRequests store.ComputeAllocationChangeRequestStore - changeEvents store.ComputeAllocationChangeRequestEventStore - memberships store.ComputeAllocationMembershipStore + db *sqlx.DB + eventBus *events.Bus + orgs store.OrganizationStore + users store.UserStore + projs store.ProjectStore + clusters store.ComputeClusterStore + clusterUsers store.ComputeClusterUserStore + allocs store.ComputeAllocationStore + resources store.ComputeAllocationResourceStore + resourceMappings store.ComputeAllocationResourceMappingStore + resourceRates store.ComputeAllocationResourceRateStore + allocDiffs store.ComputeAllocationDiffStore + changeRequests store.ComputeAllocationChangeRequestStore + changeEvents store.ComputeAllocationChangeRequestEventStore + memberships store.ComputeAllocationMembershipStore membershipOverrides store.ComputeAllocationMembershipResourceOverrideStore - usages store.ComputeAllocationUsageStore + usages store.ComputeAllocationUsageStore + extIDs store.ExternalIdentityStore + userDNs store.UserDNStore + userMerges store.UserMergeStore } // New constructs a Service backed by the supplied database handle. // Stores are instantiated internally using the default MySQL implementations. func New(database *sqlx.DB, eventBus *events.Bus) *Service { return &Service{ - db: database, - eventBus: eventBus, - orgs: store.NewOrganizationStore(database), - users: store.NewUserStore(database), - projs: store.NewProjectStore(database), - clusters: store.NewComputeClusterStore(database), - clusterUsers: store.NewComputeClusterUserStore(database), - allocs: store.NewComputeAllocationStore(database), - resources: store.NewComputeAllocationResourceStore(database), - resourceMappings: store.NewComputeAllocationResourceMappingStore(database), - resourceRates: store.NewComputeAllocationResourceRateStore(database), - allocDiffs: store.NewComputeAllocationDiffStore(database), - changeRequests: store.NewComputeAllocationChangeRequestStore(database), - changeEvents: store.NewComputeAllocationChangeRequestEventStore(database), - memberships: store.NewComputeAllocationMembershipStore(database), + db: database, + eventBus: eventBus, + orgs: store.NewOrganizationStore(database), + users: store.NewUserStore(database), + projs: store.NewProjectStore(database), + clusters: store.NewComputeClusterStore(database), + clusterUsers: store.NewComputeClusterUserStore(database), + allocs: store.NewComputeAllocationStore(database), + resources: store.NewComputeAllocationResourceStore(database), + resourceMappings: store.NewComputeAllocationResourceMappingStore(database), + resourceRates: store.NewComputeAllocationResourceRateStore(database), + allocDiffs: store.NewComputeAllocationDiffStore(database), + changeRequests: store.NewComputeAllocationChangeRequestStore(database), + changeEvents: store.NewComputeAllocationChangeRequestEventStore(database), + memberships: store.NewComputeAllocationMembershipStore(database), membershipOverrides: store.NewComputeAllocationMembershipResourceOverrideStore(database), - usages: store.NewComputeAllocationUsageStore(database), + usages: store.NewComputeAllocationUsageStore(database), + extIDs: store.NewExternalIdentityStore(database), + userDNs: store.NewUserDNStore(database), + userMerges: store.NewUserMergeStore(database), } } @@ -98,25 +104,31 @@ func NewWithStores( membershipOverrides store.ComputeAllocationMembershipResourceOverrideStore, memberships store.ComputeAllocationMembershipStore, usages store.ComputeAllocationUsageStore, + extIDs store.ExternalIdentityStore, + userDNs store.UserDNStore, + userMerges store.UserMergeStore, ) *Service { return &Service{ - db: database, - eventBus: eventBus, - orgs: orgs, - users: users, - projs: projs, - clusters: clusters, - clusterUsers: clusterUsers, - allocs: allocs, - resources: resources, - resourceMappings: resourceMappings, - resourceRates: resourceRates, - allocDiffs: allocDiffs, - changeRequests: changeRequests, - changeEvents: changeEvents, + db: database, + eventBus: eventBus, + orgs: orgs, + users: users, + projs: projs, + clusters: clusters, + clusterUsers: clusterUsers, + allocs: allocs, + resources: resources, + resourceMappings: resourceMappings, + resourceRates: resourceRates, + allocDiffs: allocDiffs, + changeRequests: changeRequests, + changeEvents: changeEvents, membershipOverrides: membershipOverrides, - memberships: memberships, - usages: usages, + memberships: memberships, + usages: usages, + extIDs: extIDs, + userDNs: userDNs, + userMerges: userMerges, } } diff --git a/pkg/service/user.go b/pkg/service/user.go index da809881c..a17bf7a27 100644 --- a/pkg/service/user.go +++ b/pkg/service/user.go @@ -22,6 +22,7 @@ import ( "database/sql" "fmt" + "github.com/apache/airavata-custos/pkg/events" "github.com/apache/airavata-custos/pkg/models" ) @@ -53,12 +54,17 @@ func (s *Service) CreateUser(ctx context.Context, user *models.User) (*models.Us if user.ID == "" { user.ID = newID() } + if user.Status == "" { + user.Status = models.UserActive + } if err := s.inTx(ctx, func(tx *sql.Tx) error { return s.users.Create(ctx, tx, user) }); err != nil { return nil, fmt.Errorf("create user: %w", err) } + + s.eventBus.Publish(events.UserCreateEvent, user) return user, nil } @@ -74,6 +80,26 @@ func (s *Service) GetUser(ctx context.Context, id string) (*models.User, error) return u, nil } +// GetUserByExternalIdentity returns the user owning the external identity +// uniquely identified by (source, externalID). Returns ErrNotFound when no +// such binding exists. +func (s *Service) GetUserByExternalIdentity(ctx context.Context, source, externalID string) (*models.User, error) { + if source == "" { + return nil, fmt.Errorf("%w: source is required", ErrInvalidInput) + } + if externalID == "" { + return nil, fmt.Errorf("%w: external_id is required", ErrInvalidInput) + } + ext, err := s.extIDs.FindBySourceAndExternalID(ctx, source, externalID) + if err != nil { + return nil, fmt.Errorf("lookup external identity: %w", err) + } + if ext == nil { + return nil, ErrNotFound + } + return s.GetUser(ctx, ext.UserID) +} + // GetUserByEmail retrieves a user by email. func (s *Service) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { u, err := s.users.FindByEmail(ctx, email) @@ -95,28 +121,102 @@ func (s *Service) ListUsersByOrganization(ctx context.Context, organizationID st return users, nil } -// UpdateUser persists changes to an existing user. +// UpdateUser persists changes to an existing user. Fields left blank/zero on +// the supplied record fall back to the stored value. func (s *Service) UpdateUser(ctx context.Context, user *models.User) error { if user == nil || user.ID == "" { return fmt.Errorf("%w: user id is required", ErrInvalidInput) } + existing, err := s.users.FindByID(ctx, user.ID) + if err != nil { + return fmt.Errorf("lookup user: %w", err) + } + if existing == nil { + return ErrNotFound + } + + if user.Email != "" && user.Email != existing.Email { + if other, err := s.users.FindByEmail(ctx, user.Email); err != nil { + return fmt.Errorf("lookup user by email: %w", err) + } else if other != nil && other.ID != user.ID { + return fmt.Errorf("%w: user with email %q", ErrAlreadyExists, user.Email) + } + } + + if user.OrganizationID == "" { + user.OrganizationID = existing.OrganizationID + } + if user.FirstName == "" { + user.FirstName = existing.FirstName + } + if user.LastName == "" { + user.LastName = existing.LastName + } + if user.MiddleName == "" { + user.MiddleName = existing.MiddleName + } + if user.Email == "" { + user.Email = existing.Email + } + if user.Status == "" { + user.Status = existing.Status + } + if err := s.inTx(ctx, func(tx *sql.Tx) error { return s.users.Update(ctx, tx, user) }); err != nil { return fmt.Errorf("update user: %w", err) } + + s.eventBus.Publish(events.UserUpdateEvent, user) return nil } +// UpdateUserStatus sets the lifecycle status of the user identified by id. +// Other fields are preserved. +func (s *Service) UpdateUserStatus(ctx context.Context, id string, status models.UserStatus) (*models.User, error) { + if id == "" { + return nil, fmt.Errorf("%w: user id is required", ErrInvalidInput) + } + if status == "" { + return nil, fmt.Errorf("%w: status is required", ErrInvalidInput) + } + existing, err := s.users.FindByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("lookup user: %w", err) + } + if existing == nil { + return nil, ErrNotFound + } + if err := s.inTx(ctx, func(tx *sql.Tx) error { + return s.users.UpdateStatus(ctx, tx, id, status) + }); err != nil { + return nil, fmt.Errorf("update user status: %w", err) + } + existing.Status = status + + s.eventBus.Publish(events.UserUpdateEvent, existing) + return existing, nil +} + // DeleteUser removes a user by ID. func (s *Service) DeleteUser(ctx context.Context, id string) error { if id == "" { return fmt.Errorf("%w: user id is required", ErrInvalidInput) } + existing, err := s.users.FindByID(ctx, id) + if err != nil { + return fmt.Errorf("lookup user: %w", err) + } + if existing == nil { + return ErrNotFound + } if err := s.inTx(ctx, func(tx *sql.Tx) error { return s.users.Delete(ctx, tx, id) }); err != nil { return fmt.Errorf("delete user: %w", err) } + + s.eventBus.Publish(events.UserDeleteEvent, existing) return nil } diff --git a/pkg/service/user_dn.go b/pkg/service/user_dn.go new file mode 100644 index 000000000..611b93f9d --- /dev/null +++ b/pkg/service/user_dn.go @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package service + +import ( + "context" + "database/sql" + "fmt" + + "github.com/apache/airavata-custos/pkg/events" + "github.com/apache/airavata-custos/pkg/models" +) + +// AddUserDN binds a DN to a user. If d.ID is empty, a new UUID is generated. +// The referenced user must already exist; (user_id, dn) is unique. +func (s *Service) AddUserDN(ctx context.Context, d *models.UserDN) (*models.UserDN, error) { + if d == nil { + return nil, fmt.Errorf("%w: user dn is nil", ErrInvalidInput) + } + if d.UserID == "" { + return nil, fmt.Errorf("%w: user dn user_id is required", ErrInvalidInput) + } + if d.DN == "" { + return nil, fmt.Errorf("%w: user dn dn is required", ErrInvalidInput) + } + + if user, err := s.users.FindByID(ctx, d.UserID); err != nil { + return nil, fmt.Errorf("verify user: %w", err) + } else if user == nil { + return nil, fmt.Errorf("%w: user %q does not exist", ErrInvalidInput, d.UserID) + } + + if existing, err := s.userDNs.FindByDN(ctx, d.DN); err != nil { + return nil, fmt.Errorf("lookup user dn: %w", err) + } else if existing != nil { + return nil, fmt.Errorf("%w: dn %q", ErrAlreadyExists, d.DN) + } + + if d.ID == "" { + d.ID = newID() + } + + if err := s.inTx(ctx, func(tx *sql.Tx) error { + return s.userDNs.Create(ctx, tx, d) + }); err != nil { + return nil, fmt.Errorf("add user dn: %w", err) + } + + s.eventBus.Publish(events.UserDNCreateEvent, d) + return d, nil +} + +// GetUserDN retrieves a DN binding by ID. Returns ErrNotFound when no row matches. +func (s *Service) GetUserDN(ctx context.Context, id string) (*models.UserDN, error) { + d, err := s.userDNs.FindByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("get user dn: %w", err) + } + if d == nil { + return nil, ErrNotFound + } + return d, nil +} + +// GetUserDNByDN performs a reverse lookup from DN to binding. +func (s *Service) GetUserDNByDN(ctx context.Context, dn string) (*models.UserDN, error) { + if dn == "" { + return nil, fmt.Errorf("%w: dn is required", ErrInvalidInput) + } + d, err := s.userDNs.FindByDN(ctx, dn) + if err != nil { + return nil, fmt.Errorf("get user dn by dn: %w", err) + } + if d == nil { + return nil, ErrNotFound + } + return d, nil +} + +// ListUserDNs returns every DN bound to the given user. +func (s *Service) ListUserDNs(ctx context.Context, userID string) ([]models.UserDN, error) { + if userID == "" { + return nil, fmt.Errorf("%w: user_id is required", ErrInvalidInput) + } + out, err := s.userDNs.FindByUser(ctx, userID) + if err != nil { + return nil, fmt.Errorf("list user dns: %w", err) + } + return out, nil +} + +// RemoveUserDN removes a DN binding by ID. +func (s *Service) RemoveUserDN(ctx context.Context, id string) error { + if id == "" { + return fmt.Errorf("%w: user dn id is required", ErrInvalidInput) + } + d, err := s.userDNs.FindByID(ctx, id) + if err != nil { + return fmt.Errorf("lookup user dn: %w", err) + } + if d == nil { + return ErrNotFound + } + if err := s.inTx(ctx, func(tx *sql.Tx) error { + return s.userDNs.Delete(ctx, tx, id) + }); err != nil { + return fmt.Errorf("remove user dn: %w", err) + } + + s.eventBus.Publish(events.UserDNDeleteEvent, d) + return nil +} diff --git a/pkg/service/user_merge.go b/pkg/service/user_merge.go new file mode 100644 index 000000000..a2ffe2ad7 --- /dev/null +++ b/pkg/service/user_merge.go @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package service + +import ( + "context" + "database/sql" + "fmt" + + "github.com/apache/airavata-custos/pkg/events" + "github.com/apache/airavata-custos/pkg/models" +) + +// MergeUsers consolidates the retiring user into the surviving user. All +// identity-forward state moves to the survivor; historical truth stays in +// place. The retiring user is flipped to status=MERGED and a row is written +// to user_merges with the surviving user and the given reason. All work +// happens in a single transaction. +// +// Moved to survivor (duplicates on the retiring user are dropped first): +// - external_identities +// - user_dns +// - compute_cluster_users +// - projects.project_pi_id +// - compute_allocation_memberships +// +// Left in place (who actually did the thing): +// - compute_allocation_change_requests (requester / approver) +// - compute_allocation_usages +func (s *Service) MergeUsers(ctx context.Context, survivingID, retiringID, reason string) (*models.User, error) { + if survivingID == "" || retiringID == "" { + return nil, fmt.Errorf("%w: surviving and retiring user IDs are required", ErrInvalidInput) + } + if survivingID == retiringID { + return nil, fmt.Errorf("%w: cannot merge a user with itself", ErrInvalidInput) + } + + survivor, err := s.users.FindByID(ctx, survivingID) + if err != nil { + return nil, fmt.Errorf("lookup surviving user: %w", err) + } + if survivor == nil { + return nil, fmt.Errorf("%w: surviving user %q does not exist", ErrInvalidInput, survivingID) + } + if survivor.Status == models.UserMerged { + return nil, fmt.Errorf("%w: surviving user %q is itself merged", ErrInvalidInput, survivingID) + } + retiring, err := s.users.FindByID(ctx, retiringID) + if err != nil { + return nil, fmt.Errorf("lookup retiring user: %w", err) + } + if retiring == nil { + return nil, fmt.Errorf("%w: retiring user %q does not exist", ErrInvalidInput, retiringID) + } + + // Idempotency: re-running the same merge is a no-op; merging the same + // retiring user into a different survivor is rejected. + if retiring.Status == models.UserMerged { + prior, err := s.userMerges.FindByRetiringUser(ctx, retiringID) + if err != nil { + return nil, fmt.Errorf("lookup prior merge: %w", err) + } + if prior != nil { + if prior.SurvivingUserID == survivingID { + return survivor, nil + } + return nil, fmt.Errorf("%w: user %q already merged into %q", + ErrAlreadyExists, retiringID, prior.SurvivingUserID) + } + } + + if err := s.inTx(ctx, func(tx *sql.Tx) error { + if err := s.extIDs.ReassignUser(ctx, tx, retiringID, survivingID); err != nil { + return fmt.Errorf("reassign external identities: %w", err) + } + if err := s.userDNs.ReassignUser(ctx, tx, retiringID, survivingID); err != nil { + return fmt.Errorf("reassign user dns: %w", err) + } + if err := s.clusterUsers.ReassignUser(ctx, tx, retiringID, survivingID); err != nil { + return fmt.Errorf("reassign compute cluster users: %w", err) + } + if err := s.projs.ReassignPI(ctx, tx, retiringID, survivingID); err != nil { + return fmt.Errorf("reassign project PI: %w", err) + } + if err := s.memberships.ReassignUser(ctx, tx, retiringID, survivingID); err != nil { + return fmt.Errorf("reassign memberships: %w", err) + } + if err := s.users.UpdateStatus(ctx, tx, retiringID, models.UserMerged); err != nil { + return fmt.Errorf("mark retiring user merged: %w", err) + } + return s.userMerges.Record(ctx, tx, retiringID, survivingID, reason) + }); err != nil { + return nil, fmt.Errorf("merge users: %w", err) + } + + retiring.Status = models.UserMerged + s.eventBus.Publish(events.UserUpdateEvent, retiring) + s.eventBus.Publish(events.UserUpdateEvent, survivor) + return survivor, nil +} + +// GetUserMergeByRetiringUser returns the merge record for a retiring user, or +// ErrNotFound if the user has not been merged. +func (s *Service) GetUserMergeByRetiringUser(ctx context.Context, retiringUserID string) (*models.UserMerge, error) { + if retiringUserID == "" { + return nil, fmt.Errorf("%w: retiring_user_id is required", ErrInvalidInput) + } + m, err := s.userMerges.FindByRetiringUser(ctx, retiringUserID) + if err != nil { + return nil, fmt.Errorf("get user merge: %w", err) + } + if m == nil { + return nil, ErrNotFound + } + return m, nil +} + +// ListUserMergesBySurvivingUser returns every merge record absorbed by the +// given surviving user, oldest first. +func (s *Service) ListUserMergesBySurvivingUser(ctx context.Context, survivingUserID string) ([]models.UserMerge, error) { + if survivingUserID == "" { + return nil, fmt.Errorf("%w: surviving_user_id is required", ErrInvalidInput) + } + out, err := s.userMerges.FindBySurvivingUser(ctx, survivingUserID) + if err != nil { + return nil, fmt.Errorf("list user merges by surviving user: %w", err) + } + return out, nil +}
