https://github.com/ahmedbougacha created 
https://github.com/llvm/llvm-project/pull/97647

Enabled in clang using:

    -fptrauth-indirect-gotos

and at the IR level using function attribute:

    "ptrauth-indirect-gotos"

Signing uses IA and a per-function integer discriminator. The discriminator 
isn't ABI-visible, and is currently:

    ptrauth_string_discriminator("<function_name> blockaddress")

A sufficiently sophisticated frontend could benefit from per-indirectbr 
discrimination, which would need additional machinery, such as allowing 
"ptrauth" bundles on indirectbr. For our purposes, the simple scheme above is 
sufficient.

>From f0d8af86161c6037e9e0d1fe800e5876dd090092 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ah...@bougacha.org>
Date: Tue, 12 Mar 2024 14:40:17 -0700
Subject: [PATCH] [AArch64][PAC] Sign block addresses used in indirectbr.

Enabled in clang using:
  -fptrauth-indirect-gotos

and at the IR level using function attribute:
  "ptrauth-indirect-gotos"

Signing uses IA and a per-function integer discriminator.
The discriminator isn't ABI-visible, and is currently:
  ptrauth_string_discriminator("<function_name> blockaddress")

A sufficiently sophisticated frontend could benefit from
per-indirectbr discrimination, which would need additional
machinery, such as allowing "ptrauth" bundles on indirectbr.
For our purposes, the simple scheme above is sufficient.
---
 clang/include/clang/Basic/Features.def        |   1 +
 clang/include/clang/Basic/LangOptions.def     |   1 +
 .../include/clang/Basic/PointerAuthOptions.h  |   3 +
 clang/include/clang/Driver/Options.td         |   2 +
 clang/lib/CodeGen/CodeGenFunction.cpp         |   2 +
 clang/lib/Driver/ToolChains/Clang.cpp         |   3 +
 clang/lib/Frontend/CompilerInvocation.cpp     |   6 +-
 .../CodeGen/ptrauth-function-attributes.c     |   5 +
 llvm/docs/PointerAuth.md                      |  24 ++++
 llvm/include/llvm/CodeGen/AsmPrinter.h        |   3 +
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp    |   6 +-
 llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp |  35 +++++-
 llvm/lib/Target/AArch64/AArch64FastISel.cpp   |   4 +
 .../Target/AArch64/AArch64ISelLowering.cpp    |  54 ++++++++-
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |   1 +
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |  18 +++
 llvm/lib/Target/AArch64/AArch64Subtarget.cpp  |  11 ++
 llvm/lib/Target/AArch64/AArch64Subtarget.h    |   9 ++
 .../GISel/AArch64InstructionSelector.cpp      |  26 +++++
 .../CodeGen/AArch64/ptrauth-indirectbr.ll     | 106 ++++++++++++++++++
 20 files changed, 309 insertions(+), 11 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/ptrauth-indirectbr.ll

