This is an automated email from the ASF dual-hosted git repository.

lahirujayathilake pushed a commit to branch access-integration
in repository https://gitbox.apache.org/repos/asf/airavata-custos.git

commit e5a58a57d0067ca1528a5c914851a2e23c7521e6
Author: lahiruj <[email protected]>
AuthorDate: Tue May 19 22:07:39 2026 -0400

    Complete AMIE inactivate, reactivate, and identity sync handlers
---
 .../AMIE-Processor/handler/data_account_create.go  |  24 ++--
 .../AMIE-Processor/handler/data_project_create.go  |  24 ++--
 .../ACCESS/AMIE-Processor/handler/handler.go       |  61 ++++++++++
 .../handler/request_account_create.go              |  14 ++-
 .../handler/request_account_inactivate.go          |  12 +-
 .../handler/request_account_reactivate.go          |  12 +-
 .../handler/request_project_create.go              |  49 ++++++--
 .../handler/request_project_inactivate.go          |  80 ++++++++++---
 .../handler/request_project_reactivate.go          |  83 ++++++++++---
 .../AMIE-Processor/handler/request_user_modify.go  | 129 ++++++++++++++++++---
 10 files changed, 409 insertions(+), 79 deletions(-)

diff --git a/connectors/ACCESS/AMIE-Processor/handler/data_account_create.go 
b/connectors/ACCESS/AMIE-Processor/handler/data_account_create.go
index 5d3626b8a..bfbe57df0 100644
--- a/connectors/ACCESS/AMIE-Processor/handler/data_account_create.go
+++ b/connectors/ACCESS/AMIE-Processor/handler/data_account_create.go
@@ -54,17 +54,21 @@ func (h *DataAccountCreateHandler) Handle(ctx 
context.Context, tx *sql.Tx, packe
                return err
        }
 
