jdoerfert updated this revision to Diff 234817.
jdoerfert added a comment.

Rebase


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D71739/new/

https://reviews.llvm.org/D71739

Files:
  clang/lib/CodeGen/CodeGenFunction.cpp
  clang/test/CodeGen/alloc-align-attr.c
  llvm/include/llvm/IR/IRBuilder.h
  llvm/lib/IR/IRBuilder.cpp
  llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
  llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
  llvm/test/Transforms/AlignmentFromAssumptions/simple.ll

Index: llvm/test/Transforms/AlignmentFromAssumptions/simple.ll
===================================================================
--- llvm/test/Transforms/AlignmentFromAssumptions/simple.ll
+++ llvm/test/Transforms/AlignmentFromAssumptions/simple.ll
@@ -4,10 +4,7 @@
 
 define i32 @foo(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i32 32)]
   %0 = load i32, i32* %a, align 4
   ret i32 %0
 
@@ -18,11 +15,7 @@
 
 define i32 @foo2(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %offsetptr = add i64 %ptrint, 24
-  %maskedptr = and i64 %offsetptr, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i32 32, i32 24)]
   %arrayidx = getelementptr inbounds i32, i32* %a, i64 2
   %0 = load i32, i32* %arrayidx, align 4
   ret i32 %0
@@ -34,11 +27,7 @@
 
 define i32 @foo2a(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %offsetptr = add i64 %ptrint, 28
-  %maskedptr = and i64 %offsetptr, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i32 32, i32 28)]
   %arrayidx = getelementptr inbounds i32, i32* %a, i64 -1
   %0 = load i32, i32* %arrayidx, align 4
   ret i32 %0
@@ -50,10 +39,7 @@
 
 define i32 @goo(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i32 32, i32 0)]
   %0 = load i32, i32* %a, align 4
   ret i32 %0
 
@@ -64,10 +50,7 @@
 
 define i32 @hoo(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i64 32, i32 0)]
   br label %for.body
 
 for.body:                                         ; preds = %entry, %for.body
@@ -98,10 +81,7 @@
 ;         load(a, i0+i1+i2+32)
 define void @hoo2(i32* nocapture %a, i64 %id, i64 %num) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i8 32, i64 0)]
   %id.mul = shl nsw i64 %id, 6
   %num.mul = shl nsw i64 %num, 6
   br label %for0.body
@@ -147,10 +127,7 @@
 
 define i32 @joo(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i8 32, i8 0)]
   br label %for.body
 
 for.body:                                         ; preds = %entry, %for.body
@@ -175,16 +152,13 @@
 
 define i32 @koo(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
   br label %for.body
 
 for.body:                                         ; preds = %entry, %for.body
   %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
   %r.06 = phi i32 [ 0, %entry ], [ %add, %for.body ]
   %arrayidx = getelementptr inbounds i32, i32* %a, i64 %indvars.iv
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i8 32, i8 0)]
   %0 = load i32, i32* %arrayidx, align 4
   %add = add nsw i32 %0, %r.06
   %indvars.iv.next = add i64 %indvars.iv, 4
@@ -203,10 +177,7 @@
 
 define i32 @koo2(i32* nocapture %a) nounwind uwtable readonly {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i128 32, i128 0)]
   br label %for.body
 
 for.body:                                         ; preds = %entry, %for.body
@@ -231,10 +202,7 @@
 
 define i32 @moo(i32* nocapture %a) nounwind uwtable {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %a, i16 32)]
   %0 = bitcast i32* %a to i8*
   tail call void @llvm.memset.p0i8.i64(i8* align 4 %0, i8 0, i64 64, i1 false)
   ret i32 undef
@@ -246,15 +214,9 @@
 
 define i32 @moo2(i32* nocapture %a, i32* nocapture %b) nounwind uwtable {
 entry:
-  %ptrint = ptrtoint i32* %a to i64
-  %maskedptr = and i64 %ptrint, 31
-  %maskcond = icmp eq i64 %maskedptr, 0
-  tail call void @llvm.assume(i1 %maskcond)
-  %ptrint1 = ptrtoint i32* %b to i64
-  %maskedptr3 = and i64 %ptrint1, 127
-  %maskcond4 = icmp eq i64 %maskedptr3, 0
-  tail call void @llvm.assume(i1 %maskcond4)
+  tail call void @llvm.assume(i1 true) ["align"(i32* %b, i32 128)]
   %0 = bitcast i32* %a to i8*
+  tail call void @llvm.assume(i1 true) ["align"(i8* %0, i16 32)]
   %1 = bitcast i32* %b to i8*
   tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 %0, i8* align 4 %1, i64 64, i1 false)
   ret i32 undef
