vsk created this revision.

UBSan can check that scalar loads provide in-range values. When we load
a value from a bitfield, we know that the range of the value is
constrained by the bitfield's width. This patch teaches UBSan how to use
that information to skip emitting some range checks.

This depends on / is a follow-up to: https://reviews.llvm.org/D30423


https://reviews.llvm.org/D30729

Files:
  lib/CodeGen/CGExpr.cpp
  lib/CodeGen/CodeGenFunction.h
  test/CodeGenCXX/ubsan-bitfields.cpp

Index: test/CodeGenCXX/ubsan-bitfields.cpp
===================================================================
--- test/CodeGenCXX/ubsan-bitfields.cpp
+++ test/CodeGenCXX/ubsan-bitfields.cpp
@@ -7,7 +7,9 @@
 };
 
 struct S {
-  E e1 : 10;
+  E e1 : 10; //< Wide enough for the unsigned range check.
+  E e2 : 2; //< Still wide enough, 3 < 2^2.
+  E e3 : 1; //< Not wide enough, 3 > 2^1.
 };
 
 // CHECK-LABEL: define i32 @_Z4loadP1S
@@ -19,3 +21,65 @@
   // CHECK: call void @__ubsan_handle_load_invalid_value
   return s->e1;
 }
+
+// CHECK-LABEL: define i32 @_Z5load2P1S
+E load2(S *s) {
+  // CHECK: [[LOAD:%.*]] = load i16, i16* {{.*}}
+  // CHECK: [[LSHR:%.*]] = lshr i16 [[LOAD]], 10
+  // CHECK: [[CLEAR:%.*]] = and i16 [[LSHR]], 3
+  // CHECK: [[CAST:%.*]] = zext i16 [[CLEAR]] to i32
+  // CHECK: icmp ule i32 [[CAST]], 3, !nosanitize
+  // CHECK: call void @__ubsan_handle_load_invalid_value
+  return s->e2;
+}
+
+// CHECK-LABEL: define i32 @_Z5load3P1S
+E load3(S *s) {
+  // CHECK-NOT: !nosanitize
+  return s->e3;
+}
+
+enum E2 {
+  x = -3,
+  y = 3
+};
+
+struct S2 {
+  E2 e1 : 4; //< Wide enough for signed range checks.
+  E2 e2 : 3; //< Still wide enough, -3 > -2^2 and 3 < 2^2.
+  E2 e3 : 2; //< Not wide enough, -3 < -2^1.
+};
+
+// CHECK-LABEL: define i32 @_Z5load4P2S2
+E2 load4(S2 *s) {
+  // CHECK: [[LOAD:%.*]] = load i16, i16* {{.*}}
+  // CHECK: [[SHL:%.*]] = shl i16 [[LOAD]], 12
+  // CHECK: [[ASHR:%.*]] = ashr i16 [[SHL]], 12
+  // CHECK: [[CAST:%.*]] = sext i16 [[ASHR]] to i32
+  // CHECK: [[UPPER_BOUND:%.*]] = icmp sle i32 [[CAST]], 3, !nosanitize
+  // CHECK: [[LOWER_BOUND:%.*]] = icmp sge i32 [[CAST]], -4, !nosanitize
+  // CHECK: [[BOUND:%.*]] = and i1 [[UPPER_BOUND]], [[LOWER_BOUND]], !nosanitize
+  // CHECK: br i1 [[BOUND]], {{.*}}, !nosanitize
+  // CHECK: call void @__ubsan_handle_load_invalid_value
+  return s->e1;
+}
+
+// CHECK-LABEL: define i32 @_Z5load5P2S2
+E2 load5(S2 *s) {
+  // CHECK: [[LOAD:%.*]] = load i16, i16* {{.*}}
+  // CHECK: [[SHL:%.*]] = shl i16 [[LOAD]], 9
+  // CHECK: [[ASHR:%.*]] = ashr i16 [[SHL]], 13
+  // CHECK: [[CAST:%.*]] = sext i16 [[ASHR]] to i32
+  // CHECK: [[UPPER_BOUND:%.*]] = icmp sle i32 [[CAST]], 3, !nosanitize
+  // CHECK: [[LOWER_BOUND:%.*]] = icmp sge i32 [[CAST]], -4, !nosanitize
+  // CHECK: [[BOUND:%.*]] = and i1 [[UPPER_BOUND]], [[LOWER_BOUND]], !nosanitize
+  // CHECK: br i1 [[BOUND]], {{.*}}, !nosanitize
+  // CHECK: call void @__ubsan_handle_load_invalid_value
+  return s->e2;
+}
+
+// CHECK-LABEL: define i32 @_Z5load6P2S2
+E2 load6(S2 *s) {
+  // CHECK-NOT: !nosanitize
+  return s->e3;
+}
Index: lib/CodeGen/CodeGenFunction.h
===================================================================
--- lib/CodeGen/CodeGenFunction.h
+++ lib/CodeGen/CodeGenFunction.h
@@ -2892,11 +2892,12 @@
   llvm::Value *EmitFromMemory(llvm::Value *Value, QualType Ty);
 
   /// Check if the scalar \p Value is within the valid range for the given