-       dns := getDNList(body)
-       if len(dns) > 0 {
-               user, err := h.svc.GetUserByExternalIdentity(ctx, 
amieIdentitySource, personGlobalID)
-               if err != nil {
-                       if errors.Is(err, service.ErrNotFound) {
-                               slog.WarnContext(ctx, "data_account_create: 
user not found for AMIE PersonID; skipping DN persistence",
-                                       "personGlobalID", personGlobalID)
-                       } else {
-                               return fmt.Errorf("data_account_create: resolve 
user: %w", err)
-                       }
+       user, err := h.svc.GetUserByExternalIdentity(ctx, amieIdentitySource, 
personGlobalID)
+       if err != nil {
+               if errors.Is(err, service.ErrNotFound) {
+                       slog.WarnContext(ctx, "data_account_create: user not 
found for AMIE PersonID; skipping DN persistence and ExternalIdentity upsert",
+                               "personGlobalID", personGlobalID)
                } else {
+                       return fmt.Errorf("data_account_create: resolve user: 
%w", err)
+               }
+       }
+       if user != nil {
+               if err := ensureExternalIdentity(ctx, h.svc, user.ID, 
personGlobalID); err != nil {
+                       return fmt.Errorf("data_account_create: ensure external 
identity: %w", err)
+               }
+               dns := getDNList(body)
+               if len(dns) > 0 {
                        for _, dn := range dns {
                                if _, err := h.svc.AddUserDN(ctx, 
&models.UserDN{UserID: user.ID, DN: dn}); err != nil {
                                        if errors.Is(err, 
service.ErrAlreadyExists) {
diff --git a/connectors/ACCESS/AMIE-Processor/handler/data_project_create.go 
b/connectors/ACCESS/AMIE-Processor/handler/data_project_create.go
index 476143494..e5dd3ce26 100644
--- a/connectors/ACCESS/AMIE-Processor/handler/data_project_create.go
+++ b/connectors/ACCESS/AMIE-Processor/handler/data_project_create.go
@@ -54,17 +54,21 @@ func (h *DataProjectCreateHandler) Handle(ctx 
context.Context, tx *sql.Tx, packe
                return err
        }
 
-       dns := getDNList(body)
-       if len(dns) > 0 {
-               user, err := h.svc.GetUserByExternalIdentity(ctx, 
amieIdentitySource, piGlobalID)
-               if err != nil {
-                       if errors.Is(err, service.ErrNotFound) {
-                               slog.WarnContext(ctx, "data_project_create: PI 
user not found; skipping DN persistence",
-                                       "piGlobalID", piGlobalID)
-                       } else {
-                               return fmt.Errorf("data_project_create: resolve 
PI user: %w", err)
-                       }
+       user, err := h.svc.GetUserByExternalIdentity(ctx, amieIdentitySource, 
piGlobalID)
+       if err != nil {
+               if errors.Is(err, service.ErrNotFound) {
+                       slog.WarnContext(ctx, "data_project_create: PI user not 
found; skipping DN persistence and ExternalIdentity upsert",
+                               "piGlobalID", piGlobalID)
                } else {
+                       return fmt.Errorf("data_project_create: resolve PI 
user: %w", err)
+               }
+       }
+       if user != nil {
+               if err := ensureExternalIdentity(ctx, h.svc, user.ID, 
piGlobalID); err != nil {
+                       return fmt.Errorf("data_project_create: ensure external 
identity: %w", err)
+               }
+               dns := getDNList(body)
+               if len(dns) > 0 {
                        for _, dn := range dns {
                                if _, err := h.svc.AddUserDN(ctx, 
&models.UserDN{UserID: user.ID, DN: dn}); err != nil {
                                        if errors.Is(err, 
service.ErrAlreadyExists) {
diff --git a/connectors/ACCESS/AMIE-Processor/handler/handler.go 
b/connectors/ACCESS/AMIE-Processor/handler/handler.go
index aca530b1c..3fd2664b7 100644
--- a/connectors/ACCESS/AMIE-Processor/handler/handler.go
+++ b/connectors/ACCESS/AMIE-Processor/handler/handler.go
@@ -111,6 +111,67 @@ func getResourceList(body map[string]any) []string {
        return result
 }
 
+// ensureExternalIdentity is the idempotent upsert used by data_project_create
+// and data_account_create. It is a no-op when the user already has an AMIE
+// ExternalIdentity row for this globalID; creates one otherwise. Pre-existing
+// rows are NOT touched here, attribute updates (org / orgCode / nsfStatus)
+// are owned by request_user_modify.
+func ensureExternalIdentity(ctx context.Context, svc *service.Service, userID, 
globalID string) error {
+       if existing, err := svc.GetExternalIdentityBySourceAndExternalID(ctx, 
amieIdentitySource, globalID); err == nil {
+               if existing.UserID == userID {
+                       return nil
+               }
+               return fmt.Errorf("external identity for %s=%s is bound to user 
%s (not %s)", amieIdentitySource, globalID, existing.UserID, userID)
+       } else if !errors.Is(err, service.ErrNotFound) {
+               return err
+       }
+       if _, err := svc.CreateExternalIdentity(ctx, &models.ExternalIdentity{
+               UserID:     userID,
+               Source:     amieIdentitySource,
+               ExternalID: globalID,
+       }); err != nil {
+               return fmt.Errorf("create external identity: %w", err)
+       }
+       return nil
+}
+
+// flipUserMemberships flips every ComputeAllocationMembership the user holds
+// under any allocation belonging to the given project (Custos project.id) to
+// the given status. Returns the rows that were updated. Silently returns an
+// empty slice when the project or user is unknown, the reply still goes back
+// to AMIE but no state changes.
+func flipUserMemberships(ctx context.Context, svc *service.Service, projectID, 
userID string, status models.AllocationStatus) 
([]models.ComputeAllocationMembership, error) {
+       project, err := svc.GetProject(ctx, projectID)
+       if err != nil {
+               if errors.Is(err, service.ErrNotFound) {
+                       return nil, nil
+               }
+               return nil, fmt.Errorf("lookup project: %w", err)
+       }
+       allocations, err := svc.ListComputeAllocationsByProject(ctx, project.ID)
+       if err != nil {
+               return nil, fmt.Errorf("list allocations: %w", err)
+       }
+       var updated []models.ComputeAllocationMembership
+       for _, a := range allocations {
+               members, err := svc.ListMembersForAllocation(ctx, a.ID)
+               if err != nil {
+                       return nil, fmt.Errorf("list memberships for allocation 
%s: %w", a.ID, err)
+               }
+               for _, m := range members {
+                       if m.UserID != userID {
+                               continue
+                       }
+                       flipped, err := svc.UpdateMembershipStatus(ctx, m.ID, 
status)
+                       if err != nil {
+                               return nil, fmt.Errorf("update membership %s: 
%w", m.ID, err)
+                       }
+                       updated = append(updated, *flipped)
+               }
+       }
+       return updated, nil
+}
+
 // ensureOrganization looks up an Organization by its originated_id (the
 // AMIE-side org code such as "TEST123"); creates one if missing, using the
 // human-readable organization name from the packet.
diff --git a/connectors/ACCESS/AMIE-Processor/handler/request_account_create.go 
b/connectors/ACCESS/AMIE-Processor/handler/request_account_create.go
index 3d228c62e..3ef0a152e 100644
--- a/connectors/ACCESS/AMIE-Processor/handler/request_account_create.go
+++ b/connectors/ACCESS/AMIE-Processor/handler/request_account_create.go
@@ -51,8 +51,8 @@ func (h *RequestAccountCreateHandler) Handle(ctx 
context.Context, tx *sql.Tx, pa
        if err != nil {
                return err
        }
-       projectOriginatedID := getString(body, "ProjectID")
-       if err := requireText(projectOriginatedID, "ProjectID"); err != nil {
+       projectID := getString(body, "ProjectID")
+       if err := requireText(projectID, "ProjectID"); err != nil {
                return err
        }
        if err := requireText(getString(body, "GrantNumber"), "GrantNumber"); 
err != nil {
@@ -74,9 +74,11 @@ func (h *RequestAccountCreateHandler) Handle(ctx 
context.Context, tx *sql.Tx, pa
                return fmt.Errorf("request_account_create: audit CREATE_PERSON: 
%w", err)
        }
 
-       project, err := h.svc.GetProjectByOriginatedID(ctx, projectOriginatedID)
+       // AMIE replies to notify_project_create with project.id (Custos UUID), 
so
+       // subsequent packets carry that id back to us as body.ProjectID.
+       project, err := h.svc.GetProject(ctx, projectID)
        if err != nil {
-               return fmt.Errorf("request_account_create: project %q not found 
(request_project_create must precede this packet): %w", projectOriginatedID, 
err)
+               return fmt.Errorf("request_account_create: project %q not found 
(request_project_create must precede this packet): %w", projectID, err)
        }
 
        allocations, err := h.svc.ListComputeAllocationsByProject(ctx, 
project.ID)
@@ -84,7 +86,7 @@ func (h *RequestAccountCreateHandler) Handle(ctx 
context.Context, tx *sql.Tx, pa
                return fmt.Errorf("request_account_create: list allocations: 
%w", err)
        }
        if len(allocations) == 0 {
-               return fmt.Errorf("request_account_create: project %q has no 
ComputeAllocation; request_project_create did not provision one", 
projectOriginatedID)
+               return fmt.Errorf("request_account_create: project %q has no 
ComputeAllocation; request_project_create did not provision one", projectID)
        }
        allocation := allocations[0]
 
@@ -107,7 +109,7 @@ func (h *RequestAccountCreateHandler) Handle(ctx 
context.Context, tx *sql.Tx, pa
        }
 
        replyBody := map[string]any{
-               "ProjectID":           projectOriginatedID,
+               "ProjectID":           projectID,
                "GrantNumber":         getString(body, "GrantNumber"),
                "UserPersonID":        user.ID,
                "UserRemoteSiteLogin": account.Username,
diff --git 
a/connectors/ACCESS/AMIE-Processor/handler/request_account_inactivate.go 
b/connectors/ACCESS/AMIE-Processor/handler/request_account_inactivate.go
index ef4d3d42c..ff5e188b6 100644
--- a/connectors/ACCESS/AMIE-Processor/handler/request_account_inactivate.go
+++ b/connectors/ACCESS/AMIE-Processor/handler/request_account_inactivate.go
@@ -23,6 +23,7 @@ import (
        "fmt"
 
        
"github.com/apache/airavata-custos/connectors/ACCESS/AMIE-Processor/model"
+       "github.com/apache/airavata-custos/pkg/models"
        "github.com/apache/airavata-custos/pkg/service"
 )
 
@@ -52,8 +53,15 @@ func (h *RequestAccountInactivateHandler) Handle(ctx 
context.Context, tx *sql.Tx
                return err
        }
 
-       if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditInactivateMembership, "membership_request", personID, 
fmt.Sprintf("project=%s person=%s (membership persistence pending allocation 
mapping)", projectID, personID)); err != nil {
-               return fmt.Errorf("request_account_inactivate: audit 
INACTIVATE_MEMBERSHIP: %w", err)
+       flipped, err := flipUserMemberships(ctx, h.svc, projectID, personID, 
models.INACTIVE)
+       if err != nil {
+               return fmt.Errorf("request_account_inactivate: flip 
memberships: %w", err)
+       }
+       for _, m := range flipped {
+               if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditInactivateMembership, "compute_allocation_membership", m.ID,
+                       fmt.Sprintf("user=%s allocation=%s", m.UserID, 
m.ComputeAllocationID)); err != nil {
+                       return fmt.Errorf("request_account_inactivate: audit 
INACTIVATE_MEMBERSHIP: %w", err)
+               }
        }
 
        reply := map[string]any{
diff --git 
a/connectors/ACCESS/AMIE-Processor/handler/request_account_reactivate.go 
b/connectors/ACCESS/AMIE-Processor/handler/request_account_reactivate.go
index 709d362eb..50f248493 100644
--- a/connectors/ACCESS/AMIE-Processor/handler/request_account_reactivate.go
+++ b/connectors/ACCESS/AMIE-Processor/handler/request_account_reactivate.go
@@ -23,6 +23,7 @@ import (
        "fmt"
 
        
"github.com/apache/airavata-custos/connectors/ACCESS/AMIE-Processor/model"
+       "github.com/apache/airavata-custos/pkg/models"
        "github.com/apache/airavata-custos/pkg/service"
 )
 
@@ -52,8 +53,15 @@ func (h *RequestAccountReactivateHandler) Handle(ctx 
context.Context, tx *sql.Tx
                return err
        }
 
-       if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditReactivateMembership, "membership_request", personID, 
fmt.Sprintf("project=%s person=%s (membership persistence pending allocation 
mapping)", projectID, personID)); err != nil {
-               return fmt.Errorf("request_account_reactivate: audit 
REACTIVATE_MEMBERSHIP: %w", err)
+       flipped, err := flipUserMemberships(ctx, h.svc, projectID, personID, 
models.ACTIVE)
+       if err != nil {
+               return fmt.Errorf("request_account_reactivate: flip 
memberships: %w", err)
+       }
+       for _, m := range flipped {
+               if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditReactivateMembership, "compute_allocation_membership", m.ID,
+                       fmt.Sprintf("user=%s allocation=%s", m.UserID, 
m.ComputeAllocationID)); err != nil {
+                       return fmt.Errorf("request_account_reactivate: audit 
REACTIVATE_MEMBERSHIP: %w", err)
+               }
        }
 
        reply := map[string]any{
diff --git a/connectors/ACCESS/AMIE-Processor/handler/request_project_create.go 
b/connectors/ACCESS/AMIE-Processor/handler/request_project_create.go
index 33a06dff5..db8cfce76 100644
--- a/connectors/ACCESS/AMIE-Processor/handler/request_project_create.go
+++ b/connectors/ACCESS/AMIE-Processor/handler/request_project_create.go
@@ -22,6 +22,7 @@ import (
        "database/sql"
        "errors"
        "fmt"
+       "strings"
 
        
"github.com/apache/airavata-custos/connectors/ACCESS/AMIE-Processor/model"
        "github.com/apache/airavata-custos/pkg/models"
@@ -63,7 +64,12 @@ func (h *RequestProjectCreateHandler) Handle(ctx 
context.Context, tx *sql.Tx, pa
        }
        // AMIE protocol: request_project_create does not carry a ProjectID. The
        // receiving site assigns one. We use the GrantNumber as the 
originated_id
-       // since it is the stable cross-site identifier on the AMIE side.
+       // (the stable cross-site identifier on the AMIE side) so first delivery
+       // and supplement/renewal re-deliveries map to the same Project row.
+       //
+       // TODO(amie-integration, grant-number-modeling): verify whether
+       // GrantNumber should be promoted to a first-class column on `projects`
+       // (with its own UNIQUE constraint) instead of being injected into 
originated_id.
        projectOriginatedID := grantNumber
 
        pi, err := h.ensurePIUser(ctx, body, piGlobalID)
@@ -150,20 +156,18 @@ func (h *RequestProjectCreateHandler) ensureProject(ctx 
context.Context, origina
        })
 }
 
-// ensureAllocation creates a ComputeAllocation for the project if none exists
-// yet. If one already exists (e.g. a repeat request_project_create signaling a
-// supplement/renewal), the existing row is returned unchanged.
-//
-// TODO(amie-integration, allocation-type): branch on body["AllocationType"]
-// (new / renewal / supplement / extension) and adjust the allocation
-// accordingly.
+// ensureAllocation creates a ComputeAllocation for the project on first
+// delivery. On repeat delivery (supplement / renewal / extension / 
adjustment),
+// the existing row is preserved and a ComputeAllocationDiff is recorded
+// capturing the grant event. InitialSUAmount on the parent row stays as the
+// original grant; effective SUs = InitialSUAmount + sum(grant diffs).
 func (h *RequestProjectCreateHandler) ensureAllocation(ctx context.Context, 
body map[string]any, projectID, grantNumber string) (*models.ComputeAllocation, 
error) {
        existing, err := h.svc.ListComputeAllocationsByProject(ctx, projectID)
        if err != nil {
                return nil, fmt.Errorf("list allocations: %w", err)
        }
        if len(existing) > 0 {
-               return &existing[0], nil
+               return h.recordAllocationDiff(ctx, body, &existing[0])
        }
 
        su, err := getInt64(body, "ServiceUnitsAllocated")
@@ -188,3 +192,30 @@ func (h *RequestProjectCreateHandler) ensureAllocation(ctx 
context.Context, body
                EndTime:          end,
        })
 }
+
+// recordAllocationDiff writes a ComputeAllocationDiff for a re-delivered
+// request_project_create against an existing allocation (the AMIE pattern for
+// supplements / renewals / extensions / adjustments). DiffType is the 
upper-cased
+// AllocationType from the packet body; falls back to "GRANT" when AMIE did 
not supply one.
+// NewSUAmount carries this packet's ServiceUnitsAllocated, the delta granted
+// by this event, not the cumulative total.
+func (h *RequestProjectCreateHandler) recordAllocationDiff(ctx 
context.Context, body map[string]any, existing *models.ComputeAllocation) 
(*models.ComputeAllocation, error) {
+       su, err := getInt64(body, "ServiceUnitsAllocated")
+       if err != nil {
+               return nil, err
+       }
+       diffType := strings.ToUpper(strings.TrimSpace(getString(body, 
"AllocationType")))
+       if diffType == "" {
+               diffType = "GRANT"
+       }
+       if _, err := h.svc.CreateComputeAllocationDiff(ctx, 
&models.ComputeAllocationDiff{
+               ComputeAllocationID: existing.ID,
+               DiffType:            diffType,
+               NewSUAmount:         su,
+               Status:              models.ACTIVE,
+               Description:         fmt.Sprintf("AMIE %s of %d SUs", diffType, 
su),
+       }); err != nil {
+               return nil, fmt.Errorf("record allocation diff: %w", err)
+       }
+       return existing, nil
+}
diff --git 
a/connectors/ACCESS/AMIE-Processor/handler/request_project_inactivate.go 
b/connectors/ACCESS/AMIE-Processor/handler/request_project_inactivate.go
index 286953255..64b76036b 100644
--- a/connectors/ACCESS/AMIE-Processor/handler/request_project_inactivate.go
+++ b/connectors/ACCESS/AMIE-Processor/handler/request_project_inactivate.go
@@ -46,31 +46,78 @@ func (h *RequestProjectInactivateHandler) Handle(ctx 
context.Context, tx *sql.Tx
        if err != nil {
                return err
        }
-       originatedID := getString(body, "ProjectID")
-       if err := requireText(originatedID, "ProjectID"); err != nil {
+       projectID := getString(body, "ProjectID")
+       if err := requireText(projectID, "ProjectID"); err != nil {
                return err
        }
 
-       project, err := h.svc.GetProjectByOriginatedID(ctx, originatedID)
+       // AMIE carries the Custos project.id we returned in 
notify_project_create.
+       project, err := h.svc.GetProject(ctx, projectID)
        if err != nil {
                if errors.Is(err, service.ErrNotFound) {
-                       slog.WarnContext(ctx, "request_project_inactivate: 
project not found in core; skipping status flip",
-                               "originatedID", originatedID)
-               } else {
-                       return fmt.Errorf("request_project_inactivate: lookup 
project: %w", err)
+                       slog.WarnContext(ctx, "request_project_inactivate: 
project not found in core; skipping",
+                               "projectID", projectID)
+                       return h.reply(ctx, tx, packet, eventID, projectID)
                }
-       } else {
-               if _, err := h.svc.UpdateProjectStatus(ctx, project.ID, 
models.INACTIVE); err != nil {
-                       return fmt.Errorf("request_project_inactivate: update 
status: %w", err)
+               return fmt.Errorf("request_project_inactivate: lookup project: 
%w", err)
+       }
+
+       if _, err := h.svc.UpdateProjectStatus(ctx, project.ID, 
models.INACTIVE); err != nil {
+               return fmt.Errorf("request_project_inactivate: update project 
status: %w", err)
+       }
+       if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditInactivateProject, "project", project.ID, ""); err != nil {
+               return fmt.Errorf("request_project_inactivate: audit 
INACTIVATE_PROJECT: %w", err)
+       }
+
+       if err := h.deactivateAllocations(ctx, tx, packet, eventID, project.ID, 
getString(body, "Comment")); err != nil {
+               return err
+       }
+
+       return h.reply(ctx, tx, packet, eventID, projectID)
+}
+
+// deactivateAllocations flips every ComputeAllocation under the project to
+// INACTIVE, writes a status-change Diff per allocation, and flips every
+// ComputeAllocationMembership under those allocations to INACTIVE.
+func (h *RequestProjectInactivateHandler) deactivateAllocations(ctx 
context.Context, tx *sql.Tx, packet *model.Packet, eventID, projectID, comment 
string) error {
+       allocations, err := h.svc.ListComputeAllocationsByProject(ctx, 
projectID)
+       if err != nil {
+               return fmt.Errorf("list allocations: %w", err)
+       }
+       for _, a := range allocations {
+               a.Status = models.INACTIVE
+               if err := h.svc.UpdateComputeAllocation(ctx, &a); err != nil {
+                       return fmt.Errorf("inactivate allocation %s: %w", a.ID, 
err)
+               }
+               if _, err := h.svc.CreateComputeAllocationDiff(ctx, 
&models.ComputeAllocationDiff{
+                       ComputeAllocationID: a.ID,
+                       DiffType:            "ALLOCATION_STATUS_CHANGE",
+                       Status:              models.INACTIVE,
+                       Description:         describeStatusChange("Inactivated 
by AMIE request_project_inactivate", comment),
+               }); err != nil {
+                       return fmt.Errorf("record inactivate diff for %s: %w", 
a.ID, err)
                }
-               if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditInactivateProject, "project", project.ID, ""); err != nil {
-                       return fmt.Errorf("request_project_inactivate: audit 
INACTIVATE_PROJECT: %w", err)
+               members, err := h.svc.ListMembersForAllocation(ctx, a.ID)
+               if err != nil {
+                       return fmt.Errorf("list memberships for allocation %s: 
%w", a.ID, err)
+               }
+               for _, m := range members {
+                       if _, err := h.svc.UpdateMembershipStatus(ctx, m.ID, 
models.INACTIVE); err != nil {
+                               return fmt.Errorf("inactivate membership %s: 
%w", m.ID, err)
+                       }
+                       if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditInactivateMembership, "compute_allocation_membership", m.ID,
+                               fmt.Sprintf("user=%s allocation=%s", m.UserID, 
a.ID)); err != nil {
+                               return fmt.Errorf("audit INACTIVATE_MEMBERSHIP 
for %s: %w", m.ID, err)
+                       }
                }
        }
+       return nil
+}
 
+func (h *RequestProjectInactivateHandler) reply(ctx context.Context, tx 
*sql.Tx, packet *model.Packet, eventID, projectID string) error {
        reply := map[string]any{
                "type": "notify_project_inactivate",
-               "body": map[string]any{"ProjectID": originatedID},
+               "body": map[string]any{"ProjectID": projectID},
        }
        if err := h.amieClient.ReplyToPacket(ctx, packet.AmieID, reply); err != 
nil {
                return fmt.Errorf("request_project_inactivate: sending reply: 
%w", err)
@@ -80,3 +127,10 @@ func (h *RequestProjectInactivateHandler) Handle(ctx 
context.Context, tx *sql.Tx
        }
        return nil
 }
+
+func describeStatusChange(primary, comment string) string {
+       if comment == "" {
+               return primary
+       }
+       return primary + " (" + comment + ")"
+}
diff --git 
a/connectors/ACCESS/AMIE-Processor/handler/request_project_reactivate.go 
b/connectors/ACCESS/AMIE-Processor/handler/request_project_reactivate.go
index 4a6820bd1..409ad6976 100644
--- a/connectors/ACCESS/AMIE-Processor/handler/request_project_reactivate.go
+++ b/connectors/ACCESS/AMIE-Processor/handler/request_project_reactivate.go
@@ -41,36 +41,93 @@ func NewRequestProjectReactivateHandler(svc 
*service.Service, amieClient AmieCli
 
 func (h *RequestProjectReactivateHandler) SupportsType() string { return 
"request_project_reactivate" }
 
+// Handle flips the Project and all of its ComputeAllocations back to ACTIVE.
+// Per AMIE protocol, only the PI's membership is reactivated automatically;
+// other members must request reactivation via request_account_reactivate.
 func (h *RequestProjectReactivateHandler) Handle(ctx context.Context, tx 
*sql.Tx, packetJSON map[string]any, packet *model.Packet, eventID string) error 
{
        body, err := getBody(packetJSON)
        if err != nil {
                return err
        }
-       originatedID := getString(body, "ProjectID")
-       if err := requireText(originatedID, "ProjectID"); err != nil {
+       projectID := getString(body, "ProjectID")
+       if err := requireText(projectID, "ProjectID"); err != nil {
                return err
        }
 
-       project, err := h.svc.GetProjectByOriginatedID(ctx, originatedID)
+       // AMIE carries the Custos project.id we returned in 
notify_project_create.
+       project, err := h.svc.GetProject(ctx, projectID)
        if err != nil {
                if errors.Is(err, service.ErrNotFound) {
-                       slog.WarnContext(ctx, "request_project_reactivate: 
project not found in core; skipping status flip",
-                               "originatedID", originatedID)
-               } else {
-                       return fmt.Errorf("request_project_reactivate: lookup 
project: %w", err)
+                       slog.WarnContext(ctx, "request_project_reactivate: 
project not found in core; skipping",
+                               "projectID", projectID)
+                       return h.reply(ctx, tx, packet, eventID, projectID)
                }
-       } else {
-               if _, err := h.svc.UpdateProjectStatus(ctx, project.ID, 
models.ACTIVE); err != nil {
-                       return fmt.Errorf("request_project_reactivate: update 
status: %w", err)
+               return fmt.Errorf("request_project_reactivate: lookup project: 
%w", err)
+       }
+
+       if _, err := h.svc.UpdateProjectStatus(ctx, project.ID, models.ACTIVE); 
err != nil {
+               return fmt.Errorf("request_project_reactivate: update project 
status: %w", err)
+       }
+       if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditReactivateProject, "project", project.ID, ""); err != nil {
+               return fmt.Errorf("request_project_reactivate: audit 
REACTIVATE_PROJECT: %w", err)
+       }
+
+       if err := h.reactivateAllocationsAndPI(ctx, tx, packet, eventID, 
project); err != nil {
+               return err
+       }
+
+       return h.reply(ctx, tx, packet, eventID, projectID)
+}
+
+// reactivateAllocationsAndPI flips every ComputeAllocation under the project
+// back to ACTIVE, writes a status-change Diff per allocation, and reactivates
+// only the PI's membership on each allocation. Other members stay INACTIVE
+// until their own request_account_reactivate arrives.
+func (h *RequestProjectReactivateHandler) reactivateAllocationsAndPI(ctx 
context.Context, tx *sql.Tx, packet *model.Packet, eventID string, project 
*models.Project) error {
+       allocations, err := h.svc.ListComputeAllocationsByProject(ctx, 
project.ID)
+       if err != nil {
+               return fmt.Errorf("list allocations: %w", err)
+       }
+       for _, a := range allocations {
+               a.Status = models.ACTIVE
+               if err := h.svc.UpdateComputeAllocation(ctx, &a); err != nil {
+                       return fmt.Errorf("reactivate allocation %s: %w", a.ID, 
err)
+               }
+               if _, err := h.svc.CreateComputeAllocationDiff(ctx, 
&models.ComputeAllocationDiff{
+                       ComputeAllocationID: a.ID,
+                       DiffType:            "ALLOCATION_STATUS_CHANGE",
+                       Status:              models.ACTIVE,
+                       Description:         "Reactivated by AMIE 
request_project_reactivate",
+               }); err != nil {
+                       return fmt.Errorf("record reactivate diff for %s: %w", 
a.ID, err)
                }
-               if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditReactivateProject, "project", project.ID, ""); err != nil {
-                       return fmt.Errorf("request_project_reactivate: audit 
REACTIVATE_PROJECT: %w", err)
+               if project.ProjectPIID == "" {
+                       continue
+               }
+               members, err := h.svc.ListMembersForAllocation(ctx, a.ID)
+               if err != nil {
+                       return fmt.Errorf("list memberships for allocation %s: 
%w", a.ID, err)
+               }
+               for _, m := range members {
+                       if m.UserID != project.ProjectPIID {
+                               continue
+                       }
+                       if _, err := h.svc.UpdateMembershipStatus(ctx, m.ID, 
models.ACTIVE); err != nil {
+                               return fmt.Errorf("reactivate PI membership %s: 
%w", m.ID, err)
+                       }
+                       if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditReactivateMembership, "compute_allocation_membership", m.ID,
+                               fmt.Sprintf("PI user=%s allocation=%s", 
m.UserID, a.ID)); err != nil {
+                               return fmt.Errorf("audit REACTIVATE_MEMBERSHIP 
for %s: %w", m.ID, err)
+                       }
                }
        }
+       return nil
+}
 
+func (h *RequestProjectReactivateHandler) reply(ctx context.Context, tx 
*sql.Tx, packet *model.Packet, eventID, projectID string) error {
        reply := map[string]any{
                "type": "notify_project_reactivate",
-               "body": map[string]any{"ProjectID": originatedID},
+               "body": map[string]any{"ProjectID": projectID},
        }
        if err := h.amieClient.ReplyToPacket(ctx, packet.AmieID, reply); err != 
nil {
                return fmt.Errorf("request_project_reactivate: sending reply: 
%w", err)
diff --git a/connectors/ACCESS/AMIE-Processor/handler/request_user_modify.go 
b/connectors/ACCESS/AMIE-Processor/handler/request_user_modify.go
index 511f2ca28..b72031b22 100644
--- a/connectors/ACCESS/AMIE-Processor/handler/request_user_modify.go
+++ b/connectors/ACCESS/AMIE-Processor/handler/request_user_modify.go
@@ -20,11 +20,13 @@ package handler
 import (
        "context"
        "database/sql"
+       "encoding/json"
        "errors"
        "fmt"
        "strings"
 
        
"github.com/apache/airavata-custos/connectors/ACCESS/AMIE-Processor/model"
+       "github.com/apache/airavata-custos/pkg/models"
        "github.com/apache/airavata-custos/pkg/service"
 )
 
@@ -65,20 +67,8 @@ func (h *RequestUserModifyHandler) Handle(ctx 
context.Context, tx *sql.Tx, packe
        switch {
        case strings.EqualFold(actionType, "replace"):
                if user != nil {
-                       if v := getString(body, "UserFirstName"); v != "" {
-                               user.FirstName = v
-                       }
-                       if v := getString(body, "UserLastName"); v != "" {
-                               user.LastName = v
-                       }
-                       if v := getString(body, "UserEmail"); v != "" {
-                               user.Email = v
-                       }
-                       if err := h.svc.UpdateUser(ctx, user); err != nil {
-                               return fmt.Errorf("request_user_modify: update 
user: %w", err)
-                       }
-                       if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditUpdatePerson, "user", user.ID, ""); err != nil {
-                               return fmt.Errorf("request_user_modify: audit 
UPDATE_PERSON: %w", err)
+                       if err := h.applyReplace(ctx, tx, packet, eventID, 
body, user, userGlobalID); err != nil {
+                               return err
                        }
                }
        case strings.EqualFold(actionType, "delete"):
@@ -110,3 +100,114 @@ func (h *RequestUserModifyHandler) Handle(ctx 
context.Context, tx *sql.Tx, packe
        }
        return nil
 }
+
+// applyReplace updates the User row (basic profile), the ExternalIdentity row
+// (org / orgCode / NSF status carried as metadata), and reconciles the DN 
list.
+func (h *RequestUserModifyHandler) applyReplace(ctx context.Context, tx 
*sql.Tx, packet *model.Packet, eventID string, body map[string]any, user 
*models.User, userGlobalID string) error {
+       if v := getString(body, "UserFirstName"); v != "" {
+               user.FirstName = v
+       }
+       if v := getString(body, "UserLastName"); v != "" {
+               user.LastName = v
+       }
+       if v := getString(body, "UserEmail"); v != "" {
+               user.Email = v
+       }
+       if err := h.svc.UpdateUser(ctx, user); err != nil {
+               return fmt.Errorf("update user: %w", err)
+       }
+       if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditUpdatePerson, "user", user.ID, ""); err != nil {
+               return fmt.Errorf("audit UPDATE_PERSON: %w", err)
+       }
+
+       if err := h.updateExternalIdentity(ctx, body, user.ID, userGlobalID); 
err != nil {
+               return fmt.Errorf("update external identity: %w", err)
+       }
+       if err := h.syncDNs(ctx, tx, packet, eventID, body, user.ID); err != 
nil {
+               return fmt.Errorf("sync user DNs: %w", err)
+       }
+       return nil
+}
+
+// updateExternalIdentity refreshes the AMIE ExternalIdentity row's metadata
+// (organization, org_code, NSF status) — these are AMIE-side attributes that
+// may shift over a user's lifetime. The row's source / external_id are
+// immutable identifiers.
+func (h *RequestUserModifyHandler) updateExternalIdentity(ctx context.Context, 
body map[string]any, userID, userGlobalID string) error {
+       ext, err := h.svc.GetExternalIdentityBySourceAndExternalID(ctx, 
amieIdentitySource, userGlobalID)
+       if err != nil {
+               if errors.Is(err, service.ErrNotFound) {
+                       return nil
+               }
+               return err
+       }
+       metadata := map[string]any{}
+       if v := getString(body, "UserOrganization"); v != "" {
+               metadata["organization"] = v
+       }
+       if v := getString(body, "UserOrgCode"); v != "" {
+               metadata["org_code"] = v
+       }
+       if v := getString(body, "NsfStatusCode"); v != "" {
+               metadata["nsf_status_code"] = v
+       }
+       if len(metadata) == 0 {
+               return nil
+       }
+       encoded, err := json.Marshal(metadata)
+       if err != nil {
+               return fmt.Errorf("encode metadata: %w", err)
+       }
+       ext.UserID = userID
+       ext.Metadata = string(encoded)
+       return h.svc.UpdateExternalIdentity(ctx, ext)
+}
+
+// syncDNs reconciles the user's DN list with the packet body's UserDnList:
+// new DNs are added, DNs missing from the packet are removed. AMIE's
+// request_user_modify with ActionType=replace is the authoritative source.
+func (h *RequestUserModifyHandler) syncDNs(ctx context.Context, tx *sql.Tx, 
packet *model.Packet, eventID string, body map[string]any, userID string) error 
{
+       incoming := getDNList(body)
+       if incoming == nil {
+               return nil
+       }
+       desired := make(map[string]struct{}, len(incoming))
+       for _, dn := range incoming {
+               desired[dn] = struct{}{}
+       }
+       existing, err := h.svc.ListUserDNs(ctx, userID)
+       if err != nil {
+               return fmt.Errorf("list user DNs: %w", err)
+       }
+       have := make(map[string]struct{}, len(existing))
+       added := 0
+       removed := 0
+       for _, e := range existing {
+               have[e.DN] = struct{}{}
+               if _, keep := desired[e.DN]; !keep {
+                       if err := h.svc.RemoveUserDN(ctx, e.ID); err != nil {
+                               return fmt.Errorf("remove DN %s: %w", e.DN, err)
+                       }
+                       removed++
+               }
+       }
+       for dn := range desired {
+               if _, exists := have[dn]; exists {
+                       continue
+               }
+               if _, err := h.svc.AddUserDN(ctx, &models.UserDN{UserID: 
userID, DN: dn}); err != nil {
+                       if errors.Is(err, service.ErrAlreadyExists) {
+                               continue
+                       }
+                       return fmt.Errorf("add DN %s: %w", dn, err)
+               }
+               added++
+       }
+       if added > 0 || removed > 0 {
+               if err := h.auditSvc.Log(ctx, tx, packet.ID, eventID, 
model.AuditPersistDNs, "user", userID,
+                       fmt.Sprintf("DN sync: +%d -%d", added, removed)); err 
!= nil {
+                       return fmt.Errorf("audit PERSIST_DNS: %w", err)
+               }
+       }
+       return nil
+}

Reply via email to