Index: llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
===================================================================
--- llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
+++ llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
@@ -199,94 +199,23 @@
                                                         Value *&AAPtr,
                                                         const SCEV *&AlignSCEV,
                                                         const SCEV *&OffSCEV) {
-  // An alignment assume must be a statement about the least-significant
-  // bits of the pointer being zero, possibly with some offset.
-  ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0));
-  if (!ICI)
-    return false;
-
-  // This must be an expression of the form: x & m == 0.
-  if (ICI->getPredicate() != ICmpInst::ICMP_EQ)
-    return false;
-
-  // Swap things around so that the RHS is 0.
-  Value *CmpLHS = ICI->getOperand(0);
-  Value *CmpRHS = ICI->getOperand(1);
-  const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS);
-  const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS);
-  if (CmpLHSSCEV->isZero())
-    std::swap(CmpLHS, CmpRHS);
-  else if (!CmpRHSSCEV->isZero())
-    return false;
-
-  BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS);
-  if (!CmpBO || CmpBO->getOpcode() != Instruction::And)
-    return false;
-
-  // Swap things around so that the right operand of the and is a constant
-  // (the mask); we cannot deal with variable masks.
-  Value *AndLHS = CmpBO->getOperand(0);
-  Value *AndRHS = CmpBO->getOperand(1);
-  const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS);
-  const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS);
-  if (isa<SCEVConstant>(AndLHSSCEV)) {
-    std::swap(AndLHS, AndRHS);
-    std::swap(AndLHSSCEV, AndRHSSCEV);
+  Type *Int64Ty = Type::getInt64Ty(I->getContext());
+  Optional<OperandBundleUse> AlignOB = I->getOperandBundle("align");
+  if (AlignOB.hasValue()) {
+    assert(AlignOB.getValue().Inputs.size() >= 2);
+    AAPtr = AlignOB.getValue().Inputs[0].get();
+    // TODO: Consider accumulating the offset to the base.
+    AAPtr = AAPtr->stripPointerCastsSameRepresentation();
+    AlignSCEV = SE->getSCEV(AlignOB.getValue().Inputs[1].get());
+    AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty);
+    if (AlignOB.getValue().Inputs.size() == 3)
+      OffSCEV = SE->getSCEV(AlignOB.getValue().Inputs[2].get());
+    else
+      OffSCEV = SE->getZero(Int64Ty);
+    OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty);
+    return true;
   }
-
-  const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV);
-  if (!MaskSCEV)
-    return false;
-
-  // The mask must have some trailing ones (otherwise the condition is
-  // trivial and tells us nothing about the alignment of the left operand).
-  unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes();
-  if (!TrailingOnes)
-    return false;
-
-  // Cap the alignment at the maximum with which LLVM can deal (and make sure
-  // we don't overflow the shift).
-  uint64_t Alignment;
-  TrailingOnes = std::min(TrailingOnes,
-    unsigned(sizeof(unsigned) * CHAR_BIT - 1));
-  Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment);
-
-  Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext());
-  AlignSCEV = SE->getConstant(Int64Ty, Alignment);
-
-  // The LHS might be a ptrtoint instruction, or it might be the pointer
-  // with an offset.
-  AAPtr = nullptr;
-  OffSCEV = nullptr;
-  if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) {
-    AAPtr = PToI->getPointerOperand();
-    OffSCEV = SE->getZero(Int64Ty);
-  } else if (const SCEVAddExpr* AndLHSAddSCEV =
-             dyn_cast<SCEVAddExpr>(AndLHSSCEV)) {
-    // Try to find the ptrtoint; subtract it and the rest is the offset.
-    for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(),
-         JE = AndLHSAddSCEV->op_end(); J != JE; ++J)
-      if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J))
-        if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) {
-          AAPtr = PToI->getPointerOperand();
-          OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J);
-          break;
-        }
-  }
-
-  if (!AAPtr)
-    return false;
-
-  // Sign extend the offset to 64 bits (so that it is like all of the other
-  // expressions).
-  unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits();
-  if (OffSCEVBits < 64)
-    OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty);
-  else if (OffSCEVBits > 64)
-    return false;
-
-  AAPtr = AAPtr->stripPointerCasts();
-  return true;
+  return false;
 }
 
 bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
