jdoerfert created this revision.
jdoerfert added reviewers: hfinkel, xbolva00, lebedev.ri, nikic, Tyker, 
rjmccall, spatel.
Herald added subscribers: bollu, hiraditya.
Herald added projects: clang, LLVM.

NOTE: This is a prototype not a finished patch!
NOTE: There is a mailing list discussion on this: 
http://lists.llvm.org/pipermail/llvm-dev/2019-December/137632.html

Complemantary to the assumption outliner prototype in D71692 
<https://reviews.llvm.org/D71692>, this patch
shows how we could simplify the code emitted for an alignemnt
assumption. The generated code is smaller, less fragile, and it makes it
easier to recognize the additional use as a "assumption use".

As mentioned in D71692 <https://reviews.llvm.org/D71692> and on the mailing 
list, we could adopt this
scheme, and similar schemes for other patterns, without adopting the
assumption outlining.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D71739

Files:
  clang/lib/CodeGen/CodeGenFunction.cpp
  clang/test/CodeGen/alloc-align-attr.c
  llvm/include/llvm/IR/IRBuilder.h
  llvm/lib/IR/IRBuilder.cpp
  llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
  llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
  llvm/test/Transforms/AlignmentFromAssumptions/simple.ll

Index: llvm/test/Transforms/AlignmentFromAssumptions/simple.ll
===================================================================
--- llvm/test/Transforms/AlignmentFromAssumptions/simple.ll
+++ llvm/test/Transforms/AlignmentFromAssumptions/simple.ll
@@ -6,10 +6,7 @@
 
 define i32 @foo(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i32 32)]
   %0 = load i32, i32* %a, align 4
   ret i32 %0
 
@@ -20,11 +17,7 @@
 
 define i32 @foo2(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %offsetptr = add i64 %ptrint, 24
-  %maskedptr = and i64 %offsetptr, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i32 32, i32 24)]
   %arrayidx = getelementptr inbounds i32, i32* %a, i64 2
   %0 = load i32, i32* %arrayidx, align 4
   ret i32 %0
@@ -36,11 +29,7 @@
 
 define i32 @foo2a(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %offsetptr = add i64 %ptrint, 28
-  %maskedptr = and i64 %offsetptr, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i32 32, i32 28)]
   %arrayidx = getelementptr inbounds i32, i32* %a, i64 -1
   %0 = load i32, i32* %arrayidx, align 4
   ret i32 %0
@@ -52,10 +41,7 @@
 
 define i32 @goo(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i32 32, i32 0)]
   %0 = load i32, i32* %a, align 4
   ret i32 %0
 
@@ -66,10 +52,7 @@
 
 define i32 @hoo(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i64 32, i32 0)]
   br label %for.body
 
 for.body:                                         ; preds = %entry, %for.body
@@ -100,10 +83,7 @@
 ;         load(a, i0+i1+i2+32)
 define void @hoo2(i32* nocapture %a, i64 %id, i64 %num) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i8 32, i64 0)]
   %id.mul = shl nsw i64 %id, 6
   %num.mul = shl nsw i64 %num, 6
   br label %for0.body
@@ -149,10 +129,7 @@
 
 define i32 @joo(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i8 32, i8 0)]
   br label %for.body
 
 for.body:                                         ; preds = %entry, %for.body
@@ -177,16 +154,13 @@
 
 define i32 @koo(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
   br label %for.body
 
 for.body:                                         ; preds = %entry, %for.body
   %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
   %r.06 = phi i32 [ 0, %entry ], [ %add, %for.body ]
   %arrayidx = getelementptr inbounds i32, i32* %a, i64 %indvars.iv
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i8 32, i8 0)]
   %0 = load i32, i32* %arrayidx, align 4
   %add = add nsw i32 %0, %r.06
   %indvars.iv.next = add i64 %indvars.iv, 4
@@ -205,10 +179,7 @@
 
 define i32 @koo2(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i128 32, i128 0)]
   br label %for.body
 
 for.body:                                         ; preds = %entry, %for.body
@@ -233,10 +204,7 @@
 
 define i32 @moo(i32* nocapture %a) nounwind uwtable {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i16 32)]
   %0 = bitcast i32* %a to i8*
   tail call void @llvm.memset.p0i8.i64(i8* align 4 %0, i8 0, i64 64, i1 false)
   ret i32 undef
