xazax.hun created this revision.
xazax.hun added reviewers: zaks.anna, dcoughlin, krememek, jordan_rose.
xazax.hun added a subscriber: cfe-commits.

This is the first implementation of a checker that supposed to catch 
nullability errors.
Unfortunately the nullability qualifiers do not have well defined meaning, one 
can not assume that nonnull implies that the pointer can not be null.
In fact, the contract is, when the nullability preconditions on the parameter 
is not violated, then the nullability postconditions of the return value must 
not be violated either.
Right now the checker only checks simple rules, for example nullable pointers 
must be checked before dereferenced, should not pass null or unchecked nullable 
pointer to nonnul parameter, should not return null or unchecked nullable 
pointer from a nonnull returning function. The check will probably be relaxed, 
if one of the (not nullable) parameters known to be null, it will be ok to 
return null or unchecked nonnull pointer from a nonnull returning function.

Some details are still being worked out, how to define the nullability rules 
(in terms of this checker) to be able to both discover real issues and avoid 
false positives, while making it possible for users to suppress warnings 
(possibly using explicit casts). Once the rules are clear, supporting 
documentation will be provided.

http://reviews.llvm.org/D11468

Files:
  lib/StaticAnalyzer/Checkers/CMakeLists.txt
  lib/StaticAnalyzer/Checkers/Checkers.td
  lib/StaticAnalyzer/Checkers/NullabilityChecker.cpp
  test/Analysis/nullability.mm