@@ -310,7 +239,6 @@
       continue;
 
     if (Instruction *K = dyn_cast<Instruction>(J))
-      if (isValidAssumeForContext(ACall, K, DT))
         WorkList.push_back(K);
   }
 
@@ -318,24 +246,30 @@
     Instruction *J = WorkList.pop_back_val();
 
     if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
+      if (!isValidAssumeForContext(ACall, J, DT))
+        continue;
       unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
-        LI->getPointerOperand(), SE);
+                                              LI->getPointerOperand(), SE);
 
       if (NewAlignment > LI->getAlignment()) {
         LI->setAlignment(MaybeAlign(NewAlignment));
         ++NumLoadAlignChanged;
       }
     } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
+      if (!isValidAssumeForContext(ACall, J, DT))
+        continue;
       unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
-        SI->getPointerOperand(), SE);
+                                              SI->getPointerOperand(), SE);
 
       if (NewAlignment > SI->getAlignment()) {
         SI->setAlignment(MaybeAlign(NewAlignment));
         ++NumStoreAlignChanged;
       }
     } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
-      unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
-        MI->getDest(), SE);
+      if (!isValidAssumeForContext(ACall, J, DT))
+        continue;
+      unsigned NewDestAlignment =
+          getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE);
 
       LLVM_DEBUG(dbgs() << "\tmem inst: " << NewDestAlignment << "\n";);
       if (NewDestAlignment > MI->getDestAlignment()) {
@@ -346,8 +280,8 @@
       // For memory transfers, there is also a source alignment that
       // can be set.
       if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
-        unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
-          MTI->getSource(), SE);
+        unsigned NewSrcAlignment =
+            getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE);
 
         LLVM_DEBUG(dbgs() << "\tmem trans: " << NewSrcAlignment << "\n";);
 
@@ -363,7 +297,7 @@
     Visited.insert(J);
     for (User *UJ : J->users()) {
       Instruction *K = cast<Instruction>(UJ);
-      if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT))
+      if (!Visited.count(K))
         WorkList.push_back(K);
     }
   }
Index: llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
===================================================================
--- llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3982,11 +3982,16 @@
     break;
   case Intrinsic::assume: {
     Value *IIOperand = II->getArgOperand(0);
+    SmallVector<OperandBundleDef, 4> OpBundles;
+    II->getOperandBundlesAsDefs(OpBundles);
+    bool HasOpBundles = !OpBundles.empty();
     // Remove an assume if it is followed by an identical assume.
     // TODO: Do we need this? Unless there are conflicting assumptions, the
     // computeKnownBits(IIOperand) below here eliminates redundant assumes.
     Instruction *Next = II->getNextNonDebugInstruction();
-    if (match(Next, m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand))))
+    if (HasOpBundles &&
+        match(Next, m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand))) &&
+        !cast<IntrinsicInst>(Next)->hasOperandBundles())
       return eraseInstFromFunction(CI);
 
     // Canonicalize assume(a && b) -> assume(a); assume(b);
@@ -3996,14 +4001,15 @@
     Value *AssumeIntrinsic = II->getCalledValue();
     Value *A, *B;
     if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) {
-      Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, A, II->getName());
+      Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, A, OpBundles,
+                         II->getName());
       Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, B, II->getName());
       return eraseInstFromFunction(*II);
     }
     // assume(!(a || b)) -> assume(!a); assume(!b);
     if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) {
       Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic,
-                         Builder.CreateNot(A), II->getName());
+                         Builder.CreateNot(A), OpBundles, II->getName());
       Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic,
                          Builder.CreateNot(B), II->getName());
       return eraseInstFromFunction(*II);