@@ -248,15 +216,9 @@
 
 define i32 @moo2(i32* nocapture %a, i32* nocapture %b) nounwind uwtable {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
-  %ptrint1 = ptrtoint i32* %b to i64
-  %maskedptr3 = and i64 %ptrint1, 127
-  %maskcond4 = icmp eq i64 %maskedptr3, 0
-  tail call void @llvm.assume(i1 %maskcond4)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %b, i32 128)]
   %0 = bitcast i32* %a to i8*
+  tail call void @llvm.assume(i1 true) ["align"(i8* %0, i16 32)]
   %1 = bitcast i32* %b to i8*
   tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 %0, i8* align 4 %1, i64 64, i1 false)
   ret i32 undef
Index: llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
===================================================================
--- llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
+++ llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
@@ -204,102 +204,18 @@
   if (AlignOB.hasValue()) {
     assert(AlignOB.getValue().Inputs.size() >= 2);
     AAPtr = AlignOB.getValue().Inputs[0].get();
+    // TODO: Consider accumulating the offset to the base.
+    AAPtr = AAPtr->stripPointerCastsSameRepresentation();
     AlignSCEV = SE->getSCEV(AlignOB.getValue().Inputs[1].get());
+    AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty);
     if (AlignOB.getValue().Inputs.size() == 3)
       OffSCEV = SE->getSCEV(AlignOB.getValue().Inputs[2].get());
     else
       OffSCEV = SE->getZero(Int64Ty);
+    OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty);
     return true;
   }
-
-  // An alignment assume must be a statement about the least-significant
-  // bits of the pointer being zero, possibly with some offset.
-  ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0));
-  if (!ICI)
-    return false;
-
-  // This must be an expression of the form: x & m == 0.
-  if (ICI->getPredicate() != ICmpInst::ICMP_EQ)
-    return false;
-
-  // Swap things around so that the RHS is 0.
-  Value *CmpLHS = ICI->getOperand(0);
-  Value *CmpRHS = ICI->getOperand(1);
-  const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS);
-  const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS);
-  if (CmpLHSSCEV->isZero())
-    std::swap(CmpLHS, CmpRHS);
-  else if (!CmpRHSSCEV->isZero())
-    return false;
-
-  BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS);
-  if (!CmpBO || CmpBO->getOpcode() != Instruction::And)
-    return false;
-
-  // Swap things around so that the right operand of the and is a constant
-  // (the mask); we cannot deal with variable masks.
-  Value *AndLHS = CmpBO->getOperand(0);
-  Value *AndRHS = CmpBO->getOperand(1);
-  const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS);
-  const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS);
-  if (isa<SCEVConstant>(AndLHSSCEV)) {
-    std::swap(AndLHS, AndRHS);
-    std::swap(AndLHSSCEV, AndRHSSCEV);
-  }
-
-  const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV);
-  if (!MaskSCEV)
-    return false;
-
-  // The mask must have some trailing ones (otherwise the condition is
-  // trivial and tells us nothing about the alignment of the left operand).
-  unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes();
-  if (!TrailingOnes)
-    return false;
-
-  // Cap the alignment at the maximum with which LLVM can deal (and make sure
-  // we don't overflow the shift).
-  uint64_t Alignment;
-  TrailingOnes = std::min(TrailingOnes,
-    unsigned(sizeof(unsigned) * CHAR_BIT - 1));
-  Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment);
-
-  AlignSCEV = SE->getConstant(Int64Ty, Alignment);
-
-  // The LHS might be a ptrtoint instruction, or it might be the pointer
-  // with an offset.
-  AAPtr = nullptr;
-  OffSCEV = nullptr;
-  if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) {
-    AAPtr = PToI->getPointerOperand();
-    OffSCEV = SE->getZero(Int64Ty);
-  } else if (const SCEVAddExpr* AndLHSAddSCEV =
-             dyn_cast<SCEVAddExpr>(AndLHSSCEV)) {
-    // Try to find the ptrtoint; subtract it and the rest is the offset.
-    for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(),
-                                  JE = AndLHSAddSCEV->op_end();
-         J != JE; ++J)
-      if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J))
-        if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) {
-          AAPtr = PToI->getPointerOperand();
-          OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J);
-          break;
-        }
-  }
-
-  if (!AAPtr)
-    return false;
-
-  // Sign extend the offset to 64 bits (so that it is like all of the other
-  // expressions).
-  unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits();
-  if (OffSCEVBits < 64)
-    OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty);
-  else if (OffSCEVBits > 64)
-    return false;
-
-  AAPtr = AAPtr->stripPointerCasts();
-  return true;
+  return false;
 }
 
 bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