-  /// type \p Ty.
+  /// type \p Ty. If \p BitWidth is provided, it must be the bit width of the
+  /// storage for the scalar.
   ///
-  /// Returns true if a check is needed (even if the range is unknown).
-  bool EmitScalarRangeCheck(llvm::Value *Value, QualType Ty,
-                            SourceLocation Loc);
+  /// Returns true if range metadata for the scalar should be dropped.
+  bool EmitScalarRangeCheck(llvm::Value *Value, QualType Ty, SourceLocation Loc,
+                            Optional<unsigned> BitWidth = None);
 
   /// EmitLoadOfScalar - Load a scalar value from an address, taking
   /// care to appropriately convert from the memory representation to
Index: lib/CodeGen/CGExpr.cpp
===================================================================
--- lib/CodeGen/CGExpr.cpp
+++ lib/CodeGen/CGExpr.cpp
@@ -1302,7 +1302,8 @@
 }
 
 bool CodeGenFunction::EmitScalarRangeCheck(llvm::Value *Value, QualType Ty,
-                                           SourceLocation Loc) {
+                                           SourceLocation Loc,
+                                           Optional<unsigned> BitWidth) {
   bool HasBoolCheck = SanOpts.has(SanitizerKind::Bool);
   bool HasEnumCheck = SanOpts.has(SanitizerKind::Enum);
   if (!HasBoolCheck && !HasEnumCheck)
@@ -1319,9 +1320,32 @@
   if (!getRangeForType(*this, Ty, Min, End, /*StrictEnums=*/true, IsBool))
     return true;
 
+  --End;
+  if (BitWidth) {
+    if (!Min) {
+      // If End > MaxValueForWidth, then Value < End. Skip the range check.
+      auto MaxValueForWidth =
+          llvm::APInt::getMaxValue(*BitWidth).zextOrSelf(End.getBitWidth());
+      if (End.ugt(MaxValueForWidth))
+        return false;
+    } else {
+      assert(Min.eq(-(End + 1)) &&
+             "The full range check assumes Min = -(End + 1)");
+      assert(Ty->hasSignedIntegerRepresentation() &&
+             "The full range check works on signed integers only");
+
+      // If Min < MinValueForWidth, then Value > Min. We know Min = -(End + 1),
+      // so End > MaxValueForWidth, and Value < End. Skip the range check.
+      auto MinValueForWidth =
+          llvm::APInt::getSignedMinValue(*BitWidth).sextOrSelf(
+              End.getBitWidth());
+      if (Min.slt(MinValueForWidth))
+        return false;
+    }
+  }
+
   SanitizerScope SanScope(this);
   llvm::Value *Check;
-  --End;
   if (!Min) {
     Check = Builder.CreateICmpULE(
         Value, llvm::ConstantInt::get(getLLVMContext(), End));
@@ -1566,7 +1590,7 @@
                               "bf.clear");
   }
   Val = Builder.CreateIntCast(Val, ResLTy, Info.IsSigned, "bf.cast");
-  EmitScalarRangeCheck(Val, LV.getType(), Loc);
+  EmitScalarRangeCheck(Val, LV.getType(), Loc, Info.Size);
   return RValue::get(Val);
 }
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to