@@ -4019,7 +4025,8 @@
         isValidAssumeForContext(II, LHS, &DT)) {
       MDNode *MD = MDNode::get(II->getContext(), None);
       LHS->setMetadata(LLVMContext::MD_nonnull, MD);
-      return eraseInstFromFunction(*II);
+      if (!HasOpBundles)
+        return eraseInstFromFunction(*II);
 
       // TODO: apply nonnull return attributes to calls and invokes
       // TODO: apply range metadata for range check patterns?
@@ -4027,10 +4034,12 @@
 
     // If there is a dominating assume with the same condition as this one,
     // then this one is redundant, and should be removed.
-    KnownBits Known(1);
-    computeKnownBits(IIOperand, Known, 0, II);
-    if (Known.isAllOnes())
-      return eraseInstFromFunction(*II);
+    if (!HasOpBundles) {
+      KnownBits Known(1);
+      computeKnownBits(IIOperand, Known, 0, II);
+      if (Known.isAllOnes())
+        return eraseInstFromFunction(*II);
+    }
 
     // Update the cache of affected values for this assumption (we might be
     // here because we just simplified the condition).
@@ -4048,7 +4057,7 @@
       II->setOperand(2, ConstantInt::get(OpIntTy, GCR.getBasePtrIndex()));
       return II;
     }
-    
+
     // Translate facts known about a pointer before relocating into
     // facts about the relocate value, while being careful to
     // preserve relocation semantics.
Index: llvm/lib/IR/IRBuilder.cpp
===================================================================
--- llvm/lib/IR/IRBuilder.cpp
+++ llvm/lib/IR/IRBuilder.cpp
@@ -71,11 +71,12 @@
   return BCI;
 }
 
-static CallInst *createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
-                                  IRBuilderBase *Builder,
-                                  const Twine &Name = "",
-                                  Instruction *FMFSource = nullptr) {
-  CallInst *CI = CallInst::Create(Callee, Ops, Name);
+static CallInst *
+createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
+                 IRBuilderBase *Builder, const Twine &Name = "",
+                 Instruction *FMFSource = nullptr,
+                 ArrayRef<OperandBundleDef> OpBundles = llvm::None) {
+  CallInst *CI = CallInst::Create(Callee, Ops, OpBundles, Name);
   if (FMFSource)
     CI->copyFastMathFlags(FMFSource);
   Builder->GetInsertBlock()->getInstList().insert(Builder->GetInsertPoint(),CI);
@@ -457,14 +458,16 @@
   return createCallHelper(TheFn, Ops, this);
 }
 