@@ -323,7 +239,6 @@
       continue;
 
     if (Instruction *K = dyn_cast<Instruction>(J))
-      if (isValidAssumeForContext(ACall, K, DT))
         WorkList.push_back(K);
   }
 
@@ -331,24 +246,30 @@
     Instruction *J = WorkList.pop_back_val();
 
     if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
+      if (!isValidAssumeForContext(ACall, J, DT))
+        continue;
       unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
-        LI->getPointerOperand(), SE);
+                                              LI->getPointerOperand(), SE);
 
       if (NewAlignment > LI->getAlignment()) {
         LI->setAlignment(MaybeAlign(NewAlignment));
         ++NumLoadAlignChanged;
       }
     } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
+      if (!isValidAssumeForContext(ACall, J, DT))
+        continue;
       unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
-        SI->getPointerOperand(), SE);
+                                              SI->getPointerOperand(), SE);
 
       if (NewAlignment > SI->getAlignment()) {
         SI->setAlignment(MaybeAlign(NewAlignment));
         ++NumStoreAlignChanged;
       }
     } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
-      unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
-        MI->getDest(), SE);
+      if (!isValidAssumeForContext(ACall, J, DT))
+        continue;
+      unsigned NewDestAlignment =
+          getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE);
 
       LLVM_DEBUG(dbgs() << "\tmem inst: " << NewDestAlignment << "\n";);
       if (NewDestAlignment > MI->getDestAlignment()) {
@@ -359,8 +280,8 @@
       // For memory transfers, there is also a source alignment that
       // can be set.
       if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
-        unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
-          MTI->getSource(), SE);
+        unsigned NewSrcAlignment =
+            getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE);
 
         LLVM_DEBUG(dbgs() << "\tmem trans: " << NewSrcAlignment << "\n";);
 
@@ -376,7 +297,7 @@
     Visited.insert(J);
     for (User *UJ : J->users()) {
       Instruction *K = cast<Instruction>(UJ);
-      if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT))
+      if (!Visited.count(K))
         WorkList.push_back(K);
     }
   }
Index: llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
===================================================================
--- llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3982,11 +3982,16 @@
     break;
   case Intrinsic::assume: {
     Value *IIOperand = II->getArgOperand(0);
+    SmallVector<OperandBundleDef, 4> OpBundles;
+    II->getOperandBundlesAsDefs(OpBundles);
+    bool HasOpBundles = !OpBundles.empty();
     // Remove an assume if it is followed by an identical assume.
     // TODO: Do we need this? Unless there are conflicting assumptions, the
     // computeKnownBits(IIOperand) below here eliminates redundant assumes.
     Instruction *Next = II->getNextNonDebugInstruction();
-    if (match(Next, m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand))))
+    if (HasOpBundles &&
+        match(Next, m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand))) &&
+        !cast<IntrinsicInst>(Next)->hasOperandBundles())
       return eraseInstFromFunction(CI);
 
     // Canonicalize assume(a && b) -> assume(a); assume(b);
@@ -3996,14 +4001,15 @@
     Value *AssumeIntrinsic = II->getCalledValue();
     Value *A, *B;
     if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) {
-      Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, A, II->getName());
+      Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, A, OpBundles,
+                         II->getName());
       Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, B, II->getName());
       return eraseInstFromFunction(*II);
     }
     // assume(!(a || b)) -> assume(!a); assume(!b);
     if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) {
       Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic,
-                         Builder.CreateNot(A), II->getName());
+                         Builder.CreateNot(A), OpBundles, II->getName());
       Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic,
                          Builder.CreateNot(B), II->getName());
       return eraseInstFromFunction(*II);