diff --git a/clang/include/clang/Basic/Features.def 
b/clang/include/clang/Basic/Features.def
index 53f410d3cb4bd..cf800afe08557 100644
--- a/clang/include/clang/Basic/Features.def
+++ b/clang/include/clang/Basic/Features.def
@@ -108,6 +108,7 @@ FEATURE(ptrauth_calls, LangOpts.PointerAuthCalls)
 FEATURE(ptrauth_returns, LangOpts.PointerAuthReturns)
 FEATURE(ptrauth_vtable_pointer_address_discrimination, 
LangOpts.PointerAuthVTPtrAddressDiscrimination)
 FEATURE(ptrauth_vtable_pointer_type_discrimination, 
LangOpts.PointerAuthVTPtrTypeDiscrimination)
+FEATURE(ptrauth_indirect_gotos, LangOpts.PointerAuthIndirectGotos)
 FEATURE(ptrauth_member_function_pointer_type_discrimination, 
LangOpts.PointerAuthCalls)
 FEATURE(ptrauth_init_fini, LangOpts.PointerAuthInitFini)
 EXTENSION(swiftcc,
diff --git a/clang/include/clang/Basic/LangOptions.def 
b/clang/include/clang/Basic/LangOptions.def
index 491759e2fcdbb..bdf77a5b35208 100644
--- a/clang/include/clang/Basic/LangOptions.def
+++ b/clang/include/clang/Basic/LangOptions.def
@@ -165,6 +165,7 @@ LANGOPT(ExperimentalLibrary, 1, 0, "enable unstable and 
experimental library fea
 LANGOPT(PointerAuthIntrinsics, 1, 0, "pointer authentication intrinsics")
 LANGOPT(PointerAuthCalls  , 1, 0, "function pointer authentication")
 LANGOPT(PointerAuthReturns, 1, 0, "return pointer authentication")
+LANGOPT(PointerAuthIndirectGotos, 1, 0, "indirect gotos pointer 
authentication")
 LANGOPT(PointerAuthAuthTraps, 1, 0, "pointer authentication failure traps")
 LANGOPT(PointerAuthVTPtrAddressDiscrimination, 1, 0, "incorporate address 
discrimination in authenticated vtable pointers")
 LANGOPT(PointerAuthVTPtrTypeDiscrimination, 1, 0, "incorporate type 
discrimination in authenticated vtable pointers")
diff --git a/clang/include/clang/Basic/PointerAuthOptions.h 
b/clang/include/clang/Basic/PointerAuthOptions.h
index 197d63642ca6d..2711639dbe299 100644
--- a/clang/include/clang/Basic/PointerAuthOptions.h
+++ b/clang/include/clang/Basic/PointerAuthOptions.h
@@ -154,6 +154,9 @@ class PointerAuthSchema {
 };
 
 struct PointerAuthOptions {
+  /// Do indirect goto label addresses need to be authenticated?
+  bool IndirectGotos = false;
+
   /// The ABI for C function pointers.
   PointerAuthSchema FunctionPointers;
 
diff --git a/clang/include/clang/Driver/Options.td 
b/clang/include/clang/Driver/Options.td
index 58ca6f2bea9e4..791b7261ddbda 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -4228,6 +4228,8 @@ defm ptrauth_vtable_pointer_address_discrimination :
 defm ptrauth_vtable_pointer_type_discrimination :
   OptInCC1FFlag<"ptrauth-vtable-pointer-type-discrimination", "Enable type 
discrimination of vtable pointers">;
 defm ptrauth_init_fini : OptInCC1FFlag<"ptrauth-init-fini", "Enable signing of 
function pointers in init/fini arrays">;
+defm ptrauth_indirect_gotos : OptInCC1FFlag<"ptrauth-indirect-gotos",
+  "Enable signing and authentication of indirect goto targets">;
 }
 
 def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp 
b/clang/lib/CodeGen/CodeGenFunction.cpp
index 26deeca95d326..33cb07a5bc832 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -865,6 +865,8 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType 
RetTy,
   const CodeGenOptions &CodeGenOpts = CGM.getCodeGenOpts();
   if (CodeGenOpts.PointerAuth.FunctionPointers)
     Fn->addFnAttr("ptrauth-calls");
+  if (CodeGenOpts.PointerAuth.IndirectGotos)
+    Fn->addFnAttr("ptrauth-indirect-gotos");
 
   // Apply xray attributes to the function (as a string, for now)
   bool AlwaysXRayAttr = false;
diff --git a/clang/lib/Driver/ToolChains/Clang.cpp 
b/clang/lib/Driver/ToolChains/Clang.cpp
index aa285c39f14b4..f9dba2c9f22cd 100644
--- a/clang/lib/Driver/ToolChains/Clang.cpp
+++ b/clang/lib/Driver/ToolChains/Clang.cpp
@@ -1789,6 +1789,9 @@ void Clang::AddAArch64TargetArgs(const ArgList &Args,
       options::OPT_fno_ptrauth_vtable_pointer_type_discrimination);
   Args.addOptInFlag(CmdArgs, options::OPT_fptrauth_init_fini,
                     options::OPT_fno_ptrauth_init_fini);
+
+  Args.addOptInFlag(CmdArgs, options::OPT_fptrauth_indirect_gotos,
+                    options::OPT_fno_ptrauth_indirect_gotos);
 }
 
 void Clang::AddLoongArchTargetArgs(const ArgList &Args,
diff --git a/clang/lib/Frontend/CompilerInvocation.cpp 
b/clang/lib/Frontend/CompilerInvocation.cpp
index f42e28ba7e629..a64e394a7754e 100644
--- a/clang/lib/Frontend/CompilerInvocation.cpp
+++ b/clang/lib/Frontend/CompilerInvocation.cpp
@@ -1480,13 +1480,14 @@ void CompilerInvocation::setDefaultPointerAuthOptions(
     Opts.CXXVirtualFunctionPointers = Opts.CXXVirtualVariadicFunctionPointers =
         PointerAuthSchema(Key::ASIA, true, Discrimination::Decl);
   }
+  Opts.IndirectGotos = LangOpts.PointerAuthIndirectGotos;
 }
 
 static void parsePointerAuthOptions(PointerAuthOptions &Opts,
                                     const LangOptions &LangOpts,
                                     const llvm::Triple &Triple,
                                     DiagnosticsEngine &Diags) {
-  if (!LangOpts.PointerAuthCalls)
+  if (!LangOpts.PointerAuthCalls && !LangOpts.PointerAuthIndirectGotos)
     return;
 
   CompilerInvocation::setDefaultPointerAuthOptions(Opts, LangOpts, Triple);
@@ -3390,6 +3391,8 @@ static void GeneratePointerAuthArgs(const LangOptions 
&Opts,
     GenerateArg(Consumer, OPT_fptrauth_calls);
   if (Opts.PointerAuthReturns)
     GenerateArg(Consumer, OPT_fptrauth_returns);
+  if (Opts.PointerAuthIndirectGotos)
+    GenerateArg(Consumer, OPT_fptrauth_indirect_gotos);
   if (Opts.PointerAuthAuthTraps)
     GenerateArg(Consumer, OPT_fptrauth_auth_traps);
   if (Opts.PointerAuthVTPtrAddressDiscrimination)
@@ -3405,6 +3408,7 @@ static void ParsePointerAuthArgs(LangOptions &Opts, 
ArgList &Args,
   Opts.PointerAuthIntrinsics = Args.hasArg(OPT_fptrauth_intrinsics);
   Opts.PointerAuthCalls = Args.hasArg(OPT_fptrauth_calls);
   Opts.PointerAuthReturns = Args.hasArg(OPT_fptrauth_returns);
+  Opts.PointerAuthIndirectGotos = Args.hasArg(OPT_fptrauth_indirect_gotos);
   Opts.PointerAuthAuthTraps = Args.hasArg(OPT_fptrauth_auth_traps);
   Opts.PointerAuthVTPtrAddressDiscrimination =
       Args.hasArg(OPT_fptrauth_vtable_pointer_address_discrimination);
diff --git a/clang/test/CodeGen/ptrauth-function-attributes.c 
b/clang/test/CodeGen/ptrauth-function-attributes.c
index 7ec30498b9d35..7f93ccc7c4bce 100644
--- a/clang/test/CodeGen/ptrauth-function-attributes.c
+++ b/clang/test/CodeGen/ptrauth-function-attributes.c
@@ -4,10 +4,15 @@
 // RUN: %clang_cc1 -triple arm64-apple-ios  -fptrauth-calls   -emit-llvm %s  
-o - | FileCheck %s --check-prefixes=ALL,CALLS
 // RUN: %clang_cc1 -triple aarch64-linux-gnu -fptrauth-calls  -emit-llvm %s  
-o - | FileCheck %s --check-prefixes=ALL,CALLS
 
+// RUN: %clang_cc1 -triple arm64-apple-ios  -fptrauth-indirect-gotos 
-emit-llvm %s -o - | FileCheck %s --check-prefixes=ALL,GOTOS
+// RUN: %clang_cc1 -triple arm64e-apple-ios -fptrauth-indirect-gotos 
-emit-llvm %s -o - | FileCheck %s --check-prefixes=ALL,GOTOS
+
 // ALL: define {{(dso_local )?}}void @test() #0
 void test() {
 }
 
 // CALLS: attributes #0 = {{{.*}} "ptrauth-calls" {{.*}}}
 
+// GOTOS: attributes #0 = {{{.*}} "ptrauth-indirect-gotos" {{.*}}}
+
 // OFF-NOT: attributes {{.*}} "ptrauth-
diff --git a/llvm/docs/PointerAuth.md b/llvm/docs/PointerAuth.md
index cf2cc6305f130..e027c902e58e1 100644
--- a/llvm/docs/PointerAuth.md
+++ b/llvm/docs/PointerAuth.md
@@ -18,6 +18,9 @@ At the IR level, it is represented using:
 * a [set of intrinsics](#intrinsics) (to sign/authenticate pointers)
 * a [signed pointer constant](#constant) (to sign globals)
 * a [call operand bundle](#operand-bundle) (to authenticate called pointers)
+* a [set of function attributes](#function-attributes) (to describe what
+  pointers are signed and how, to control implicit codegen in the backend, as
+  well as preserve invariants in the mid-level optimizer)
 
 The current implementation leverages the
 [Armv8.3-A PAuth/Pointer Authentication 
Code](#armv8-3-a-pauth-pointer-authentication-code)
@@ -287,6 +290,27 @@ but with the added guarantee that `%fp_i`, `%fp_auth`, and 
`%fp_auth_p`
 are not stored to (and reloaded from) memory.
 
 
+### Function Attributes
+
+Some function attributes are used to describe other pointer authentication
+operations that are not otherwise explicitly expressed in IR.
+
+#### ``ptrauth-indirect-gotos``
+
+``ptrauth-indirect-gotos`` specifies that indirect gotos in this function
+should authenticate their target.  At the IR level, no other change is needed.
+When lowering [``blockaddress`` 
constants](https://llvm.org/docs/LangRef.html#blockaddress),
+and [``indirectbr`` 
instructions](https://llvm.org/docs/LangRef.html#i-indirectbr),
+this tells the backend to respectively sign and authenticate the pointers.
+
+The specific scheme isn't ABI-visible.  Currently, the AArch64 backend
+signs blockaddresses using the `ASIA` key, with an integer discriminator
+derived from the parent function's name, using the SipHash stable 
discriminator:
+```
+  ptrauth_string_discriminator("<function_name> blockaddress")
+```
+
+
 ## AArch64 Support
 
 AArch64 is currently the only architecture with full support of the pointer
diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h 
b/llvm/include/llvm/CodeGen/AsmPrinter.h
index a60dce30c4a6c..290da83cee35f 100644
--- a/llvm/include/llvm/CodeGen/AsmPrinter.h
+++ b/llvm/include/llvm/CodeGen/AsmPrinter.h
@@ -577,6 +577,9 @@ class AsmPrinter : public MachineFunctionPass {
     report_fatal_error("ptrauth constant lowering not implemented");
   }
 
+  /// Lower the specified BlockAddress to an MCExpr.
+  virtual const MCExpr *lowerBlockAddressConstant(const BlockAddress &BA);
+
   /// Return true if the basic block has exactly one predecessor and the 
control
   /// transfer mechanism between the predecessor and this block is a
   /// fall-through.
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp 
b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index c52cbff689dc5..724b7cd94adee 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -3139,7 +3139,7 @@ const MCExpr *AsmPrinter::lowerConstant(const Constant 
*CV) {
     return MCSymbolRefExpr::create(getSymbol(GV), Ctx);
 
   if (const BlockAddress *BA = dyn_cast<BlockAddress>(CV))
-    return MCSymbolRefExpr::create(GetBlockAddressSymbol(BA), Ctx);
+    return lowerBlockAddressConstant(*BA);
 
   if (const auto *Equiv = dyn_cast<DSOLocalEquivalent>(CV))
     return getObjFileLowering().lowerDSOLocalEquivalent(Equiv, TM);
@@ -3821,6 +3821,10 @@ MCSymbol *AsmPrinter::GetBlockAddressSymbol(const 
BasicBlock *BB) const {
   return const_cast<AsmPrinter *>(this)->getAddrLabelSymbol(BB);
 }
 
+const MCExpr *AsmPrinter::lowerBlockAddressConstant(const BlockAddress &BA) {
+  return MCSymbolRefExpr::create(GetBlockAddressSymbol(&BA), OutContext);
+}
+
 /// GetCPISymbol - Return the symbol for the specified constant pool entry.
 MCSymbol *AsmPrinter::GetCPISymbol(unsigned CPID) const {
   if (getSubtargetInfo().getTargetTriple().isWindowsMSVCEnvironment()) {
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp 
b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 64d41d4147644..093208bf70cda 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -93,6 +93,8 @@ class AArch64AsmPrinter : public AsmPrinter {
 
   const MCExpr *lowerConstantPtrAuth(const ConstantPtrAuth &CPA) override;
 
+  const MCExpr *lowerBlockAddressConstant(const BlockAddress &BA) override;
+
   void emitStartOfAsmFile(Module &M) override;
   void emitJumpTableInfo() override;
   std::tuple<const MCSymbol *, uint64_t, const MCSymbol *,
@@ -128,7 +130,7 @@ class AArch64AsmPrinter : public AsmPrinter {
 
   void emitSled(const MachineInstr &MI, SledKind Kind);
 
-  // Emit the sequence for BLRA (authenticate + branch).
+  // Emit the sequence for BRA/BLRA (authenticate + branch/call).
   void emitPtrauthBranch(const MachineInstr *MI);
   // Emit the sequence to compute a discriminator into x17, or reuse AddrDisc.
   unsigned emitPtrauthDiscriminator(uint16_t Disc, unsigned AddrDisc,
@@ -1581,6 +1583,7 @@ unsigned 
AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
 
 void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
   unsigned InstsEmitted = 0;
+  bool IsCall = MI->getOpcode() == AArch64::BLRA;
   unsigned BrTarget = MI->getOperand(0).getReg();
 
   auto Key = (AArch64PACKey::ID)MI->getOperand(1).getImm();
@@ -1597,10 +1600,17 @@ void AArch64AsmPrinter::emitPtrauthBranch(const 
MachineInstr *MI) {
   bool IsZeroDisc = DiscReg == AArch64::XZR;
 
   unsigned Opc;
-  if (Key == AArch64PACKey::IA)
-    Opc = IsZeroDisc ? AArch64::BLRAAZ : AArch64::BLRAA;
-  else
-    Opc = IsZeroDisc ? AArch64::BLRABZ : AArch64::BLRAB;
+  if (IsCall) {
+    if (Key == AArch64PACKey::IA)
+      Opc = IsZeroDisc ? AArch64::BLRAAZ : AArch64::BLRAA;
+    else
+      Opc = IsZeroDisc ? AArch64::BLRABZ : AArch64::BLRAB;
+  } else {
+    if (Key == AArch64PACKey::IA)
+      Opc = IsZeroDisc ? AArch64::BRAAZ : AArch64::BRAA;
+    else
+      Opc = IsZeroDisc ? AArch64::BRABZ : AArch64::BRAB;
+  }
 
   MCInst BRInst;
   BRInst.setOpcode(Opc);
@@ -1866,6 +1876,20 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const 
MachineInstr &MI) {
   assert(STI->getInstrInfo()->getInstSizeInBytes(MI) >= InstsEmitted * 4);
 }
 
+const MCExpr *
+AArch64AsmPrinter::lowerBlockAddressConstant(const BlockAddress &BA) {
+  const MCExpr *BAE = AsmPrinter::lowerBlockAddressConstant(BA);
+  const Function &Fn = *BA.getFunction();
+
+  if (std::optional<uint16_t> BADisc =
+          STI->getPtrAuthBlockAddressDiscriminator(Fn))
+    return AArch64AuthMCExpr::create(BAE, *BADisc, AArch64PACKey::IA,
+                                     /* HasAddressDiversity= */ false,
+                                     OutContext);
+
+  return BAE;
+}
+
 // Simple pseudo-instructions have their lowering (with expansion to real
 // instructions) auto-generated.
 #include "AArch64GenMCPseudoLowering.inc"
@@ -2010,6 +2034,7 @@ void AArch64AsmPrinter::emitInstruction(const 
MachineInstr *MI) {
     LowerMOVaddrPAC(*MI);
     return;
 
+  case AArch64::BRA:
   case AArch64::BLRA:
     emitPtrauthBranch(MI);
     return;
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp 
b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index e3c5a143b2889..1a99a905a47d8 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -2516,6 +2516,10 @@ bool AArch64FastISel::selectIndirectBr(const Instruction 
*I) {
   if (AddrReg == 0)
     return false;
 
+  // Authenticated indirectbr is not implemented yet.
+  if (FuncInfo.MF->getFunction().hasFnAttribute("ptrauth-indirect-gotos"))
+    return false;
+
   // Emit the indirect branch.
   const MCInstrDesc &II = TII.get(AArch64::BR);
   AddrReg = constrainOperandRegClass(II, AddrReg,  II.getNumDefs());
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp 
b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e0c3cc5eddb82..8a373c3a46d66 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -85,6 +85,7 @@
 #include "llvm/Support/InstructionCost.h"
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Support/MathExtras.h"
+#include "llvm/Support/SipHash.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
@@ -509,6 +510,7 @@ AArch64TargetLowering::AArch64TargetLowering(const 
TargetMachine &TM,
   setOperationAction(ISD::SELECT_CC, MVT::f64, Custom);
   setOperationAction(ISD::BR_JT, MVT::Other, Custom);
   setOperationAction(ISD::JumpTable, MVT::i64, Custom);
+  setOperationAction(ISD::BRIND, MVT::Other, Custom);
   setOperationAction(ISD::SETCCCARRY, MVT::i64, Custom);
 
   setOperationAction(ISD::PtrAuthGlobalAddress, MVT::i64, Custom);
@@ -6694,6 +6696,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerJumpTable(Op, DAG);
   case ISD::BR_JT:
     return LowerBR_JT(Op, DAG);
+  case ISD::BRIND:
+    return LowerBRIND(Op, DAG);
   case ISD::ConstantPool:
     return LowerConstantPool(Op, DAG);
   case ISD::BlockAddress:
@@ -10685,6 +10689,27 @@ SDValue AArch64TargetLowering::LowerBR_JT(SDValue Op,
   return DAG.getNode(ISD::BRIND, DL, MVT::Other, JTInfo, SDValue(Dest, 0));
 }
 
+SDValue AArch64TargetLowering::LowerBRIND(SDValue Op,
+                                          SelectionDAG &DAG) const {
+  MachineFunction &MF = DAG.getMachineFunction();
+  std::optional<uint16_t> BADisc =
+      Subtarget->getPtrAuthBlockAddressDiscriminator(MF.getFunction());
+  if (!BADisc)
+    return SDValue();
+
+  SDLoc DL(Op);
+  SDValue Chain = Op.getOperand(0);
+  SDValue Dest = Op.getOperand(1);
+
+  SDValue Disc = DAG.getTargetConstant(*BADisc, DL, MVT::i64);
+  SDValue Key = DAG.getTargetConstant(AArch64PACKey::IA, DL, MVT::i32);
+  SDValue AddrDisc = DAG.getRegister(AArch64::XZR, MVT::i64);
+
+  SDNode *BrA = DAG.getMachineNode(AArch64::BRA, DL, MVT::Other,
+                                   {Dest, Key, Disc, AddrDisc, Chain});
+  return SDValue(BrA, 0);
+}
+
 SDValue AArch64TargetLowering::LowerConstantPool(SDValue Op,
                                                  SelectionDAG &DAG) const {
   ConstantPoolSDNode *CP = cast<ConstantPoolSDNode>(Op);
@@ -10704,15 +10729,36 @@ SDValue 
AArch64TargetLowering::LowerConstantPool(SDValue Op,
 
 SDValue AArch64TargetLowering::LowerBlockAddress(SDValue Op,
                                                SelectionDAG &DAG) const {
-  BlockAddressSDNode *BA = cast<BlockAddressSDNode>(Op);
+  BlockAddressSDNode *BAN = cast<BlockAddressSDNode>(Op);
+  const BlockAddress *BA = BAN->getBlockAddress();
+
+  if (std::optional<uint16_t> BADisc =
+          Subtarget->getPtrAuthBlockAddressDiscriminator(*BA->getFunction())) {
+    SDLoc DL(Op);
+
+    // This isn't cheap, but BRIND is rare.
+    SDValue TargetBA = DAG.getTargetBlockAddress(BA, BAN->getValueType(0));
+
+    SDValue Disc = DAG.getTargetConstant(*BADisc, DL, MVT::i64);
+
+    SDValue Key = DAG.getTargetConstant(AArch64PACKey::IA, DL, MVT::i32);
+    SDValue AddrDisc = DAG.getRegister(AArch64::XZR, MVT::i64);
+
+    SDNode *MOV =
+      DAG.getMachineNode(AArch64::MOVaddrPAC, DL, {MVT::Other, MVT::Glue},
+                         {TargetBA, Key, AddrDisc, Disc});
+    return DAG.getCopyFromReg(SDValue(MOV, 0), DL, AArch64::X16, MVT::i64,
+                              SDValue(MOV, 1));
+  }
+
   CodeModel::Model CM = getTargetMachine().getCodeModel();
   if (CM == CodeModel::Large && !Subtarget->isTargetMachO()) {
     if (!getTargetMachine().isPositionIndependent())
-      return getAddrLarge(BA, DAG);
+      return getAddrLarge(BAN, DAG);
   } else if (CM == CodeModel::Tiny) {
-    return getAddrTiny(BA, DAG);
+    return getAddrTiny(BAN, DAG);
   }
-  return getAddr(BA, DAG);
+  return getAddr(BAN, DAG);
 }
 
 SDValue AArch64TargetLowering::LowerDarwin_VASTART(SDValue Op,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h 
b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 047c852bb01d2..69386dac318ca 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1143,6 +1143,7 @@ class AArch64TargetLowering : public TargetLowering {
                          SelectionDAG &DAG) const;
   SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerBRIND(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerConstantPool(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerAAPCS_VASTART(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td 
b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 1e06d5fdc7562..a65e5c68a24ad 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1755,6 +1755,24 @@ let Predicates = [HasPAuth] in {
     let Uses = [SP];
   }
 
+  // BRA pseudo, generalized version of BRAA/BRAB/Z.
+  // This directly manipulates x16/x17, which are the only registers the OS
+  // guarantees are safe to use for sensitive operations.
+  def BRA : Pseudo<(outs), (ins GPR64noip:$Rn, i32imm:$Key, i64imm:$Disc,
+                                GPR64noip:$AddrDisc), []>, Sched<[]> {
+    let isCodeGenOnly = 1;
+    let hasNoSchedulingInfo = 1;
+    let hasSideEffects = 1;
+    let mayStore = 0;
+    let mayLoad = 0;
+    let isBranch = 1;
+    let isTerminator = 1;
+    let isBarrier = 1;
+    let isIndirectBranch = 1;
+    let Size = 12; // 4 fixed + 8 variable, to compute discriminator.
+    let Defs = [X17];
+  }
+
   let isReturn = 1, isTerminator = 1, isBarrier = 1 in {
     def RETAA   : AuthReturn<0b010, 0, "retaa">;
     def RETAB   : AuthReturn<0b010, 1, "retab">;
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp 
b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
index 1fad1d5ca6d7d..7275116e38f2e 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
@@ -24,6 +24,7 @@
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineScheduler.h"
 #include "llvm/IR/GlobalValue.h"
+#include "llvm/Support/SipHash.h"
 #include "llvm/TargetParser/AArch64TargetParser.h"
 
 using namespace llvm;
@@ -574,6 +575,16 @@ AArch64Subtarget::getAuthenticatedLRCheckMethod() const {
   return AArch64PAuth::AuthCheckMethod::None;
 }
 
+std::optional<uint16_t> AArch64Subtarget::getPtrAuthBlockAddressDiscriminator(
+    const Function &ParentFn) const {
+  if (!ParentFn.hasFnAttribute("ptrauth-indirect-gotos"))
+    return std::nullopt;
+  // We currently have one simple mechanism for all targets.
+  // This isn't ABI, so we can always do better in the future.
+  return getPointerAuthStableSipHash(
+      (Twine(ParentFn.getName()) + " blockaddress").str());
+}
+
 bool AArch64Subtarget::enableMachinePipeliner() const {
   return getSchedModel().hasInstrSchedModel();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h 
b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index 5faba09aa67bd..172beaaaabc01 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -412,6 +412,15 @@ class AArch64Subtarget final : public 
AArch64GenSubtargetInfo {
   /// Choose a method of checking LR before performing a tail call.
   AArch64PAuth::AuthCheckMethod getAuthenticatedLRCheckMethod() const;
 
+  /// Compute the integer discriminator for a given BlockAddress constant, if
+  /// blockaddress signing is enabled (using function attribute
+  /// "ptrauth-indirect-gotos").
+  /// Note that this assumes the discriminator is independent of the indirect
+  /// goto branch site itself, i.e., it's the same for all BlockAddresses in
+  /// a function.
+  std::optional<uint16_t>
+  getPtrAuthBlockAddressDiscriminator(const Function &ParentFn) const;
+
   const PseudoSourceValue *getAddressCheckPSV() const {
     return AddressCheckPSV.get();
   }
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp 
b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 9e0860934f777..b5787b6bd8b82 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -2547,6 +2547,15 @@ bool AArch64InstructionSelector::select(MachineInstr &I) 
{
     return selectCompareBranch(I, MF, MRI);
 
   case TargetOpcode::G_BRINDIRECT: {
+    if (std::optional<uint16_t> BADisc =
+            STI.getPtrAuthBlockAddressDiscriminator(MF.getFunction())) {
+      auto MI = MIB.buildInstr(AArch64::BRA, {}, {I.getOperand(0).getReg()});
+      MI.addImm(AArch64PACKey::IA);
+      MI.addImm(*BADisc);
+      MI.addReg(/*AddrDisc=*/AArch64::XZR);
+      I.eraseFromParent();
+      return constrainSelectedInstRegOperands(*MI, TII, TRI, RBI);
+    }
     I.setDesc(TII.get(AArch64::BR));
     return constrainSelectedInstRegOperands(I, TII, TRI, RBI);
   }
@@ -3461,6 +3470,23 @@ bool AArch64InstructionSelector::select(MachineInstr &I) 
{
     return true;
   }
   case TargetOpcode::G_BLOCK_ADDR: {
+    Function *BAFn = I.getOperand(1).getBlockAddress()->getFunction();
+    if (std::optional<uint16_t> BADisc =
+            STI.getPtrAuthBlockAddressDiscriminator(*BAFn)) {
+      MIB.buildInstr(TargetOpcode::IMPLICIT_DEF, {AArch64::X16}, {});
+      MIB.buildInstr(TargetOpcode::IMPLICIT_DEF, {AArch64::X17}, {});
+      MIB.buildInstr(AArch64::MOVaddrPAC)
+          .addBlockAddress(I.getOperand(1).getBlockAddress())
+          .addImm(AArch64PACKey::IA)
+          .addReg(/*AddrDisc=*/AArch64::XZR)
+          .addImm(*BADisc)
+          .constrainAllUses(TII, TRI, RBI);
+      MIB.buildCopy(I.getOperand(0).getReg(), Register(AArch64::X16));
+      RBI.constrainGenericRegister(I.getOperand(0).getReg(),
+                                   AArch64::GPR64RegClass, MRI);
+      I.eraseFromParent();
+      return true;
+    }
     if (TM.getCodeModel() == CodeModel::Large && !TM.isPositionIndependent()) {
       materializeLargeCMVal(I, I.getOperand(1).getBlockAddress(), 0);
       I.eraseFromParent();
diff --git a/llvm/test/CodeGen/AArch64/ptrauth-indirectbr.ll 
b/llvm/test/CodeGen/AArch64/ptrauth-indirectbr.ll
new file mode 100644
index 0000000000000..db49422cf1abb
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/ptrauth-indirectbr.ll
@@ -0,0 +1,106 @@
+; RUN: llc -mtriple arm64e-apple-darwin \
+; RUN:   -asm-verbose=false -aarch64-enable-collect-loh=false \
+; RUN:   -o - %s | FileCheck %s
+
+; RUN: llc -mtriple arm64e-apple-darwin \
+; RUN:   -global-isel -global-isel-abort=1 -verify-machineinstrs \
+; RUN:   -asm-verbose=false -aarch64-enable-collect-loh=false \
+; RUN:   -o - %s | FileCheck %s
+
+; The discriminator is the same for all blockaddresses in the function.
+; ptrauth_string_discriminator("test_blockaddress blockaddress") == 52152
+
+; CHECK-LABEL: _test_blockaddress:
+; CHECK:         adrp x16, [[F1BB1ADDR:Ltmp[0-9]+]]@PAGE
+; CHECK-NEXT:    add x16, x16, [[F1BB1ADDR]]@PAGEOFF
+; CHECK-NEXT:    mov x17, #[[F1DISCVAL:52152]]
+; CHECK-NEXT:    pacia x16, x17
+; CHECK-NEXT:    mov x0, x16
+; CHECK-NEXT:    adrp x16, [[F1BB2ADDR:Ltmp[0-9]+]]@PAGE
+; CHECK-NEXT:    add x16, x16, [[F1BB2ADDR]]@PAGEOFF
+; CHECK-NEXT:    mov x17, #[[F1DISCVAL]]
+; CHECK-NEXT:    pacia x16, x17
+; CHECK-NEXT:    mov x1, x16
+; CHECK-NEXT:    bl _dummy_choose
+; CHECK-NEXT:    mov x17, #[[F1DISCVAL]]
+; CHECK-NEXT:    braa x0, x17
+; CHECK:        [[F1BB1ADDR]]:
+; CHECK-NEXT:   [[F1BB1:LBB[0-9_]+]]:
+; CHECK-NEXT:    mov w0, #1
+; CHECK:        [[F1BB2ADDR]]:
+; CHECK-NEXT:   [[F1BB2:LBB[0-9_]+]]:
+; CHECK-NEXT:    mov w0, #2
+define i32 @test_blockaddress() #0 {
+entry:
+  %tmp0 = call ptr @dummy_choose(ptr blockaddress(@test_blockaddress, %bb1), 
ptr blockaddress(@test_blockaddress, %bb2))
+  indirectbr ptr %tmp0, [label %bb1, label %bb2]
+
+bb1:
+  ret i32 1
+
+bb2:
+  ret i32 2
+}
+
+; Test another function to compare the discriminator.
+; ptrauth_string_discriminator("test_blockaddress_2 blockaddress") == 22012
+
+; CHECK-LABEL: _test_blockaddress_2:
+; CHECK:         adrp x16, [[F2BB1ADDR:Ltmp[0-9]+]]@PAGE
+; CHECK-NEXT:    add x16, x16, [[F2BB1ADDR]]@PAGEOFF
+; CHECK-NEXT:    mov x17, #[[F2DISCVAL:22012]]
+; CHECK-NEXT:    pacia x16, x17
+; CHECK-NEXT:    mov x0, x16
+; CHECK-NEXT:    adrp x16, [[F2BB2ADDR:Ltmp[0-9]+]]@PAGE
+; CHECK-NEXT:    add x16, x16, [[F2BB2ADDR]]@PAGEOFF
+; CHECK-NEXT:    mov x17, #[[F2DISCVAL]]
+; CHECK-NEXT:    pacia x16, x17
+; CHECK-NEXT:    mov x1, x16
+; CHECK-NEXT:    bl _dummy_choose
+; CHECK-NEXT:    mov x17, #[[F2DISCVAL]]
+; CHECK-NEXT:    braa x0, x17
+; CHECK:        [[F2BB1ADDR]]:
+; CHECK-NEXT:   [[F2BB1:LBB[0-9_]+]]:
+; CHECK-NEXT:    mov w0, #1
+; CHECK:        [[F2BB2ADDR]]:
+; CHECK-NEXT:   [[F2BB2:LBB[0-9_]+]]:
+; CHECK-NEXT:    mov w0, #2
+define i32 @test_blockaddress_2() #0 {
+entry:
+  %tmp0 = call ptr @dummy_choose(ptr blockaddress(@test_blockaddress_2, %bb1), 
ptr blockaddress(@test_blockaddress_2, %bb2))
+  indirectbr ptr %tmp0, [label %bb1, label %bb2]
+
+bb1:
+  ret i32 1
+
+bb2:
+  ret i32 2
+}
+
+; CHECK-LABEL: _test_blockaddress_other_function:
+; CHECK:         adrp x16, [[F1BB1ADDR]]@PAGE
+; CHECK-NEXT:    add x16, x16, [[F1BB1ADDR]]@PAGEOFF
+; CHECK-NEXT:    mov x17, #[[F1DISCVAL]]
+; CHECK-NEXT:    pacia x16, x17
+; CHECK-NEXT:    mov x0, x16
+; CHECK-NEXT:    ret
+define ptr @test_blockaddress_other_function() #0 {
+  ret ptr blockaddress(@test_blockaddress, %bb1)
+}
+
+; CHECK-LABEL: .section __DATA,__const
+; CHECK-NEXT:  .globl _test_blockaddress_array
+; CHECK-NEXT:  .p2align 4
+; CHECK-NEXT:  _test_blockaddress_array:
+; CHECK-NEXT:   .quad [[F1BB1ADDR]]@AUTH(ia,[[F1DISCVAL]]
+; CHECK-NEXT:   .quad [[F1BB2ADDR]]@AUTH(ia,[[F1DISCVAL]]
+; CHECK-NEXT:   .quad [[F2BB1ADDR]]@AUTH(ia,[[F2DISCVAL]]
+; CHECK-NEXT:   .quad [[F2BB2ADDR]]@AUTH(ia,[[F2DISCVAL]]
+@test_blockaddress_array = constant [4 x ptr] [
+  ptr blockaddress(@test_blockaddress, %bb1), ptr 
blockaddress(@test_blockaddress, %bb2),
+  ptr blockaddress(@test_blockaddress_2, %bb1), ptr 
blockaddress(@test_blockaddress_2, %bb2)
+]
+
+declare ptr @dummy_choose(ptr, ptr)
+
+attributes #0 = { "ptrauth-indirect-gotos" nounwind }

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to