-CallInst *IRBuilderBase::CreateAssumption(Value *Cond) {
+CallInst *
+IRBuilderBase::CreateAssumption(Value *Cond,
+                                ArrayRef<OperandBundleDef> OpBundles) {
   assert(Cond->getType() == getInt1Ty() &&
          "an assumption condition must be of type i1");
 
   Value *Ops[] = { Cond };
   Module *M = BB->getParent()->getParent();
   Function *FnAssume = Intrinsic::getDeclaration(M, Intrinsic::assume);
-  return createCallHelper(FnAssume, Ops, this);
+  return createCallHelper(FnAssume, Ops, this, "", nullptr, OpBundles);
 }
 
 /// Create a call to a Masked Load intrinsic.
Index: llvm/include/llvm/IR/IRBuilder.h
===================================================================
--- llvm/include/llvm/IR/IRBuilder.h
+++ llvm/include/llvm/IR/IRBuilder.h
@@ -692,7 +692,11 @@
 
   /// Create an assume intrinsic call that allows the optimizer to
   /// assume that the provided condition will be true.
-  CallInst *CreateAssumption(Value *Cond);
+  ///
+  /// The optional argument \p OpBundles specifies operand bundles that are
+  /// added to the call instruction.
+  CallInst *CreateAssumption(Value *Cond,
+                             ArrayRef<OperandBundleDef> OpBundles = llvm::None);
 
   /// Create a call to the experimental.gc.statepoint intrinsic to
   /// start a new statepoint sequence.
@@ -2704,35 +2708,16 @@
 
 private:
   /// Helper function that creates an assume intrinsic call that
-  /// represents an alignment assumption on the provided Ptr, Mask, Type
-  /// and Offset. It may be sometimes useful to do some other logic
-  /// based on this alignment check, thus it can be stored into 'TheCheck'.
+  /// represents an alignment assumption on the provided pointer \p PtrValue
+  /// with offset \p OffsetValue and alignment value \p AlignValue.
   CallInst *CreateAlignmentAssumptionHelper(const DataLayout &DL,
-                                            Value *PtrValue, Value *Mask,
-                                            Type *IntPtrTy, Value *OffsetValue,
-                                            Value **TheCheck) {
-    Value *PtrIntValue = CreatePtrToInt(PtrValue, IntPtrTy, "ptrint");
-
-    if (OffsetValue) {
-      bool IsOffsetZero = false;
-      if (const auto *CI = dyn_cast<ConstantInt>(OffsetValue))
-        IsOffsetZero = CI->isZero();
-
-      if (!IsOffsetZero) {
-        if (OffsetValue->getType() != IntPtrTy)
-          OffsetValue = CreateIntCast(OffsetValue, IntPtrTy, /*isSigned*/ true,
-                                      "offsetcast");
-        PtrIntValue = CreateSub(PtrIntValue, OffsetValue, "offsetptr");
-      }
-    }
-
-    Value *Zero = ConstantInt::get(IntPtrTy, 0);
-    Value *MaskedPtr = CreateAnd(PtrIntValue, Mask, "maskedptr");
-    Value *InvCond = CreateICmpEQ(MaskedPtr, Zero, "maskcond");
-    if (TheCheck)
-      *TheCheck = InvCond;
-
-    return CreateAssumption(InvCond);
+                                            Value *PtrValue, Value *AlignValue,
+                                            Value *OffsetValue) {
+    SmallVector<Value *, 2> Vals({PtrValue, AlignValue});
+    if (OffsetValue)
+      Vals.push_back(OffsetValue);
+    OperandBundleDefT<Value *> AlignOpB("align", Vals);
+    return CreateAssumption(ConstantInt::getTrue(getContext()), {AlignOpB});
   }
 
 public:
@@ -2742,22 +2727,17 @@
   /// An optional offset can be provided, and if it is provided, the offset
   /// must be subtracted from the provided pointer to get the pointer with the
   /// specified alignment.
-  ///
-  /// It may be sometimes useful to do some other logic
-  /// based on this alignment check, thus it can be stored into 'TheCheck'.
   CallInst *CreateAlignmentAssumption(const DataLayout &DL, Value *PtrValue,
                                       unsigned Alignment,
-                                      Value *OffsetValue = nullptr,
-                                      Value **TheCheck = nullptr) {
+                                      Value *OffsetValue = nullptr) {
     assert(isa<PointerType>(PtrValue->getType()) &&
            "trying to create an alignment assumption on a non-pointer?");
     assert(Alignment != 0 && "Invalid Alignment");
     auto *PtrTy = cast<PointerType>(PtrValue->getType());
     Type *IntPtrTy = getIntPtrTy(DL, PtrTy->getAddressSpace());
-
-    Value *Mask = ConstantInt::get(IntPtrTy, Alignment - 1);
-    return CreateAlignmentAssumptionHelper(DL, PtrValue, Mask, IntPtrTy,
-                                           OffsetValue, TheCheck);
+    Value *AlignValue = ConstantInt::get(IntPtrTy, Alignment);
+    return CreateAlignmentAssumptionHelper(DL, PtrValue, AlignValue,
+                                           OffsetValue);
   }
 
   /// Create an assume intrinsic call that represents an alignment
@@ -2767,15 +2747,11 @@
   /// must be subtracted from the provided pointer to get the pointer with the
   /// specified alignment.
   ///
-  /// It may be sometimes useful to do some other logic
-  /// based on this alignment check, thus it can be stored into 'TheCheck'.
-  ///
   /// This overload handles the condition where the Alignment is dependent
   /// on an existing value rather than a static value.
   CallInst *CreateAlignmentAssumption(const DataLayout &DL, Value *PtrValue,
                                       Value *Alignment,
-                                      Value *OffsetValue = nullptr,
-                                      Value **TheCheck = nullptr) {
+                                      Value *OffsetValue = nullptr) {
     assert(isa<PointerType>(PtrValue->getType()) &&
            "trying to create an alignment assumption on a non-pointer?");
     auto *PtrTy = cast<PointerType>(PtrValue->getType());
@@ -2785,10 +2761,8 @@
       Alignment = CreateIntCast(Alignment, IntPtrTy, /*isSigned*/ false,
                                 "alignmentcast");
 
-    Value *Mask = CreateSub(Alignment, ConstantInt::get(IntPtrTy, 1), "mask");
-
-    return CreateAlignmentAssumptionHelper(DL, PtrValue, Mask, IntPtrTy,
-                                           OffsetValue, TheCheck);
+    return CreateAlignmentAssumptionHelper(DL, PtrValue, Alignment,
+                                           OffsetValue);
   }
 };
 