@@ -4019,7 +4025,8 @@
         isValidAssumeForContext(II, LHS, &DT)) {
       MDNode *MD = MDNode::get(II->getContext(), None);
       LHS->setMetadata(LLVMContext::MD_nonnull, MD);
-      return eraseInstFromFunction(*II);
+      if (!HasOpBundles)
+        return eraseInstFromFunction(*II);
 
       // TODO: apply nonnull return attributes to calls and invokes
       // TODO: apply range metadata for range check patterns?
@@ -4027,10 +4034,12 @@
 
     // If there is a dominating assume with the same condition as this one,
     // then this one is redundant, and should be removed.
-    KnownBits Known(1);
-    computeKnownBits(IIOperand, Known, 0, II);
-    if (Known.isAllOnes())
-      return eraseInstFromFunction(*II);
+    if (!HasOpBundles) {
+      KnownBits Known(1);
+      computeKnownBits(IIOperand, Known, 0, II);
+      if (Known.isAllOnes())
+        return eraseInstFromFunction(*II);
+    }
 
     // Update the cache of affected values for this assumption (we might be
     // here because we just simplified the condition).
@@ -4048,7 +4057,7 @@
       II->setOperand(2, ConstantInt::get(OpIntTy, GCR.getBasePtrIndex()));
       return II;
     }
-    
+
     // Translate facts known about a pointer before relocating into
     // facts about the relocate value, while being careful to
     // preserve relocation semantics.
Index: llvm/lib/IR/IRBuilder.cpp
===================================================================
--- llvm/lib/IR/IRBuilder.cpp
+++ llvm/lib/IR/IRBuilder.cpp
@@ -71,11 +71,12 @@
   return BCI;
 }
 
-static CallInst *createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
-                                  IRBuilderBase *Builder,
-                                  const Twine &Name = "",
-                                  Instruction *FMFSource = nullptr) {
-  CallInst *CI = CallInst::Create(Callee, Ops, Name);
+static CallInst *
+createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
+                 IRBuilderBase *Builder, const Twine &Name = "",
+                 Instruction *FMFSource = nullptr,
+                 ArrayRef<OperandBundleDef> OpBundles = llvm::None) {
+  CallInst *CI = CallInst::Create(Callee, Ops, OpBundles, Name);
   if (FMFSource)
     CI->copyFastMathFlags(FMFSource);
   Builder->GetInsertBlock()->getInstList().insert(Builder->GetInsertPoint(),CI);
@@ -457,14 +458,16 @@
   return createCallHelper(TheFn, Ops, this);
 }
 