Index: test/Analysis/nullability.mm
===================================================================
--- /dev/null
+++ test/Analysis/nullability.mm
@@ -0,0 +1,115 @@
+// RUN: %clang_cc1 -analyze -analyzer-checker=core,alpha.core.Nullability -verify %s
+
+#define nil 0
+
+@protocol NSObject
++ (id)alloc;
+- (id)init;
+@end
+
+@protocol NSCopying
+@end
+
+__attribute__((objc_root_class))
+@interface NSObject <NSObject>
+@end
+
+@interface TestObject : NSObject
+- (int * _Nonnull)returnsNonnull;
+- (int * _Nullable)returnsNullable;
+- (int *) returnsUnspecified;
+- (void)takesNonnull:(int * _Nonnull)p;
+- (void)takesNullable:(int * _Nullable)p;
+- (void)takesUnspecified:(int *)p;
+@end
+
+TestObject *getUnspecifiedTestObject();
+TestObject * _Nonnull getNonnullTestObject();
+TestObject * _Nullable getNullableTestObject();
+
+int getRandom();
+
+typedef struct Dummy {
+  int val;
+} Dummy;
+
+void takesNullable(Dummy * _Nullable);
+void takesNonnull(Dummy * _Nonnull);
+void takesUnspecified(Dummy *);
+
+Dummy * _Nullable returnsNullable();
+Dummy * _Nonnull returnsNonnull();
+Dummy * returnsUnspecified();
+int * _Nullable returnsNullableInt();
+
+template<typename T>
+T* eraseNullab(T* p) {
+  return p;
+}
+
+void testBasicRules() {
+  Dummy *p = returnsNullable();
+  int *ptr = returnsNullableInt();
+  // Make every dereference a different path to avoid nonnull assumptions.
+  switch(getRandom()) {
+    case 0: { Dummy &r = *p; } break; // expected-warning {{}}
+    case 1: { int b = p->val; } break; // expected-warning {{}}
+    case 2: { int stuff = *ptr; } break; // expected-warning {{}}
+    case 3: takesNonnull(p); break; // expected-warning {{}}
+    default: { Dummy d = *p; } break; // expected-warning {{}}
+  }
+  if (p) {
+    takesNonnull(p);
+    if (getRandom()) {
+      Dummy &r = *p;
+    } else {
+      int b = p->val;
+    }
+  }
+  Dummy *q = 0;
+  takesNullable(q);
+  takesNonnull(q); // expected-warning {{}}
+  Dummy a;
+  Dummy * _Nonnull nonnull = &a;
+  nonnull = q; // expected-warning {{}}
+  q = &a;
+  takesNullable(q);
+  takesNonnull(q);
+}
+
+void testArgumentTracking(Dummy * _Nonnull nonnull, Dummy * _Nullable nullable) {
+  Dummy *p = nullable;
+  nonnull = p; // expected-warning {{}}
+  p = 0;
+  Dummy *q = nonnull;
+  q = p;
+}
+
+Dummy * _Nonnull testNullableReturn(Dummy * _Nullable a) {
+  Dummy *p = a;
+  return p; // expected-warning {{}}
+}
+
+Dummy * _Nonnull testNullReturn() {
+  Dummy *p = 0;
+  return p; // expected-warning {{}}
+}
+
+void testObjCMessageResultNullability() {
+  // The expected result: the most nullable of self and method return type.
+  TestObject *o = getUnspecifiedTestObject();
+  int *shouldBeNullable = [eraseNullab(getNullableTestObject()) returnsNonnull];
+  [o takesNonnull: shouldBeNullable]; // expected-warning {{}}
+  shouldBeNullable = [eraseNullab(getNullableTestObject()) returnsUnspecified];
+  [o takesNonnull: shouldBeNullable]; // expected-warning {{}}
+  shouldBeNullable = [eraseNullab(getNullableTestObject()) returnsNullable];
+  [o takesNonnull: shouldBeNullable]; // expected-warning {{}}
+  shouldBeNullable = [eraseNullab(getNonnullTestObject()) returnsNullable];
+  [o takesNonnull: shouldBeNullable]; // expected-warning {{}}
+  shouldBeNullable = [eraseNullab(getUnspecifiedTestObject()) returnsNullable];
+  [o takesNonnull: shouldBeNullable]; // expected-warning {{}}
+  shouldBeNullable = [eraseNullab(getNullableTestObject()) returnsNullable];
+  [o takesNonnull: shouldBeNullable]; // expected-warning {{}}
+  int * shouldBeNonnull = [eraseNullab(getNonnullTestObject()) returnsNonnull];
+  [o takesNonnull: shouldBeNonnull];
+}
Index: lib/StaticAnalyzer/Checkers/NullabilityChecker.cpp
===================================================================
--- /dev/null
+++ lib/StaticAnalyzer/Checkers/NullabilityChecker.cpp
@@ -0,0 +1,573 @@
+//== Nullabilityhecker.cpp - Nullability checker ----------------*- C++ -*--==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This checker tries to find nullability violations. The assumption of the
+// checker is that, the user is running this checker after all the nullability
+// warnings that is emitted by the compiler was fixed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ClangSACheckers.h"
+#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
+#include "clang/StaticAnalyzer/Core/Checker.h"
+#include "clang/StaticAnalyzer/Core/CheckerManager.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
+
+using namespace clang;
+using namespace ento;
+
+namespace {
+enum class Nullability : char {
+  Contradicted, // Tracked nullability is contradicted by an explicit cast.
+  Nullable,
+  Unspecified, // Optimization: Most pointers expected to be unspecified. When
+               // memory region is not stored in the state, it implicitly means
+               // unspecified.
+  Nonnull
+};
+
+static const char *getNullabilityString(Nullability Nullab) {
+  switch (Nullab) {
+  case Nullability::Contradicted:
+    return "contradicted";
+  case Nullability::Nullable:
+    return "nullable";
+  case Nullability::Unspecified:
+    return "unspecified";
+  case Nullability::Nonnull:
+    return "nonnull";
+  }
+  assert(false);
+  return "";
+}
+
+static Nullability getMostNullable(Nullability Lhs, Nullability Rhs) {
+  return static_cast<Nullability>(
+      std::min(static_cast<char>(Lhs), static_cast<char>(Rhs)));
+}
+
+enum class ErrorKind {
+  NilAssignedToNonnull,
+  NilPassedToNonnull,
+  NilReturnedToNonnull,
+  NullableAssignedToNonnull,
+  NullableReturnedToNonnull,
+  NullableDereferenced,
+  NullableAssignedToReference,
+  NullablePassedToNonnull
+};
+
+class NullabilityChecker
+    : public Checker<check::Bind, check::PreCall, check::PreStmt<ReturnStmt>,
+                     check::PostStmt<ExplicitCastExpr>, check::PostObjCMessage,
+                     check::DeadSymbols, check::Event<ImplicitNullDerefEvent>> {
+  mutable std::unique_ptr<BugType> BT;
+
+public:
+  void checkBind(SVal L, SVal V, const Stmt *S, CheckerContext &C) const;
+  void checkPostStmt(const ExplicitCastExpr *CE, CheckerContext &C) const;
+  void checkPreStmt(const ReturnStmt *S, CheckerContext &C) const;
+  void checkPostObjCMessage(const ObjCMethodCall &M, CheckerContext &C) const;
+  void checkPreCall(const CallEvent &Call, CheckerContext &C) const;
+  void checkDeadSymbols(SymbolReaper &SR, CheckerContext &C) const;
+  void checkEvent(ImplicitNullDerefEvent Event) const;
+
+  class NullabilityBugVisitor
+      : public BugReporterVisitorImpl<NullabilityBugVisitor> {
+  public:
+    NullabilityBugVisitor(const MemRegion *M) : Region(M) {}
+    ~NullabilityBugVisitor() override {}
+
+    void Profile(llvm::FoldingSetNodeID &ID) const override {
+      static int X = 0;
+      ID.AddPointer(&X);
+      ID.AddPointer(Region);
+    }
+
+    PathDiagnosticPiece *VisitNode(const ExplodedNode *N,
+                                   const ExplodedNode *PrevN,
+                                   BugReporterContext &BRC,
+                                   BugReport &BR) override;
+
+  private:
+    // The tracked region.
+    const MemRegion *Region;
+  };
+
+  void reportBug(ErrorKind Error, ExplodedNode *N, const MemRegion *Region,
+                 BugReporter &BR) const {
+    if (!BT)
+      BT.reset(new BugType(this, "Nullability", "Memory error"));
+    const char *Msg = nullptr;
+    switch (Error) {
+    case ErrorKind::NilAssignedToNonnull:
+      Msg = "Nullpointer is assigned to nonnull pointer";
+      break;
+    case ErrorKind::NilPassedToNonnull:
+      Msg = "Nullpointer is passed to a nonnull parameter";
+      break;
+    case ErrorKind::NilReturnedToNonnull:
+      Msg = "Nullpointer is returned from a nonnull returning function";
+      break;
+    case ErrorKind::NullableAssignedToNonnull:
+      Msg = "Nullable pointer is assigned to nonnull without a defensive check";
+      break;
+    case ErrorKind::NullableReturnedToNonnull:
+      Msg = "Nullable pointer is returned from a nonnull returning function "
+            "without a defensive check";
+      break;
+    case ErrorKind::NullableDereferenced:
+      Msg = "Nullable pointer is dereferenced without a defensive check";
+      break;
+    case ErrorKind::NullableAssignedToReference:
+      Msg = "Nullable pointer is assigned to a reference without a defensive "
+            "check";
+      break;
+    case ErrorKind::NullablePassedToNonnull:
+      Msg = "Nullable pointer is passed to a nonnull parameter without a "
+            "defensive check";
+      break;
+    }
+    assert(Msg);
+    std::unique_ptr<BugReport> R(new BugReport(*BT, Msg, N));
+    if (Region) {
+      R->markInteresting(Region);
+      R->addVisitor(llvm::make_unique<NullabilityBugVisitor>(Region));
+    }
+    BR.emitReport(std::move(R));
+  }
+};
+
+class NullabilityState {
+public:
+  NullabilityState(Nullability Nullab) : Nullab(Nullab) {}
+
+  Nullability getValue() const { return Nullab; }
+
+  void Profile(llvm::FoldingSetNodeID &ID) const {
+    ID.AddInteger(static_cast<char>(Nullab));
+  }
+
+private:
+  Nullability Nullab;
+};
+
+bool operator==(NullabilityState Lhs, NullabilityState Rhs) {
+  return Lhs.getValue() == Rhs.getValue();
+}
+
+} // end anonymous namespace
+
+REGISTER_MAP_WITH_PROGRAMSTATE(NullabilityMap, const MemRegion *,
+                               NullabilityState)
+
+PathDiagnosticPiece *NullabilityChecker::NullabilityBugVisitor::VisitNode(
+    const ExplodedNode *N, const ExplodedNode *PrevN, BugReporterContext &BRC,
+    BugReport &BR) {
+  ProgramStateRef state = N->getState();
+  ProgramStateRef statePrev = PrevN->getState();
+
+  const NullabilityState *TrackedNullab = state->get<NullabilityMap>(Region);
+  const NullabilityState *TrackedNullabPrev =
+      statePrev->get<NullabilityMap>(Region);
+  if (!TrackedNullab)
+    return nullptr;
+
+  if (TrackedNullabPrev &&
+      TrackedNullabPrev->getValue() == TrackedNullab->getValue())
+    return nullptr;
+
+  // Retrieve the associated statement.
+  const Stmt *S = nullptr;
+  ProgramPoint ProgLoc = N->getLocation();
+  if (Optional<StmtPoint> SP = ProgLoc.getAs<StmtPoint>()) {
+    S = SP->getStmt();
+  }
+
+  if (!S)
+    return nullptr;
+
+  std::string InfoText = (llvm::Twine("Nullability '") +
+                          getNullabilityString(TrackedNullab->getValue()) +
+                          "' is infered from this context'")
+                             .str();
+
+  // Generate the extra diagnostic.
+  PathDiagnosticLocation Pos(S, BRC.getSourceManager(),
+                             N->getLocationContext());
+  return new PathDiagnosticEventPiece(Pos, InfoText, true, nullptr);
+}
+
+static Nullability getNullability(QualType Type) {
+  const auto *AttrType = Type->getAs<AttributedType>();
+  if (!AttrType)
+    return Nullability::Unspecified;
+  if (AttrType->getAttrKind() == AttributedType::attr_nullable)
+    return Nullability::Nullable;
+  else if (AttrType->getAttrKind() == AttributedType::attr_nonnull)
+    return Nullability::Nonnull;
+  return Nullability::Unspecified;
+}
+
+void NullabilityChecker::checkDeadSymbols(SymbolReaper &SR,
+                                          CheckerContext &C) const {
+  ProgramStateRef State = C.getState();
+  NullabilityMapTy Nullabilities = State->get<NullabilityMap>();
+  for (NullabilityMapTy::iterator I = Nullabilities.begin(),
+                                  E = Nullabilities.end();
+       I != E; ++I) {
+    if (!SR.isLiveRegion(I->first)) {
+      State = State->remove<NullabilityMap>(I->first);
+    }
+  }
+}
+
+void NullabilityChecker::checkEvent(ImplicitNullDerefEvent Event) const {
+  SVal DereferencedSVal = Event.Location;
+
+  auto RegionSVal = DereferencedSVal.getAs<loc::MemRegionVal>();
+  if (!RegionSVal)
+    return;
+
+  ProgramStateRef State = Event.SinkNode->getState();
+  const MemRegion *Region = RegionSVal->getRegion();
+  const NullabilityState *TrackedNullability =
+      State->get<NullabilityMap>(Region);
+
+  if (!TrackedNullability) {
+    // Maybe a field or an element is loaded of a nullable pointer.
+    TrackedNullability = State->get<NullabilityMap>(
+        Region->getAs<SubRegion>()->getSuperRegion());
+    if (!TrackedNullability)
+      return;
+  }
+
+  Nullability TrackedNullabValue = TrackedNullability->getValue();
+
+  if (TrackedNullabValue == Nullability::Nullable) {
+    BugReporter &BR = *Event.BR;
+    reportBug(ErrorKind::NullableDereferenced, Event.SinkNode, Region, BR);
+  }
+}
+
+void NullabilityChecker::checkPreStmt(const ReturnStmt *S,
+                                      CheckerContext &C) const {
+
+  auto RetExpr = S->getRetValue();
+  if (!RetExpr)
+    return;
+
+  QualType RetExprType = RetExpr->getType();
+  // FIXME: What about references?
+  if (!RetExprType->isPointerType() && !RetExprType->isObjCObjectPointerType())
+    return;
+
+  ProgramStateRef State = C.getState();
+  SVal RetSVal = State->getSVal(S, C.getLocationContext());
+  if (RetSVal.isUndef())
+    return;
+
+  AnalysisDeclContext *DeclCtxt =
+      C.getLocationContext()->getAnalysisDeclContext();
+
+  const FunctionType *FuncType = DeclCtxt->getDecl()->getFunctionType();
+  if (!FuncType)
+    return;
+
+  QualType ReturnType = FuncType->getReturnType();
+  Nullability StaticNullability = getNullability(ReturnType);
+
+  DefinedOrUnknownSVal ReturnValue = RetSVal.castAs<DefinedOrUnknownSVal>();
+
+  ProgramStateRef StNonNull, StNull;
+  std::tie(StNonNull, StNull) = State->assume(ReturnValue);
+  bool IsNotNull = !StNull && StNonNull;
+  bool IsNull = StNull && !StNonNull;
+  if (IsNull && StaticNullability == Nullability::Nonnull) {
+    ExplodedNode *N = C.addTransition();
+    reportBug(ErrorKind::NilReturnedToNonnull, N, nullptr, C.getBugReporter());
+    return;
+  }
+
+  auto RetRegionSVal = ReturnValue.getAs<loc::MemRegionVal>();
+  if (!RetRegionSVal)
+    return;
+
+  const MemRegion *Region = RetRegionSVal->getRegion();
+  const NullabilityState *TrackedNullability =
+      State->get<NullabilityMap>(Region);
+  if (TrackedNullability) {
+    Nullability TrackedNullabValue = TrackedNullability->getValue();
+    if (!IsNotNull && TrackedNullabValue == Nullability::Nullable &&
+        StaticNullability == Nullability::Nonnull) {
+      ExplodedNode *N = C.addTransition();
+      reportBug(ErrorKind::NullableReturnedToNonnull, N, Region,
+                C.getBugReporter());
+      return;
+    }
+  } else if (StaticNullability != Nullability::Unspecified) {
+    State = State->set<NullabilityMap>(Region, StaticNullability);
+    C.addTransition(State);
+  }
+}
+
+void NullabilityChecker::checkPreCall(const CallEvent &Call,
+                                      CheckerContext &C) const {
+  const Decl *FD = Call.getDecl();
+  if (!FD)
+    return;
+
+  ProgramStateRef State = C.getState();
+  ProgramStateRef OrigState = State;
+
+  unsigned Idx = 0;
+  for (const ParmVarDecl *Param : Call.parameters()) {
+    if (Param->isParameterPack())
+      break;
+
+    const Expr *ArgExpr = nullptr;
+    if (Idx < Call.getNumArgs())
+      ArgExpr = Call.getArgExpr(Idx);
+    SVal ArgSVal = Call.getArgSVal(Idx++);
+    if (ArgSVal.isUndef())
+      continue;
+
+    if (!Param->getType()->isPointerType() &&
+        !Param->getType()->isReferenceType() &&
+        !Param->getType()->isObjCObjectPointerType()) {
+      continue;
+    }
+
+    ProgramStateRef StNonNull, StNull;
+    DefinedOrUnknownSVal DefArgSVal = ArgSVal.castAs<DefinedOrUnknownSVal>();
+    std::tie(StNonNull, StNull) = State->assume(DefArgSVal);
+    bool IsNotNull = !StNull && StNonNull;
+    bool IsNull = StNull && !StNonNull;
+
+    Nullability StaticNullability = getNullability(Param->getType());
+    // When the static type of the parameter has no nullability information. The
+    // static type of the argument might have.
+    if (StaticNullability == Nullability::Unspecified && ArgExpr) {
+      StaticNullability = getNullability(ArgExpr->getType());
+    }
+
+    if (IsNull && StaticNullability == Nullability::Nonnull) {
+      ExplodedNode *N = C.addTransition();
+      reportBug(ErrorKind::NilPassedToNonnull, N, nullptr, C.getBugReporter());
+      return;
+    }
+
+    auto ArgRegionSVal = ArgSVal.getAs<loc::MemRegionVal>();
+    if (!ArgRegionSVal)
+      continue;
+
+    const MemRegion *Region = ArgRegionSVal->getRegion();
+    const NullabilityState *TrackedNullability =
+        State->get<NullabilityMap>(Region);
+
+    if (TrackedNullability) {
+      Nullability TrackedNullabValue = TrackedNullability->getValue();
+      if (!IsNotNull && TrackedNullabValue == Nullability::Nullable &&
+          StaticNullability == Nullability::Nonnull) {
+        ExplodedNode *N = C.addTransition();
+        reportBug(ErrorKind::NullablePassedToNonnull, N, Region,
+                  C.getBugReporter());
+        return; // FIXME: What if multiple parameters should be reported?
+      } else if (!IsNotNull && TrackedNullabValue == Nullability::Nullable &&
+                 Param->getType()->isReferenceType()) {
+        ExplodedNode *N = C.addTransition();
+        reportBug(ErrorKind::NullableAssignedToReference, N, Region,
+                  C.getBugReporter());
+        return;
+      }
+    } else if (StaticNullability != Nullability::Unspecified) {
+      State = State->set<NullabilityMap>(Region, StaticNullability);
+    }
+  }
+  if (State != OrigState)
+    C.addTransition(State);
+}
+
+void NullabilityChecker::checkPostObjCMessage(const ObjCMethodCall &M,
+                                              CheckerContext &C) const {
+  auto Decl = M.getDecl();
+  if (!Decl)
+    return;
+  QualType RetType = Decl->getReturnType();
+  if (!RetType->isPointerType() && !RetType->isObjCObjectPointerType())
+    return;
+
+  const ObjCMessageExpr *Message = M.getOriginExpr();
+
+  ProgramStateRef State = C.getState();
+  SVal ResultSVal = M.getReturnValue();
+  auto MemRegVal = ResultSVal.getAs<loc::MemRegionVal>();
+  if (!MemRegVal)
+    return;
+
+  Nullability SelfNullability = Nullability::Unspecified;
+  if (Message->getReceiverKind() == ObjCMessageExpr::SuperClass ||
+      Message->getReceiverKind() == ObjCMessageExpr::SuperInstance) {
+    SelfNullability = Nullability::Nonnull;
+  } else {
+    SVal Receiver = M.getReceiverSVal();
+    auto ValueRegionSVal = Receiver.getAs<loc::MemRegionVal>();
+    if (ValueRegionSVal) {
+      const MemRegion *SelfRegion = ValueRegionSVal->getRegion();
+      assert(SelfRegion);
+
+      const NullabilityState *TrackedSelfNullability =
+          State->get<NullabilityMap>(SelfRegion);
+      if (TrackedSelfNullability) {
+        SelfNullability = TrackedSelfNullability->getValue();
+      }
+    }
+  }
+
+  const MemRegion *ReturnRegion = MemRegVal->getRegion();
+  assert(ReturnRegion);
+
+  const NullabilityState *TrackedNullability =
+      State->get<NullabilityMap>(ReturnRegion);
+  if (TrackedNullability) {
+    Nullability RetValTracked = TrackedNullability->getValue();
+    Nullability NewNullability =
+        getMostNullable(RetValTracked, SelfNullability);
+    if (NewNullability != RetValTracked &&
+        NewNullability != Nullability::Unspecified) {
+      State = State->set<NullabilityMap>(ReturnRegion, NewNullability);
+      C.addTransition(State);
+    }
+  } else {
+    // Use static type information for return value.
+    Nullability RetNullability = getNullability(RetType);
+    RetNullability = getMostNullable(RetNullability, SelfNullability);
+    if (RetNullability != Nullability::Unspecified) {
+      State = State->set<NullabilityMap>(ReturnRegion, RetNullability);
+      C.addTransition(State);
+    }
+  }
+}
+
+void NullabilityChecker::checkPostStmt(const ExplicitCastExpr *CE,
+                                       CheckerContext &C) const {
+  QualType OriginType = CE->getSubExpr()->getType();
+  QualType DestType = CE->getType();
+  if (!OriginType->isPointerType() && !OriginType->isObjCObjectPointerType())
+    return;
+  if (!DestType->isPointerType() && !DestType->isObjCObjectPointerType())
+    return;
+
+  Nullability DestNullability = getNullability(DestType);
+
+  if (DestNullability == Nullability::Unspecified)
+    return;
+
+  ProgramStateRef State = C.getState();
+  SVal ExprSVal = State->getSVal(CE, C.getLocationContext());
+  SymbolRef Sym = ExprSVal.getAsSymbol();
+  if (!Sym)
+    return;
+
+  const auto *SymRegVal = dyn_cast<SymbolRegionValue>(Sym);
+  if (!SymRegVal)
+    return;
+  const TypedValueRegion *Region = SymRegVal->getRegion();
+
+  // When 0 is converted to nonnull mark it as contradicted.
+  if (DestNullability == Nullability::Nonnull && !ExprSVal.isUndef()) {
+    ProgramStateRef StNonNull, StNull;
+    std::tie(StNonNull, StNull) =
+        State->assume(ExprSVal.castAs<DefinedOrUnknownSVal>());
+    if (StNull && !StNonNull) {
+      State = State->set<NullabilityMap>(Region, Nullability::Contradicted);
+      C.addTransition(State);
+      return;
+    }
+  }
+
+  const NullabilityState *TrackedNullability =
+      State->get<NullabilityMap>(Region);
+
+  if (!TrackedNullability) {
+    State = State->set<NullabilityMap>(Region, DestNullability);
+    C.addTransition(State);
+  } else if (TrackedNullability->getValue() != DestNullability) {
+    // Do not add redundant transitions.
+    if (TrackedNullability->getValue() == Nullability::Contradicted)
+      return;
+    State = State->set<NullabilityMap>(Region, Nullability::Contradicted);
+    C.addTransition(State);
+  }
+}
+
+void NullabilityChecker::checkBind(SVal L, SVal V, const Stmt *S,
+                                   CheckerContext &C) const {
+  const MemRegion *MR = L.getAsRegion();
+  const TypedValueRegion *TVR = dyn_cast_or_null<TypedValueRegion>(MR);
+  if (!TVR)
+    return;
+
+  QualType LocType = TVR->getValueType();
+  if (!LocType->isPointerType() && !LocType->isReferenceType())
+    return;
+
+  Nullability LocNullability = getNullability(LocType);
+
+  ProgramStateRef State = C.getState();
+  ProgramStateRef StNonNull, StNull;
+  std::tie(StNonNull, StNull) = State->assume(V.castAs<DefinedOrUnknownSVal>());
+  bool RhsIsNull = !StNonNull && StNull;
+  bool RhsIsNotNull = StNonNull && !StNull;
+
+  // The null pointer is loaded to a reference is handled in another checker.
+  if (RhsIsNull && LocNullability == Nullability::Nonnull) {
+    ExplodedNode *N = C.addTransition();
+    reportBug(ErrorKind::NilAssignedToNonnull, N, nullptr, C.getBugReporter());
+    return;
+  }
+
+  auto ValueRegionSVal = V.getAs<loc::MemRegionVal>();
+  if (!ValueRegionSVal)
+    return;
+
+  const MemRegion *ValueRegion = ValueRegionSVal->getRegion();
+  assert(ValueRegion);
+
+  Nullability ValNullability = Nullability::Unspecified;
+  if (SymbolRef Sym = V.getAsSymbol())
+    ValNullability = getNullability(Sym->getType());
+
+  const NullabilityState *TrackedNullability =
+      State->get<NullabilityMap>(ValueRegion);
+
+  if (TrackedNullability) {
+    ValNullability = TrackedNullability->getValue();
+    if (!RhsIsNotNull && ValNullability == Nullability::Nullable) {
+      if (LocNullability == Nullability::Nonnull) {
+        ExplodedNode *N = C.addTransition();
+        reportBug(ErrorKind::NullableAssignedToNonnull, N, ValueRegion,
+                  C.getBugReporter());
+      }
+    }
+  } else if (ValNullability != Nullability::Unspecified) {
+    // Trust the static information of the value more than the static
+    // information on the location.
+    State = State->set<NullabilityMap>(ValueRegion, ValNullability);
+    C.addTransition(State);
+  } else if (LocNullability != Nullability::Unspecified) {
+    State = State->set<NullabilityMap>(ValueRegion, LocNullability);
+    C.addTransition(State);
+  }
+}
+
+void ento::registerNullabilityChecker(CheckerManager &mgr) {
+  mgr.registerChecker<NullabilityChecker>();
+}
Index: lib/StaticAnalyzer/Checkers/Checkers.td
===================================================================
--- lib/StaticAnalyzer/Checkers/Checkers.td
+++ lib/StaticAnalyzer/Checkers/Checkers.td
@@ -128,6 +128,10 @@
   HelpText<"Check for division by variable that is later compared against 0. Either the comparison is useless or there is division by zero.">,
   DescFile<"TestAfterDivZeroChecker.cpp">;
 
+def NullabilityChecker : Checker<"Nullability">,
+  HelpText<"Warn about nullability missuses">,
+  DescFile<"NullabilityChecker.cpp">;
+
 } // end "alpha.core"
 
 //===----------------------------------------------------------------------===//
Index: lib/StaticAnalyzer/Checkers/CMakeLists.txt
===================================================================
--- lib/StaticAnalyzer/Checkers/CMakeLists.txt
+++ lib/StaticAnalyzer/Checkers/CMakeLists.txt
@@ -48,6 +48,7 @@
   NSErrorChecker.cpp
   NoReturnFunctionChecker.cpp
   NonNullParamChecker.cpp
+  NullabilityChecker.cpp
   ObjCAtSyncChecker.cpp
   ObjCContainersASTChecker.cpp
   ObjCContainersChecker.cpp
_______________________________________________
cfe-commits mailing list
cfe-commits@cs.uiuc.edu
http://lists.cs.uiuc.edu/mailman/listinfo/cfe-commits

Reply via email to