Index: clang/test/CodeGen/alloc-align-attr.c
===================================================================
--- clang/test/CodeGen/alloc-align-attr.c
+++ clang/test/CodeGen/alloc-align-attr.c
@@ -6,26 +6,18 @@
 __INT32_TYPE__ test1(__INT32_TYPE__ a) {
 // CHECK: define i32 @test1
   return *m1(a);
-// CHECK: call i32* @m1(i32 [[PARAM1:%[^\)]+]])
+// CHECK: [[CALL1:%.+]] = call i32* @m1(i32 [[PARAM1:%[^\)]+]])
 // CHECK: [[ALIGNCAST1:%.+]] = zext i32 [[PARAM1]] to i64
-// CHECK: [[MASK1:%.+]] = sub i64 [[ALIGNCAST1]], 1
-// CHECK: [[PTRINT1:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR1:%.+]] = and i64 [[PTRINT1]], [[MASK1]]
-// CHECK: [[MASKCOND1:%.+]] = icmp eq i64 [[MASKEDPTR1]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND1]])
+// CHECK: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL1]], i64 [[ALIGNCAST1]]) ]
 }
 // Condition where test2 param needs casting.
 __INT32_TYPE__ test2(__SIZE_TYPE__ a) {
 // CHECK: define i32 @test2
   return *m1(a);
 // CHECK: [[CONV2:%.+]] = trunc i64 %{{.+}} to i32
-// CHECK: call i32* @m1(i32 [[CONV2]])
+// CHECK: [[CALL2:%.+]] = call i32* @m1(i32 [[CONV2]])
 // CHECK: [[ALIGNCAST2:%.+]] = zext i32 [[CONV2]] to i64
-// CHECK: [[MASK2:%.+]] = sub i64 [[ALIGNCAST2]], 1
-// CHECK: [[PTRINT2:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR2:%.+]] = and i64 [[PTRINT2]], [[MASK2]]
-// CHECK: [[MASKCOND2:%.+]] = icmp eq i64 [[MASKEDPTR2]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND2]])
+// CHECK: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL2]], i64 [[ALIGNCAST2]]) ]
 }
 __INT32_TYPE__ *m2(__SIZE_TYPE__ i) __attribute__((alloc_align(1)));
 
@@ -34,24 +26,16 @@
 // CHECK: define i32 @test3
   return *m2(a);
 // CHECK: [[CONV3:%.+]] = sext i32 %{{.+}} to i64
-// CHECK: call i32* @m2(i64 [[CONV3]])
-// CHECK: [[MASK3:%.+]] = sub i64 [[CONV3]], 1
-// CHECK: [[PTRINT3:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR3:%.+]] = and i64 [[PTRINT3]], [[MASK3]]
-// CHECK: [[MASKCOND3:%.+]] = icmp eq i64 [[MASKEDPTR3]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND3]])
+// CHECK: [[CALL3:%.+]] = call i32* @m2(i64 [[CONV3]])
+// CHECK: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL3]], i64 [[CONV3]]) ]
 }
 
 // Every type matches, canonical example.
 __INT32_TYPE__ test4(__SIZE_TYPE__ a) {
 // CHECK: define i32 @test4
   return *m2(a);
-// CHECK: call i32* @m2(i64 [[PARAM4:%[^\)]+]])
-// CHECK: [[MASK4:%.+]] = sub i64 [[PARAM4]], 1
-// CHECK: [[PTRINT4:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR4:%.+]] = and i64 [[PTRINT4]], [[MASK4]]
-// CHECK: [[MASKCOND4:%.+]] = icmp eq i64 [[MASKEDPTR4]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND4]])
+// CHECK: [[CALL4:%.+]] = call i32* @m2(i64 [[PARAM4:%[^\)]+]])
+// CHECK: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL4]], i64 [[PARAM4]]) ]
 }
 
 
@@ -64,13 +48,9 @@
 // CHECK: define i32 @test5
   struct Empty e;
   return *m3(e, a);
-// CHECK: call i32* @m3(i64 %{{.*}}, i64 %{{.*}})
+// CHECK: [[CALL5:%.+]] = call i32* @m3(i64 %{{.*}}, i64 %{{.*}})
 // CHECK: [[ALIGNCAST5:%.+]] = trunc i128 %{{.*}} to i64