-CallInst *IRBuilderBase::CreateAssumption(Value *Cond) {
+CallInst *
+IRBuilderBase::CreateAssumption(Value *Cond,
+                                ArrayRef<OperandBundleDef> OpBundles) {
   assert(Cond->getType() == getInt1Ty() &&
          "an assumption condition must be of type i1");
 
   Value *Ops[] = { Cond };
   Module *M = BB->getParent()->getParent();
   Function *FnAssume = Intrinsic::getDeclaration(M, Intrinsic::assume);
-  return createCallHelper(FnAssume, Ops, this);
+  return createCallHelper(FnAssume, Ops, this, "", nullptr, OpBundles);
 }
 
 /// Create a call to a Masked Load intrinsic.
Index: llvm/include/llvm/IR/IRBuilder.h
===================================================================
--- llvm/include/llvm/IR/IRBuilder.h
+++ llvm/include/llvm/IR/IRBuilder.h
@@ -692,7 +692,11 @@
 
   /// Create an assume intrinsic call that allows the optimizer to
   /// assume that the provided condition will be true.
-  CallInst *CreateAssumption(Value *Cond);
+  ///
+  /// The optional argument \p OpBundles specifies operand bundles that are
+  /// added to the call instruction.
+  CallInst *CreateAssumption(Value *Cond,
+                             ArrayRef<OperandBundleDef> OpBundles = llvm::None);
 
   /// Create a call to the experimental.gc.statepoint intrinsic to
   /// start a new statepoint sequence.
@@ -2704,35 +2708,16 @@
 
 private:
   /// Helper function that creates an assume intrinsic call that
-  /// represents an alignment assumption on the provided Ptr, Mask, Type
-  /// and Offset. It may be sometimes useful to do some other logic
-  /// based on this alignment check, thus it can be stored into 'TheCheck'.
+  /// represents an alignment assumption on the provided pointer \p PtrValue
+  /// with offset \p OffsetValue and alignment value \p AlignValue.
   CallInst *CreateAlignmentAssumptionHelper(const DataLayout &DL,
-                                            Value *PtrValue, Value *Mask,
-                                            Type *IntPtrTy, Value *OffsetValue,
-                                            Value **TheCheck) {
-    Value *PtrIntValue = CreatePtrToInt(PtrValue, IntPtrTy, "ptrint");
-
-    if (OffsetValue) {
-      bool IsOffsetZero = false;
-      if (const auto *CI = dyn_cast<ConstantInt>(OffsetValue))
-        IsOffsetZero = CI->isZero();
-
-      if (!IsOffsetZero) {
-        if (OffsetValue->getType() != IntPtrTy)
-          OffsetValue = CreateIntCast(OffsetValue, IntPtrTy, /*isSigned*/ true,
-                                      "offsetcast");
-        PtrIntValue = CreateSub(PtrIntValue, OffsetValue, "offsetptr");
-      }
-    }
-
-    Value *Zero = ConstantInt::get(IntPtrTy, 0);
-    Value *MaskedPtr = CreateAnd(PtrIntValue, Mask, "maskedptr");
-    Value *InvCond = CreateICmpEQ(MaskedPtr, Zero, "maskcond");
-    if (TheCheck)
-      *TheCheck = InvCond;
-
-    return CreateAssumption(InvCond);
+                                            Value *PtrValue, Value *AlignValue,
+                                            Value *OffsetValue) {
+    SmallVector<Value *, 2> Vals({PtrValue, AlignValue});
+    if (OffsetValue)
+      Vals.push_back(OffsetValue);
+    OperandBundleDefT<Value *> AlignOpB("align", Vals);
+    return CreateAssumption(ConstantInt::getTrue(getContext()), {AlignOpB});
   }
 
 public:
@@ -2742,22 +2727,17 @@
   /// An optional offset can be provided, and if it is provided, the offset
   /// must be subtracted from the provided pointer to get the pointer with the
   /// specified alignment.
-  ///
-  /// It may be sometimes useful to do some other logic
-  /// based on this alignment check, thus it can be stored into 'TheCheck'.
   CallInst *CreateAlignmentAssumption(const DataLayout &DL, Value *PtrValue,
                                       unsigned Alignment,
-                                      Value *OffsetValue = nullptr,
-                                      Value **TheCheck = nullptr) {
+                                      Value *OffsetValue = nullptr) {
     assert(isa<PointerType>(PtrValue->getType()) &&
            "trying to create an alignment assumption on a non-pointer?");
     assert(Alignment != 0 && "Invalid Alignment");
     auto *PtrTy = cast<PointerType>(PtrValue->getType());
     Type *IntPtrTy = getIntPtrTy(DL, PtrTy->getAddressSpace());
-
-    Value *Mask = ConstantInt::get(IntPtrTy, Alignment - 1);
-    return CreateAlignmentAssumptionHelper(DL, PtrValue, Mask, IntPtrTy,
-                                           OffsetValue, TheCheck);
+    Value *AlignValue = ConstantInt::get(IntPtrTy, Alignment);
+    return CreateAlignmentAssumptionHelper(DL, PtrValue, AlignValue,
+                                           OffsetValue);
   }
 
   /// Create an assume intrinsic call that represents an alignment
@@ -2767,15 +2747,11 @@
   /// must be subtracted from the provided pointer to get the pointer with the
   /// specified alignment.
   ///
-  /// It may be sometimes useful to do some other logic
-  /// based on this alignment check, thus it can be stored into 'TheCheck'.
-  ///
   /// This overload handles the condition where the Alignment is dependent
   /// on an existing value rather than a static value.
   CallInst *CreateAlignmentAssumption(const DataLayout &DL, Value *PtrValue,
                                       Value *Alignment,
-                                      Value *OffsetValue = nullptr,
-                                      Value **TheCheck = nullptr) {
+                                      Value *OffsetValue = nullptr) {
     assert(isa<PointerType>(PtrValue->getType()) &&
            "trying to create an alignment assumption on a non-pointer?");
     auto *PtrTy = cast<PointerType>(PtrValue->getType());
@@ -2785,10 +2761,8 @@
       Alignment = CreateIntCast(Alignment, IntPtrTy, /*isSigned*/ false,
                                 "alignmentcast");
 
-    Value *Mask = CreateSub(Alignment, ConstantInt::get(IntPtrTy, 1), "mask");
-
-    return CreateAlignmentAssumptionHelper(DL, PtrValue, Mask, IntPtrTy,
-                                           OffsetValue, TheCheck);
+    return CreateAlignmentAssumptionHelper(DL, PtrValue, Alignment,
+                                           OffsetValue);
   }
 };
 
