This is an automated email from the ASF dual-hosted git repository. lahirujayathilake pushed a commit to branch access-integration in repository https://gitbox.apache.org/repos/asf/airavata-custos.git
commit e5a58a57d0067ca1528a5c914851a2e23c7521e6 Author: lahiruj <[email protected]> AuthorDate: Tue May 19 22:07:39 2026 -0400 Complete AMIE inactivate, reactivate, and identity sync handlers --- .../AMIE-Processor/handler/data_account_create.go | 24 ++-- .../AMIE-Processor/handler/data_project_create.go | 24 ++-- .../ACCESS/AMIE-Processor/handler/handler.go | 61 ++++++++++ .../handler/request_account_create.go | 14 ++- .../handler/request_account_inactivate.go | 12 +- .../handler/request_account_reactivate.go | 12 +- .../handler/request_project_create.go | 49 ++++++-- .../handler/request_project_inactivate.go | 80 ++++++++++--- .../handler/request_project_reactivate.go | 83 ++++++++++--- .../AMIE-Processor/handler/request_user_modify.go | 129 ++++++++++++++++++--- 10 files changed, 409 insertions(+), 79 deletions(-) diff --git a/connectors/ACCESS/AMIE-Processor/handler/data_account_create.go b/connectors/ACCESS/AMIE-Processor/handler/data_account_create.go index 5d3626b8a..bfbe57df0 100644 --- a/connectors/ACCESS/AMIE-Processor/handler/data_account_create.go +++ b/connectors/ACCESS/AMIE-Processor/handler/data_account_create.go @@ -54,17 +54,21 @@ func (h *DataAccountCreateHandler) Handle(ctx context.Context, tx *sql.Tx, packe return err } - dns := getDNList(body) - if len(dns) > 0 { - user, err := h.svc.GetUserByExternalIdentity(ctx, amieIdentitySource, personGlobalID) - if err != nil { - if errors.Is(err, service.ErrNotFound) { - slog.WarnContext(ctx, "data_account_create: user not found for AMIE PersonID; skipping DN persistence", - "personGlobalID", personGlobalID) - } else { - return fmt.Errorf("data_account_create: resolve user: %w", err) - } + user, err := h.svc.GetUserByExternalIdentity(ctx, amieIdentitySource, personGlobalID) + if err != nil { + if errors.Is(err, service.ErrNotFound) { + slog.WarnContext(ctx, "data_account_create: user not found for AMIE PersonID; skipping DN persistence and ExternalIdentity upsert", + "personGlobalID", personGlobalID) } else { + return fmt.Errorf("data_account_create: resolve user: %w", err) + } + } + if user != nil { + if err := ensureExternalIdentity(ctx, h.svc, user.ID, personGlobalID); err != nil { + return fmt.Errorf("data_account_create: ensure external identity: %w", err) + } + dns := getDNList(body) + if len(dns) > 0 { for _, dn := range dns { if _, err := h.svc.AddUserDN(ctx, &models.UserDN{UserID: user.ID, DN: dn}); err != nil { if errors.Is(err, service.ErrAlreadyExists) { diff --git a/connectors/ACCESS/AMIE-Processor/handler/data_project_create.go b/connectors/ACCESS/AMIE-Processor/handler/data_project_create.go index 476143494..e5dd3ce26 100644 --- a/connectors/ACCESS/AMIE-Processor/handler/data_project_create.go +++ b/connectors/ACCESS/AMIE-Processor/handler/data_project_create.go @@ -54,17 +54,21 @@ func (h *DataProjectCreateHandler) Handle(ctx context.Context, tx *sql.Tx, packe return err } - dns := getDNList(body) - if len(dns) > 0 { - user, err := h.svc.GetUserByExternalIdentity(ctx, amieIdentitySource, piGlobalID) - if err != nil { - if errors.Is(err, service.ErrNotFound) { - slog.WarnContext(ctx, "data_project_create: PI user not found; skipping DN persistence", - "piGlobalID", piGlobalID) - } else { - return fmt.Errorf("data_project_create: resolve PI user: %w", err) - } + user, err := h.svc.GetUserByExternalIdentity(ctx, amieIdentitySource, piGlobalID) + if err != nil { + if errors.Is(err, service.ErrNotFound) { + slog.WarnContext(ctx, "data_project_create: PI user not found; skipping DN persistence and ExternalIdentity upsert", + "piGlobalID", piGlobalID) } else { + return fmt.Errorf("data_project_create: resolve PI user: %w", err) + } + } + if user != nil { + if err := ensureExternalIdentity(ctx, h.svc, user.ID, piGlobalID); err != nil { + return fmt.Errorf("data_project_create: ensure external identity: %w", err) + } + dns := getDNList(body) + if len(dns) > 0 { for _, dn := range dns { if _, err := h.svc.AddUserDN(ctx, &models.UserDN{UserID: user.ID, DN: dn}); err != nil { if errors.Is(err, service.ErrAlreadyExists) { diff --git a/connectors/ACCESS/AMIE-Processor/handler/handler.go b/connectors/ACCESS/AMIE-Processor/handler/handler.go index aca530b1c..3fd2664b7 100644 --- a/connectors/ACCESS/AMIE-Processor/handler/handler.go +++ b/connectors/ACCESS/AMIE-Processor/handler/handler.go @@ -111,6 +111,67 @@ func getResourceList(body map[string]any) []string { return result } +// ensureExternalIdentity is the idempotent upsert used by data_project_create +// and data_account_create. It is a no-op when the user already has an AMIE +// ExternalIdentity row for this globalID; creates one otherwise. Pre-existing +// rows are NOT touched here, attribute updates (org / orgCode / nsfStatus) +// are owned by request_user_modify. +func ensureExternalIdentity(ctx context.Context, svc *service.Service, userID, globalID string) error { + if existing, err := svc.GetExternalIdentityBySourceAndExternalID(ctx, amieIdentitySource, globalID); err == nil { + if existing.UserID == userID { + return nil + } + return fmt.Errorf("external identity for %s=%s is bound to user %s (not %s)", amieIdentitySource, globalID, existing.UserID, userID) + } else if !errors.Is(err, service.ErrNotFound) { + return err + } + if _, err := svc.CreateExternalIdentity(ctx, &models.ExternalIdentity{ + UserID: userID, + Source: amieIdentitySource, + ExternalID: globalID, + }); err != nil { + return fmt.Errorf("create external identity: %w", err) + } + return nil +} + +// flipUserMemberships flips every ComputeAllocationMembership the user holds +// under any allocation belonging to the given project (Custos project.id) to +// the given status. Returns the rows that were updated. Silently returns an +// empty slice when the project or user is unknown, the reply still goes back +// to AMIE but no state changes. +func flipUserMemberships(ctx context.Context, svc *service.Service, projectID, userID string, status models.AllocationStatus) ([]models.ComputeAllocationMembership, error) { + project, err := svc.GetProject(ctx, projectID) + if err != nil { + if errors.Is(err, service.ErrNotFound) { + return nil, nil + } + return nil, fmt.Errorf("lookup project: %w", err) + } + allocations, err := svc.ListComputeAllocationsByProject(ctx, project.ID) + if err != nil { + return nil, fmt.Errorf("list allocations: %w", err) + } + var updated []models.ComputeAllocationMembership + for _, a := range allocations { + members, err := svc.ListMembersForAllocation(ctx, a.ID) + if err != nil { + return nil, fmt.Errorf("list memberships for allocation %s: %w", a.ID, err) + } + for _, m := range members { + if m.UserID != userID { + continue + } + flipped, err := svc.UpdateMembershipStatus(ctx, m.ID, status) + if err != nil { + return nil, fmt.Errorf("update membership %s: %w", m.ID, err) + } + updated = append(updated, *flipped) + } + } + return updated, nil +} + // ensureOrganization looks up an Organization by its originated_id (the // AMIE-side org code such as "TEST123"); creates one if missing, using the // human-readable organization name from the packet. diff --git a/connectors/ACCESS/AMIE-Processor/handler/request_account_create.go b/connectors/ACCESS/AMIE-Processor/handler/request_account_create.go index 3d228c62e..3ef0a152e 100644 --- a/connectors/ACCESS/AMIE-Processor/handler/request_account_create.go +++ b/connectors/ACCESS/AMIE-Processor/handler/request_account_create.go @@ -51,8 +51,8 @@ func (h *RequestAccountCreateHandler) Handle(ctx context.Context, tx *sql.Tx, pa if err != nil { return err } - projectOriginatedID := getString(body, "ProjectID") - if err := requireText(projectOriginatedID, "ProjectID"); err != nil { + projectID := getString(body, "ProjectID") + if err := requireText(projectID, "ProjectID"); err != nil { return err } if err := requireText(getString(body, "GrantNumber"), "GrantNumber"); err != nil { @@ -74,9 +74,11 @@ func (h *RequestAccountCreateHandler) Handle(ctx context.Context, tx *sql.Tx, pa return fmt.Errorf("request_account_create: audit CREATE_PERSON: %w", err) } - project, err := h.svc.GetProjectByOriginatedID(ctx, projectOriginatedID) + // AMIE replies to notify_project_create with project.id (Custos UUID), so + // subsequent packets carry that id back to us as body.ProjectID. + project, err := h.svc.GetProject(ctx, projectID) if err != nil { - return fmt.Errorf("request_account_create: project %q not found (request_project_create must precede this packet): %w", projectOriginatedID, err) + return fmt.Errorf("request_account_create: project %q not found (request_project_create must precede this packet): %w", projectID, err) } allocations, err := h.svc.ListComputeAllocationsByProject(ctx, project.ID) @@ -84,7 +86,7 @@ func (h *RequestAccountCreateHandler) Handle(ctx context.Context, tx *sql.Tx, pa return fmt.Errorf("request_account_create: list allocations: %w", err) } if len(allocations) == 0 { - return fmt.Errorf("request_account_create: project %q has no ComputeAllocation; request_project_create did not provision one", projectOriginatedID) + return fmt.Errorf("request_account_create: project %q has no ComputeAllocation; request_project_create did not provision one", projectID) } allocation := allocations[0] @@ -107,7 +109,7 @@ func (h *RequestAccountCreateHandler) Handle(ctx context.Context, tx *sql.Tx, pa } replyBody := map[string]any{ - "ProjectID": projectOriginatedID, + "ProjectID": projectID, "GrantNumber": getString(body, "GrantNumber"), "UserPersonID": user.ID, "UserRemoteSiteLogin": account.Username, diff --git a/connectors/ACCESS/AMIE-Processor/handler/request_account_inactivate.go b/connectors/ACCESS/AMIE-Processor/handler/request_account_inactivate.go index ef4d3d42c..ff5e188b6 100644 --- a/connectors/ACCESS/AMIE-Processor/handler/request_account_inactivate.go +++ b/connectors/ACCESS/AMIE-Processor/handler/request_account_inactivate.go @@ -23,6 +23,7 @@ import ( "fmt" "github.com/apache/airavata-custos/connectors/ACCESS/AMIE-Processor/model" + "github.com/apache/airavata-custos/pkg/models" "github.com/apache/airavata-custos/pkg/service" ) @@ -52,8 +53,15 @@ func (h *RequestAccountInactivateHandler) Handle(ctx context.Context, tx *sql.Tx return err } - if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditInactivateMembership, "membership_request", personID, fmt.Sprintf("project=%s person=%s (membership persistence pending allocation mapping)", projectID, personID)); err != nil { - return fmt.Errorf("request_account_inactivate: audit INACTIVATE_MEMBERSHIP: %w", err) + flipped, err := flipUserMemberships(ctx, h.svc, projectID, personID, models.INACTIVE) + if err != nil { + return fmt.Errorf("request_account_inactivate: flip memberships: %w", err) + } + for _, m := range flipped { + if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditInactivateMembership, "compute_allocation_membership", m.ID, + fmt.Sprintf("user=%s allocation=%s", m.UserID, m.ComputeAllocationID)); err != nil { + return fmt.Errorf("request_account_inactivate: audit INACTIVATE_MEMBERSHIP: %w", err) + } } reply := map[string]any{ diff --git a/connectors/ACCESS/AMIE-Processor/handler/request_account_reactivate.go b/connectors/ACCESS/AMIE-Processor/handler/request_account_reactivate.go index 709d362eb..50f248493 100644 --- a/connectors/ACCESS/AMIE-Processor/handler/request_account_reactivate.go +++ b/connectors/ACCESS/AMIE-Processor/handler/request_account_reactivate.go @@ -23,6 +23,7 @@ import ( "fmt" "github.com/apache/airavata-custos/connectors/ACCESS/AMIE-Processor/model" + "github.com/apache/airavata-custos/pkg/models" "github.com/apache/airavata-custos/pkg/service" ) @@ -52,8 +53,15 @@ func (h *RequestAccountReactivateHandler) Handle(ctx context.Context, tx *sql.Tx return err } - if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditReactivateMembership, "membership_request", personID, fmt.Sprintf("project=%s person=%s (membership persistence pending allocation mapping)", projectID, personID)); err != nil { - return fmt.Errorf("request_account_reactivate: audit REACTIVATE_MEMBERSHIP: %w", err) + flipped, err := flipUserMemberships(ctx, h.svc, projectID, personID, models.ACTIVE) + if err != nil { + return fmt.Errorf("request_account_reactivate: flip memberships: %w", err) + } + for _, m := range flipped { + if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditReactivateMembership, "compute_allocation_membership", m.ID, + fmt.Sprintf("user=%s allocation=%s", m.UserID, m.ComputeAllocationID)); err != nil { + return fmt.Errorf("request_account_reactivate: audit REACTIVATE_MEMBERSHIP: %w", err) + } } reply := map[string]any{ diff --git a/connectors/ACCESS/AMIE-Processor/handler/request_project_create.go b/connectors/ACCESS/AMIE-Processor/handler/request_project_create.go index 33a06dff5..db8cfce76 100644 --- a/connectors/ACCESS/AMIE-Processor/handler/request_project_create.go +++ b/connectors/ACCESS/AMIE-Processor/handler/request_project_create.go @@ -22,6 +22,7 @@ import ( "database/sql" "errors" "fmt" + "strings" "github.com/apache/airavata-custos/connectors/ACCESS/AMIE-Processor/model" "github.com/apache/airavata-custos/pkg/models" @@ -63,7 +64,12 @@ func (h *RequestProjectCreateHandler) Handle(ctx context.Context, tx *sql.Tx, pa } // AMIE protocol: request_project_create does not carry a ProjectID. The // receiving site assigns one. We use the GrantNumber as the originated_id - // since it is the stable cross-site identifier on the AMIE side. + // (the stable cross-site identifier on the AMIE side) so first delivery + // and supplement/renewal re-deliveries map to the same Project row. + // + // TODO(amie-integration, grant-number-modeling): verify whether + // GrantNumber should be promoted to a first-class column on `projects` + // (with its own UNIQUE constraint) instead of being injected into originated_id. projectOriginatedID := grantNumber pi, err := h.ensurePIUser(ctx, body, piGlobalID) @@ -150,20 +156,18 @@ func (h *RequestProjectCreateHandler) ensureProject(ctx context.Context, origina }) } -// ensureAllocation creates a ComputeAllocation for the project if none exists -// yet. If one already exists (e.g. a repeat request_project_create signaling a -// supplement/renewal), the existing row is returned unchanged. -// -// TODO(amie-integration, allocation-type): branch on body["AllocationType"] -// (new / renewal / supplement / extension) and adjust the allocation -// accordingly. +// ensureAllocation creates a ComputeAllocation for the project on first +// delivery. On repeat delivery (supplement / renewal / extension / adjustment), +// the existing row is preserved and a ComputeAllocationDiff is recorded +// capturing the grant event. InitialSUAmount on the parent row stays as the +// original grant; effective SUs = InitialSUAmount + sum(grant diffs). func (h *RequestProjectCreateHandler) ensureAllocation(ctx context.Context, body map[string]any, projectID, grantNumber string) (*models.ComputeAllocation, error) { existing, err := h.svc.ListComputeAllocationsByProject(ctx, projectID) if err != nil { return nil, fmt.Errorf("list allocations: %w", err) } if len(existing) > 0 { - return &existing[0], nil + return h.recordAllocationDiff(ctx, body, &existing[0]) } su, err := getInt64(body, "ServiceUnitsAllocated") @@ -188,3 +192,30 @@ func (h *RequestProjectCreateHandler) ensureAllocation(ctx context.Context, body EndTime: end, }) } + +// recordAllocationDiff writes a ComputeAllocationDiff for a re-delivered +// request_project_create against an existing allocation (the AMIE pattern for +// supplements / renewals / extensions / adjustments). DiffType is the upper-cased +// AllocationType from the packet body; falls back to "GRANT" when AMIE did not supply one. +// NewSUAmount carries this packet's ServiceUnitsAllocated, the delta granted +// by this event, not the cumulative total. +func (h *RequestProjectCreateHandler) recordAllocationDiff(ctx context.Context, body map[string]any, existing *models.ComputeAllocation) (*models.ComputeAllocation, error) { + su, err := getInt64(body, "ServiceUnitsAllocated") + if err != nil { + return nil, err + } + diffType := strings.ToUpper(strings.TrimSpace(getString(body, "AllocationType"))) + if diffType == "" { + diffType = "GRANT" + } + if _, err := h.svc.CreateComputeAllocationDiff(ctx, &models.ComputeAllocationDiff{ + ComputeAllocationID: existing.ID, + DiffType: diffType, + NewSUAmount: su, + Status: models.ACTIVE, + Description: fmt.Sprintf("AMIE %s of %d SUs", diffType, su), + }); err != nil { + return nil, fmt.Errorf("record allocation diff: %w", err) + } + return existing, nil +} diff --git a/connectors/ACCESS/AMIE-Processor/handler/request_project_inactivate.go b/connectors/ACCESS/AMIE-Processor/handler/request_project_inactivate.go index 286953255..64b76036b 100644 --- a/connectors/ACCESS/AMIE-Processor/handler/request_project_inactivate.go +++ b/connectors/ACCESS/AMIE-Processor/handler/request_project_inactivate.go @@ -46,31 +46,78 @@ func (h *RequestProjectInactivateHandler) Handle(ctx context.Context, tx *sql.Tx if err != nil { return err } - originatedID := getString(body, "ProjectID") - if err := requireText(originatedID, "ProjectID"); err != nil { + projectID := getString(body, "ProjectID") + if err := requireText(projectID, "ProjectID"); err != nil { return err } - project, err := h.svc.GetProjectByOriginatedID(ctx, originatedID) + // AMIE carries the Custos project.id we returned in notify_project_create. + project, err := h.svc.GetProject(ctx, projectID) if err != nil { if errors.Is(err, service.ErrNotFound) { - slog.WarnContext(ctx, "request_project_inactivate: project not found in core; skipping status flip", - "originatedID", originatedID) - } else { - return fmt.Errorf("request_project_inactivate: lookup project: %w", err) + slog.WarnContext(ctx, "request_project_inactivate: project not found in core; skipping", + "projectID", projectID) + return h.reply(ctx, tx, packet, eventID, projectID) } - } else { - if _, err := h.svc.UpdateProjectStatus(ctx, project.ID, models.INACTIVE); err != nil { - return fmt.Errorf("request_project_inactivate: update status: %w", err) + return fmt.Errorf("request_project_inactivate: lookup project: %w", err) + } + + if _, err := h.svc.UpdateProjectStatus(ctx, project.ID, models.INACTIVE); err != nil { + return fmt.Errorf("request_project_inactivate: update project status: %w", err) + } + if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditInactivateProject, "project", project.ID, ""); err != nil { + return fmt.Errorf("request_project_inactivate: audit INACTIVATE_PROJECT: %w", err) + } + + if err := h.deactivateAllocations(ctx, tx, packet, eventID, project.ID, getString(body, "Comment")); err != nil { + return err + } + + return h.reply(ctx, tx, packet, eventID, projectID) +} + +// deactivateAllocations flips every ComputeAllocation under the project to +// INACTIVE, writes a status-change Diff per allocation, and flips every +// ComputeAllocationMembership under those allocations to INACTIVE. +func (h *RequestProjectInactivateHandler) deactivateAllocations(ctx context.Context, tx *sql.Tx, packet *model.Packet, eventID, projectID, comment string) error { + allocations, err := h.svc.ListComputeAllocationsByProject(ctx, projectID) + if err != nil { + return fmt.Errorf("list allocations: %w", err) + } + for _, a := range allocations { + a.Status = models.INACTIVE + if err := h.svc.UpdateComputeAllocation(ctx, &a); err != nil { + return fmt.Errorf("inactivate allocation %s: %w", a.ID, err) + } + if _, err := h.svc.CreateComputeAllocationDiff(ctx, &models.ComputeAllocationDiff{ + ComputeAllocationID: a.ID, + DiffType: "ALLOCATION_STATUS_CHANGE", + Status: models.INACTIVE, + Description: describeStatusChange("Inactivated by AMIE request_project_inactivate", comment), + }); err != nil { + return fmt.Errorf("record inactivate diff for %s: %w", a.ID, err) } - if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditInactivateProject, "project", project.ID, ""); err != nil { - return fmt.Errorf("request_project_inactivate: audit INACTIVATE_PROJECT: %w", err) + members, err := h.svc.ListMembersForAllocation(ctx, a.ID) + if err != nil { + return fmt.Errorf("list memberships for allocation %s: %w", a.ID, err) + } + for _, m := range members { + if _, err := h.svc.UpdateMembershipStatus(ctx, m.ID, models.INACTIVE); err != nil { + return fmt.Errorf("inactivate membership %s: %w", m.ID, err) + } + if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditInactivateMembership, "compute_allocation_membership", m.ID, + fmt.Sprintf("user=%s allocation=%s", m.UserID, a.ID)); err != nil { + return fmt.Errorf("audit INACTIVATE_MEMBERSHIP for %s: %w", m.ID, err) + } } } + return nil +} +func (h *RequestProjectInactivateHandler) reply(ctx context.Context, tx *sql.Tx, packet *model.Packet, eventID, projectID string) error { reply := map[string]any{ "type": "notify_project_inactivate", - "body": map[string]any{"ProjectID": originatedID}, + "body": map[string]any{"ProjectID": projectID}, } if err := h.amieClient.ReplyToPacket(ctx, packet.AmieID, reply); err != nil { return fmt.Errorf("request_project_inactivate: sending reply: %w", err) @@ -80,3 +127,10 @@ func (h *RequestProjectInactivateHandler) Handle(ctx context.Context, tx *sql.Tx } return nil } + +func describeStatusChange(primary, comment string) string { + if comment == "" { + return primary + } + return primary + " (" + comment + ")" +} diff --git a/connectors/ACCESS/AMIE-Processor/handler/request_project_reactivate.go b/connectors/ACCESS/AMIE-Processor/handler/request_project_reactivate.go index 4a6820bd1..409ad6976 100644 --- a/connectors/ACCESS/AMIE-Processor/handler/request_project_reactivate.go +++ b/connectors/ACCESS/AMIE-Processor/handler/request_project_reactivate.go @@ -41,36 +41,93 @@ func NewRequestProjectReactivateHandler(svc *service.Service, amieClient AmieCli func (h *RequestProjectReactivateHandler) SupportsType() string { return "request_project_reactivate" } +// Handle flips the Project and all of its ComputeAllocations back to ACTIVE. +// Per AMIE protocol, only the PI's membership is reactivated automatically; +// other members must request reactivation via request_account_reactivate. func (h *RequestProjectReactivateHandler) Handle(ctx context.Context, tx *sql.Tx, packetJSON map[string]any, packet *model.Packet, eventID string) error { body, err := getBody(packetJSON) if err != nil { return err } - originatedID := getString(body, "ProjectID") - if err := requireText(originatedID, "ProjectID"); err != nil { + projectID := getString(body, "ProjectID") + if err := requireText(projectID, "ProjectID"); err != nil { return err } - project, err := h.svc.GetProjectByOriginatedID(ctx, originatedID) + // AMIE carries the Custos project.id we returned in notify_project_create. + project, err := h.svc.GetProject(ctx, projectID) if err != nil { if errors.Is(err, service.ErrNotFound) { - slog.WarnContext(ctx, "request_project_reactivate: project not found in core; skipping status flip", - "originatedID", originatedID) - } else { - return fmt.Errorf("request_project_reactivate: lookup project: %w", err) + slog.WarnContext(ctx, "request_project_reactivate: project not found in core; skipping", + "projectID", projectID) + return h.reply(ctx, tx, packet, eventID, projectID) } - } else { - if _, err := h.svc.UpdateProjectStatus(ctx, project.ID, models.ACTIVE); err != nil { - return fmt.Errorf("request_project_reactivate: update status: %w", err) + return fmt.Errorf("request_project_reactivate: lookup project: %w", err) + } + + if _, err := h.svc.UpdateProjectStatus(ctx, project.ID, models.ACTIVE); err != nil { + return fmt.Errorf("request_project_reactivate: update project status: %w", err) + } + if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditReactivateProject, "project", project.ID, ""); err != nil { + return fmt.Errorf("request_project_reactivate: audit REACTIVATE_PROJECT: %w", err) + } + + if err := h.reactivateAllocationsAndPI(ctx, tx, packet, eventID, project); err != nil { + return err + } + + return h.reply(ctx, tx, packet, eventID, projectID) +} + +// reactivateAllocationsAndPI flips every ComputeAllocation under the project +// back to ACTIVE, writes a status-change Diff per allocation, and reactivates +// only the PI's membership on each allocation. Other members stay INACTIVE +// until their own request_account_reactivate arrives. +func (h *RequestProjectReactivateHandler) reactivateAllocationsAndPI(ctx context.Context, tx *sql.Tx, packet *model.Packet, eventID string, project *models.Project) error { + allocations, err := h.svc.ListComputeAllocationsByProject(ctx, project.ID) + if err != nil { + return fmt.Errorf("list allocations: %w", err) + } + for _, a := range allocations { + a.Status = models.ACTIVE + if err := h.svc.UpdateComputeAllocation(ctx, &a); err != nil { + return fmt.Errorf("reactivate allocation %s: %w", a.ID, err) + } + if _, err := h.svc.CreateComputeAllocationDiff(ctx, &models.ComputeAllocationDiff{ + ComputeAllocationID: a.ID, + DiffType: "ALLOCATION_STATUS_CHANGE", + Status: models.ACTIVE, + Description: "Reactivated by AMIE request_project_reactivate", + }); err != nil { + return fmt.Errorf("record reactivate diff for %s: %w", a.ID, err) } - if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditReactivateProject, "project", project.ID, ""); err != nil { - return fmt.Errorf("request_project_reactivate: audit REACTIVATE_PROJECT: %w", err) + if project.ProjectPIID == "" { + continue + } + members, err := h.svc.ListMembersForAllocation(ctx, a.ID) + if err != nil { + return fmt.Errorf("list memberships for allocation %s: %w", a.ID, err) + } + for _, m := range members { + if m.UserID != project.ProjectPIID { + continue + } + if _, err := h.svc.UpdateMembershipStatus(ctx, m.ID, models.ACTIVE); err != nil { + return fmt.Errorf("reactivate PI membership %s: %w", m.ID, err) + } + if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditReactivateMembership, "compute_allocation_membership", m.ID, + fmt.Sprintf("PI user=%s allocation=%s", m.UserID, a.ID)); err != nil { + return fmt.Errorf("audit REACTIVATE_MEMBERSHIP for %s: %w", m.ID, err) + } } } + return nil +} +func (h *RequestProjectReactivateHandler) reply(ctx context.Context, tx *sql.Tx, packet *model.Packet, eventID, projectID string) error { reply := map[string]any{ "type": "notify_project_reactivate", - "body": map[string]any{"ProjectID": originatedID}, + "body": map[string]any{"ProjectID": projectID}, } if err := h.amieClient.ReplyToPacket(ctx, packet.AmieID, reply); err != nil { return fmt.Errorf("request_project_reactivate: sending reply: %w", err) diff --git a/connectors/ACCESS/AMIE-Processor/handler/request_user_modify.go b/connectors/ACCESS/AMIE-Processor/handler/request_user_modify.go index 511f2ca28..b72031b22 100644 --- a/connectors/ACCESS/AMIE-Processor/handler/request_user_modify.go +++ b/connectors/ACCESS/AMIE-Processor/handler/request_user_modify.go @@ -20,11 +20,13 @@ package handler import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "strings" "github.com/apache/airavata-custos/connectors/ACCESS/AMIE-Processor/model" + "github.com/apache/airavata-custos/pkg/models" "github.com/apache/airavata-custos/pkg/service" ) @@ -65,20 +67,8 @@ func (h *RequestUserModifyHandler) Handle(ctx context.Context, tx *sql.Tx, packe switch { case strings.EqualFold(actionType, "replace"): if user != nil { - if v := getString(body, "UserFirstName"); v != "" { - user.FirstName = v - } - if v := getString(body, "UserLastName"); v != "" { - user.LastName = v - } - if v := getString(body, "UserEmail"); v != "" { - user.Email = v - } - if err := h.svc.UpdateUser(ctx, user); err != nil { - return fmt.Errorf("request_user_modify: update user: %w", err) - } - if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditUpdatePerson, "user", user.ID, ""); err != nil { - return fmt.Errorf("request_user_modify: audit UPDATE_PERSON: %w", err) + if err := h.applyReplace(ctx, tx, packet, eventID, body, user, userGlobalID); err != nil { + return err } } case strings.EqualFold(actionType, "delete"): @@ -110,3 +100,114 @@ func (h *RequestUserModifyHandler) Handle(ctx context.Context, tx *sql.Tx, packe } return nil } + +// applyReplace updates the User row (basic profile), the ExternalIdentity row +// (org / orgCode / NSF status carried as metadata), and reconciles the DN list. +func (h *RequestUserModifyHandler) applyReplace(ctx context.Context, tx *sql.Tx, packet *model.Packet, eventID string, body map[string]any, user *models.User, userGlobalID string) error { + if v := getString(body, "UserFirstName"); v != "" { + user.FirstName = v + } + if v := getString(body, "UserLastName"); v != "" { + user.LastName = v + } + if v := getString(body, "UserEmail"); v != "" { + user.Email = v + } + if err := h.svc.UpdateUser(ctx, user); err != nil { + return fmt.Errorf("update user: %w", err) + } + if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditUpdatePerson, "user", user.ID, ""); err != nil { + return fmt.Errorf("audit UPDATE_PERSON: %w", err) + } + + if err := h.updateExternalIdentity(ctx, body, user.ID, userGlobalID); err != nil { + return fmt.Errorf("update external identity: %w", err) + } + if err := h.syncDNs(ctx, tx, packet, eventID, body, user.ID); err != nil { + return fmt.Errorf("sync user DNs: %w", err) + } + return nil +} + +// updateExternalIdentity refreshes the AMIE ExternalIdentity row's metadata +// (organization, org_code, NSF status) — these are AMIE-side attributes that +// may shift over a user's lifetime. The row's source / external_id are +// immutable identifiers. +func (h *RequestUserModifyHandler) updateExternalIdentity(ctx context.Context, body map[string]any, userID, userGlobalID string) error { + ext, err := h.svc.GetExternalIdentityBySourceAndExternalID(ctx, amieIdentitySource, userGlobalID) + if err != nil { + if errors.Is(err, service.ErrNotFound) { + return nil + } + return err + } + metadata := map[string]any{} + if v := getString(body, "UserOrganization"); v != "" { + metadata["organization"] = v + } + if v := getString(body, "UserOrgCode"); v != "" { + metadata["org_code"] = v + } + if v := getString(body, "NsfStatusCode"); v != "" { + metadata["nsf_status_code"] = v + } + if len(metadata) == 0 { + return nil + } + encoded, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf("encode metadata: %w", err) + } + ext.UserID = userID + ext.Metadata = string(encoded) + return h.svc.UpdateExternalIdentity(ctx, ext) +} + +// syncDNs reconciles the user's DN list with the packet body's UserDnList: +// new DNs are added, DNs missing from the packet are removed. AMIE's +// request_user_modify with ActionType=replace is the authoritative source. +func (h *RequestUserModifyHandler) syncDNs(ctx context.Context, tx *sql.Tx, packet *model.Packet, eventID string, body map[string]any, userID string) error { + incoming := getDNList(body) + if incoming == nil { + return nil + } + desired := make(map[string]struct{}, len(incoming)) + for _, dn := range incoming { + desired[dn] = struct{}{} + } + existing, err := h.svc.ListUserDNs(ctx, userID) + if err != nil { + return fmt.Errorf("list user DNs: %w", err) + } + have := make(map[string]struct{}, len(existing)) + added := 0 + removed := 0 + for _, e := range existing { + have[e.DN] = struct{}{} + if _, keep := desired[e.DN]; !keep { + if err := h.svc.RemoveUserDN(ctx, e.ID); err != nil { + return fmt.Errorf("remove DN %s: %w", e.DN, err) + } + removed++ + } + } + for dn := range desired { + if _, exists := have[dn]; exists { + continue + } + if _, err := h.svc.AddUserDN(ctx, &models.UserDN{UserID: userID, DN: dn}); err != nil { + if errors.Is(err, service.ErrAlreadyExists) { + continue + } + return fmt.Errorf("add DN %s: %w", dn, err) + } + added++ + } + if added > 0 || removed > 0 { + if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, model.AuditPersistDNs, "user", userID, + fmt.Sprintf("DN sync: +%d -%d", added, removed)); err != nil { + return fmt.Errorf("audit PERSIST_DNS: %w", err) + } + } + return nil +}