-// CHECK: [[MASK5:%.+]] = sub i64 [[ALIGNCAST5]], 1
-// CHECK: [[PTRINT5:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR5:%.+]] = and i64 [[PTRINT5]], [[MASK5]]
-// CHECK: [[MASKCOND5:%.+]] = icmp eq i64 [[MASKEDPTR5]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND5]])
+// CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL5]], i64 [[ALIGNCAST5]]) ]
 }
 // Struct parameter takes up 2 parameters, 'i' takes up 2.
 __INT32_TYPE__ *m4(struct MultiArgs s, __int128_t i) __attribute__((alloc_align(2)));
@@ -78,12 +58,8 @@
 // CHECK: define i32 @test6
   struct MultiArgs e;
   return *m4(e, a);
-// CHECK: call i32* @m4(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
-// CHECK: [[ALIGNCAST6:%.+]] = trunc i128 %{{.*}} to i64
-// CHECK: [[MASK6:%.+]] = sub i64 [[ALIGNCAST6]], 1
-// CHECK: [[PTRINT6:%.+]] = ptrtoint
-// CHECK: [[MASKEDPTR6:%.+]] = and i64 [[PTRINT6]], [[MASK6]]
-// CHECK: [[MASKCOND6:%.+]] = icmp eq i64 [[MASKEDPTR6]], 0
-// CHECK: call void @llvm.assume(i1 [[MASKCOND6]])
+// CHECK: [[CALL6:%.+]] = call i32* @m4(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
+// CHECK-NEXT: [[ALIGNCAST6:%.+]] = trunc i128 %{{.*}} to i64
+// CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[CALL6]], i64 [[ALIGNCAST6]]) ]
 }
 
Index: clang/lib/CodeGen/CodeGenFunction.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.cpp
+++ clang/lib/CodeGen/CodeGenFunction.cpp
@@ -2126,13 +2126,35 @@
                                               SourceLocation AssumptionLoc,
                                               llvm::Value *Alignment,
                                               llvm::Value *OffsetValue) {
-  llvm::Value *TheCheck;
   llvm::Instruction *Assumption = Builder.CreateAlignmentAssumption(
-      CGM.getDataLayout(), PtrValue, Alignment, OffsetValue, &TheCheck);
-  if (SanOpts.has(SanitizerKind::Alignment)) {
-    EmitAlignmentAssumptionCheck(PtrValue, Ty, Loc, AssumptionLoc, Alignment,
-                                 OffsetValue, TheCheck, Assumption);
+      CGM.getDataLayout(), PtrValue, Alignment, OffsetValue);
+  if (!SanOpts.has(SanitizerKind::Alignment))
+    return;
+
+  llvm::Value *PtrIntValue =
+      Builder.CreatePtrToInt(PtrValue, IntPtrTy, "ptrint");
+
+  if (OffsetValue) {
+    bool IsOffsetZero = false;
+    if (const auto *CI = dyn_cast<llvm::ConstantInt>(OffsetValue))
+      IsOffsetZero = CI->isZero();
+
+    if (!IsOffsetZero) {
+      if (OffsetValue->getType() != IntPtrTy)
+        OffsetValue = Builder.CreateIntCast(OffsetValue, IntPtrTy,
+                                            /*isSigned*/ true, "offsetcast");
+      PtrIntValue = Builder.CreateSub(PtrIntValue, OffsetValue, "offsetptr");
+    }
   }
+
+  llvm::Value *Zero = llvm::ConstantInt::get(IntPtrTy, 0);
+  llvm::Value *Mask =
+      Builder.CreateSub(Alignment, llvm::ConstantInt::get(IntPtrTy, 1));
+  llvm::Value *MaskedPtr = Builder.CreateAnd(PtrIntValue, Mask, "maskedptr");
+  llvm::Value *TheCheck = Builder.CreateICmpEQ(MaskedPtr, Zero, "maskcond");
+
+  EmitAlignmentAssumptionCheck(PtrValue, Ty, Loc, AssumptionLoc, Alignment,
+                               OffsetValue, TheCheck, Assumption);
 }
 
 void CodeGenFunction::EmitAlignmentAssumption(llvm::Value *PtrValue,
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D71739: [WI... Johannes Doerfert via Phabricator via cfe-commits
    • [PATCH] D71739... Johannes Doerfert via Phabricator via cfe-commits
    • [PATCH] D71739... pre-merge checks [bot] via Phabricator via cfe-commits
    • [PATCH] D71739... pre-merge checks [bot] via Phabricator via cfe-commits

Reply via email to