Index: clang/test/CodeGen/alloc-align-attr.c
===================================================================
--- clang/test/CodeGen/alloc-align-attr.c
+++ clang/test/CodeGen/alloc-align-attr.c
@@ -6,26 +6,18 @@
 __INT32_TYPE__ test1(__INT32_TYPE__ a) {
 // CHECK: define i32 @test1
   return *m1(a);
-// CHECK: call i32* @m1(i32 [[PARAM1:%[^\)]+]])
+// CHECK: [[CALL1:%.+]] = call i32* @m1(i32 [[PARAM1:%[^\)]+]])
 // CHECK: [[ALIGNCAST1:%.+]] = zext i32 [[PARAM1]] to i64
-// CHECK: [[MASK1:%.+]] = sub i64 [[ALIGNCAST1]], 1
-// CHECK: [[PTRINT1:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR1:%.+]] = and i64 [[PTRINT1]], [[MASK1]]
-// CHECK: [[MASKCOND1:%.+]] = icmp eq i64 [[MASKEDPTR1]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND1]])
+// CHECK: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL1]], i64 [[ALIGNCAST1]]) ]
 }
 // Condition where test2 param needs casting.
 __INT32_TYPE__ test2(__SIZE_TYPE__ a) {
 // CHECK: define i32 @test2
   return *m1(a);
 // CHECK: [[CONV2:%.+]] = trunc i64 %{{.+}} to i32
-// CHECK: call i32* @m1(i32 [[CONV2]])
+// CHECK: [[CALL2:%.+]] = call i32* @m1(i32 [[CONV2]])
 // CHECK: [[ALIGNCAST2:%.+]] = zext i32 [[CONV2]] to i64
-// CHECK: [[MASK2:%.+]] = sub i64 [[ALIGNCAST2]], 1
-// CHECK: [[PTRINT2:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR2:%.+]] = and i64 [[PTRINT2]], [[MASK2]]
-// CHECK: [[MASKCOND2:%.+]] = icmp eq i64 [[MASKEDPTR2]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND2]])
+// CHECK: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL2]], i64 [[ALIGNCAST2]]) ]
 }
 __INT32_TYPE__ *m2(__SIZE_TYPE__ i) __attribute__((alloc_align(1)));
 
@@ -34,24 +26,16 @@
 // CHECK: define i32 @test3
   return *m2(a);
 // CHECK: [[CONV3:%.+]] = sext i32 %{{.+}} to i64
-// CHECK: call i32* @m2(i64 [[CONV3]])
-// CHECK: [[MASK3:%.+]] = sub i64 [[CONV3]], 1
-// CHECK: [[PTRINT3:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR3:%.+]] = and i64 [[PTRINT3]], [[MASK3]]
-// CHECK: [[MASKCOND3:%.+]] = icmp eq i64 [[MASKEDPTR3]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND3]])
+// CHECK: [[CALL3:%.+]] = call i32* @m2(i64 [[CONV3]])
+// CHECK: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL3]], i64 [[CONV3]]) ]
 }
 
 // Every type matches, canonical example.
 __INT32_TYPE__ test4(__SIZE_TYPE__ a) {
 // CHECK: define i32 @test4
   return *m2(a);
-// CHECK: call i32* @m2(i64 [[PARAM4:%[^\)]+]])
-// CHECK: [[MASK4:%.+]] = sub i64 [[PARAM4]], 1
-// CHECK: [[PTRINT4:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR4:%.+]] = and i64 [[PTRINT4]], [[MASK4]]
-// CHECK: [[MASKCOND4:%.+]] = icmp eq i64 [[MASKEDPTR4]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND4]])
+// CHECK: [[CALL4:%.+]] = call i32* @m2(i64 [[PARAM4:%[^\)]+]])
+// CHECK: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL4]], i64 [[PARAM4]]) ]
 }
 
 
@@ -64,13 +48,9 @@
 // CHECK: define i32 @test5
   struct Empty e;
   return *m3(e, a);
-// CHECK: call i32* @m3(i64 %{{.*}}, i64 %{{.*}})
+// CHECK: [[CALL5:%.+]] = call i32* @m3(i64 %{{.*}}, i64 %{{.*}})
 // CHECK: [[ALIGNCAST5:%.+]] = trunc i128 %{{.*}} to i64
-// CHECK: [[MASK5:%.+]] = sub i64 [[ALIGNCAST5]], 1
-// CHECK: [[PTRINT5:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR5:%.+]] = and i64 [[PTRINT5]], [[MASK5]]
-// CHECK: [[MASKCOND5:%.+]] = icmp eq i64 [[MASKEDPTR5]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND5]])
+// CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL5]], i64 [[ALIGNCAST5]]) ]
 }
 // Struct parameter takes up 2 parameters, 'i' takes up 2.
 __INT32_TYPE__ *m4(struct MultiArgs s, __int128_t i) __attribute__((alloc_align(2)));
@@ -78,12 +58,8 @@
 // CHECK: define i32 @test6
   struct MultiArgs e;
   return *m4(e, a);
-// CHECK: call i32* @m4(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
-// CHECK: [[ALIGNCAST6:%.+]] = trunc i128 %{{.*}} to i64
-// CHECK: [[MASK6:%.+]] = sub i64 [[ALIGNCAST6]], 1
-// CHECK: [[PTRINT6:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR6:%.+]] = and i64 [[PTRINT6]], [[MASK6]]
-// CHECK: [[MASKCOND6:%.+]] = icmp eq i64 [[MASKEDPTR6]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND6]])
+// CHECK: [[CALL6:%.+]] = call i32* @m4(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
+// CHECK-NEXT: [[ALIGNCAST6:%.+]] = trunc i128 %{{.*}} to i64
+// CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL6]], i64 [[ALIGNCAST6]]) ]
 }
 
Index: clang/lib/CodeGen/CodeGenFunction.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.cpp
+++ clang/lib/CodeGen/CodeGenFunction.cpp
@@ -2126,13 +2126,35 @@
                                               SourceLocation AssumptionLoc,
                                               llvm::Value *Alignment,
                                               llvm::Value *OffsetValue) {
-  llvm::Value *TheCheck;
   llvm::Instruction *Assumption = Builder.CreateAlignmentAssumption(
-      CGM.getDataLayout(), PtrValue, Alignment, OffsetValue, &TheCheck);
-  if (SanOpts.has(SanitizerKind::Alignment)) {
-    EmitAlignmentAssumptionCheck(PtrValue, Ty, Loc, AssumptionLoc, Alignment,
-                                 OffsetValue, TheCheck, Assumption);
+      CGM.getDataLayout(), PtrValue, Alignment, OffsetValue);
+  if (!SanOpts.has(SanitizerKind::Alignment))
+    return;
+
+  llvm::Value *PtrIntValue =
+      Builder.CreatePtrToInt(PtrValue, IntPtrTy, "ptrint");
+
+  if (OffsetValue) {
+    bool IsOffsetZero = false;
+    if (const auto *CI = dyn_cast<llvm::ConstantInt>(OffsetValue))
+      IsOffsetZero = CI->isZero();
+
+    if (!IsOffsetZero) {
+      if (OffsetValue->getType() != IntPtrTy)
+        OffsetValue = Builder.CreateIntCast(OffsetValue, IntPtrTy,
+                                            /*isSigned*/ true, "offsetcast");
+      PtrIntValue = Builder.CreateSub(PtrIntValue, OffsetValue, "offsetptr");
+    }
   }
+
+  llvm::Value *Zero = llvm::ConstantInt::get(IntPtrTy, 0);
+  llvm::Value *Mask =
+      Builder.CreateSub(Alignment, llvm::ConstantInt::get(IntPtrTy, 1));
+  llvm::Value *MaskedPtr = Builder.CreateAnd(PtrIntValue, Mask, "maskedptr");
+  llvm::Value *TheCheck = Builder.CreateICmpEQ(MaskedPtr, Zero, "maskcond");
+
+  EmitAlignmentAssumptionCheck(PtrValue, Ty, Loc, AssumptionLoc, Alignment,
+                               OffsetValue, TheCheck, Assumption);
 }
 
 void CodeGenFunction::EmitAlignmentAssumption(llvm::Value *PtrValue,
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D71739: [WI... Johannes Doerfert via Phabricator via cfe-commits

Reply via email to