dkrupp updated this revision to Diff 514973.
dkrupp marked an inline comment as done.
dkrupp added a comment.

- Implemented early return in getTaintedSymbols() when it is called by 
isTainted() for efficiency
- Fixed test incompatibility on Windows


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D144269/new/

https://reviews.llvm.org/D144269

Files:
  clang/include/clang/StaticAnalyzer/Checkers/Taint.h
  clang/include/clang/StaticAnalyzer/Core/BugReporter/CommonBugCategories.h
  clang/lib/StaticAnalyzer/Checkers/ArrayBoundCheckerV2.cpp
  clang/lib/StaticAnalyzer/Checkers/DivZeroChecker.cpp
  clang/lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp
  clang/lib/StaticAnalyzer/Checkers/Taint.cpp
  clang/lib/StaticAnalyzer/Checkers/VLASizeChecker.cpp
  clang/lib/StaticAnalyzer/Core/CommonBugCategories.cpp
  clang/test/Analysis/taint-diagnostic-visitor.c
  clang/test/Analysis/taint-tester.c

Index: clang/test/Analysis/taint-tester.c
===================================================================
--- clang/test/Analysis/taint-tester.c
+++ clang/test/Analysis/taint-tester.c
@@ -122,7 +122,7 @@
   fscanf(pp, "%d", &ii);
   int jj = ii;// expected-warning + {{tainted}}
 
-  fscanf(p, "%d", &ii);
+  fscanf(p, "%d", &ii);// expected-warning + {{tainted}}
   int jj2 = ii;// expected-warning + {{tainted}}
 
   ii = 3;
Index: clang/test/Analysis/taint-diagnostic-visitor.c
===================================================================
--- clang/test/Analysis/taint-diagnostic-visitor.c
+++ clang/test/Analysis/taint-diagnostic-visitor.c
@@ -2,13 +2,24 @@
 
 // This file is for testing enhanced diagnostics produced by the GenericTaintChecker
 
+typedef __typeof(sizeof(int)) size_t;
+struct _IO_FILE;
+typedef struct _IO_FILE FILE;
+
 int scanf(const char *restrict format, ...);
 int system(const char *command);
+char* getenv( const char* env_var );
+size_t strlen( const char* str );
+void *malloc(size_t size );
+void free( void *ptr );
+char *fgets(char *str, int n, FILE *stream);
+FILE *stdin;
 
 void taintDiagnostic(void)
 {
   char buf[128];
   scanf("%s", buf); // expected-note {{Taint originated here}}
+                    // expected-note@-1 {{Taint propagated to the 2nd argument}}
   system(buf); // expected-warning {{Untrusted data is passed to a system call}} // expected-note {{Untrusted data is passed to a system call (CERT/STR02-C. Sanitize data passed to complex subsystems)}}
 }
 
@@ -16,6 +27,7 @@
   int index;
   int Array[] = {1, 2, 3, 4, 5};
   scanf("%d", &index); // expected-note {{Taint originated here}}
+                       // expected-note@-1 {{Taint propagated to the 2nd argument}}
   return Array[index]; // expected-warning {{Out of bound memory access (index is tainted)}}
                        // expected-note@-1 {{Out of bound memory access (index is tainted)}}
 }
@@ -23,6 +35,7 @@
 int taintDiagnosticDivZero(int operand) {
   scanf("%d", &operand); // expected-note {{Value assigned to 'operand'}}
                          // expected-note@-1 {{Taint originated here}}
+                         // expected-note@-2 {{Taint propagated to the 2nd argument}}
   return 10 / operand; // expected-warning {{Division by a tainted value, possibly zero}}
                        // expected-note@-1 {{Division by a tainted value, possibly zero}}
 }
@@ -31,6 +44,71 @@
   int x;
   scanf("%d", &x); // expected-note {{Value assigned to 'x'}}
                    // expected-note@-1 {{Taint originated here}}
+                   // expected-note@-2 {{Taint propagated to the 2nd argument}}
   int vla[x]; // expected-warning {{Declared variable-length array (VLA) has tainted size}}
               // expected-note@-1 {{Declared variable-length array (VLA) has tainted size}}
 }
+
+
+//Tests if the originated note is correctly placed even if the path is
+//propagating through variables and expressions
+char* taintDiagnosticPropagation(){
+  char *pathbuf;
+  char *pathlist=getenv("PATH"); // expected-note {{Taint originated here}}
+                                 // expected-note@-1 {{Taint propagated to the return value}}
+  if (pathlist){ // expected-note {{Assuming 'pathlist' is non-null}}
+	               // expected-note@-1 {{Taking true branch}}
+    pathbuf=(char*) malloc(strlen(pathlist)+1); // expected-warning{{Untrusted data is used to specify the buffer size}}
+                                                // expected-note@-1{{Untrusted data is used to specify the buffer size}}
+                                                // expected-note@-2 {{Taint propagated to the return value}}
+    return pathbuf;
+  }
+  return 0;
+}
+
+//Taint origin should be marked correctly even if there are multiple taint
+//sources in the function
+char* taintDiagnosticPropagation2(){
+  char *pathbuf;
+  char *user_env2=getenv("USER_ENV_VAR2");//unrelated taint source
+  char *pathlist=getenv("PATH"); // expected-note {{Taint originated here}}
+                                 // expected-note@-1 {{Taint propagated to the return value}}
+  char *user_env=getenv("USER_ENV_VAR");//unrelated taint source
+  if (pathlist){ // expected-note {{Assuming 'pathlist' is non-null}}
+	               // expected-note@-1 {{Taking true branch}}
+    pathbuf=(char*) malloc(strlen(pathlist)+1); // expected-warning{{Untrusted data is used to specify the buffer size}}
+                                                // expected-note@-1{{Untrusted data is used to specify the buffer size}}
+                                                // expected-note@-2 {{Taint propagated to the return value}}
+    return pathbuf;
+  }
+  return 0;
+}
+
+void testReadStdIn(){
+  char buf[1024];
+  fgets(buf, sizeof(buf), stdin);// expected-note {{Taint originated here}}
+                                 // expected-note@-1 {{Taint propagated to the 1st argument}}
+  system(buf);// expected-warning {{Untrusted data is passed to a system call}} // expected-note {{Untrusted data is passed to a system call (CERT/STR02-C. Sanitize data passed to complex subsystems)}}
+
+}
+
+void multipleTaintSources(void) {
+  int x,y,z;
+  scanf("%d", &x); // expected-note {{Taint originated here}}
+                   // expected-note@-1 {{Taint propagated to the 2nd argument}}
+  scanf("%d", &y); // expected-note {{Taint originated here}}
+                   // expected-note@-1 {{Taint propagated to the 2nd argument}}
+  scanf("%d", &z);
+  int* ptr = (int*) malloc(y + x); // expected-warning {{Untrusted data is used to specify the buffer size}}
+                                   // expected-note@-1{{Untrusted data is used to specify the buffer size}}
+  free (ptr);
+}
+
+void multipleTaintedArgs(void) {
+  int x,y;
+  scanf("%d %d", &x, &y); // expected-note {{Taint originated here}}
+                          // expected-note@-1 {{Taint propagated to the 2nd argument, 3rd argument}}
+  int* ptr = (int*) malloc(x + y); // expected-warning {{Untrusted data is used to specify the buffer size}}
+                                   // expected-note@-1{{Untrusted data is used to specify the buffer size}}
+  free (ptr);
+}
Index: clang/lib/StaticAnalyzer/Core/CommonBugCategories.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Core/CommonBugCategories.cpp
+++ clang/lib/StaticAnalyzer/Core/CommonBugCategories.cpp
@@ -23,6 +23,7 @@
 const char *const CXXMoveSemantics = "C++ move semantics";
 const char *const SecurityError = "Security error";
 const char *const UnusedCode = "Unused code";
+const char *const TaintedData = "Tainted data used";
 } // namespace categories
 } // namespace ento
 } // namespace clang
Index: clang/lib/StaticAnalyzer/Checkers/VLASizeChecker.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Checkers/VLASizeChecker.cpp
+++ clang/lib/StaticAnalyzer/Checkers/VLASizeChecker.cpp
@@ -35,13 +35,8 @@
     : public Checker<check::PreStmt<DeclStmt>,
                      check::PreStmt<UnaryExprOrTypeTraitExpr>> {
   mutable std::unique_ptr<BugType> BT;
-  enum VLASize_Kind {
-    VLA_Garbage,
-    VLA_Zero,
-    VLA_Tainted,
-    VLA_Negative,
-    VLA_Overflow
-  };
+  mutable std::unique_ptr<BugType> TaintBT;
+  enum VLASize_Kind { VLA_Garbage, VLA_Zero, VLA_Negative, VLA_Overflow };
 
   /// Check a VLA for validity.
   /// Every dimension of the array and the total size is checked for validity.
@@ -55,8 +50,10 @@
                                     const Expr *SizeE) const;
 
   void reportBug(VLASize_Kind Kind, const Expr *SizeE, ProgramStateRef State,
-                 CheckerContext &C,
-                 std::unique_ptr<BugReporterVisitor> Visitor = nullptr) const;
+                 CheckerContext &C) const;
+
+  void reportTaintBug(const Expr *SizeE, ProgramStateRef State,
+                      CheckerContext &C, SVal TaintedSVal) const;
 
 public:
   void checkPreStmt(const DeclStmt *DS, CheckerContext &C) const;
@@ -167,8 +164,7 @@
 
   // Check if the size is tainted.
   if (isTainted(State, SizeV)) {
-    reportBug(VLA_Tainted, SizeE, nullptr, C,
-              std::make_unique<TaintBugVisitor>(SizeV));
+    reportTaintBug(SizeE, State, C, SizeV);
     return nullptr;
   }
 
@@ -209,17 +205,45 @@
   return State;
 }
 
-void VLASizeChecker::reportBug(
-    VLASize_Kind Kind, const Expr *SizeE, ProgramStateRef State,
-    CheckerContext &C, std::unique_ptr<BugReporterVisitor> Visitor) const {
+void VLASizeChecker::reportTaintBug(const Expr *SizeE, ProgramStateRef State,
+                                    CheckerContext &C, SVal TaintedSVal) const {
+  // Generate an error node.
+  ExplodedNode *N = C.generateErrorNode(State);
+  if (!N)
+    return;
+
+  if (!TaintBT)
+    TaintBT.reset(
+        new BugType(this, "Dangerous variable-length array (VLA) declaration",
+                    categories::TaintedData));
+
+  SmallString<256> buf;
+  llvm::raw_svector_ostream os(buf);
+  os << "Declared variable-length array (VLA) ";
+  os << "has tainted size";
+
+  auto report = std::make_unique<PathSensitiveBugReport>(*TaintBT, os.str(), N);
+  report->addRange(SizeE->getSourceRange());
+  bugreporter::trackExpressionValue(N, SizeE, *report);
+  // The vla size may be a complex expression where multiple memory locations
+  // are tainted.
+  std::vector<SymbolRef> TaintedSyms = getTaintedSymbols(State, TaintedSVal);
+  for (auto Sym : TaintedSyms)
+    report->markInteresting(Sym);
+  C.emitReport(std::move(report));
+}
+
+void VLASizeChecker::reportBug(VLASize_Kind Kind, const Expr *SizeE,
+                               ProgramStateRef State, CheckerContext &C) const {
   // Generate an error node.
   ExplodedNode *N = C.generateErrorNode(State);
   if (!N)
     return;
 
   if (!BT)
-    BT.reset(new BuiltinBug(
-        this, "Dangerous variable-length array (VLA) declaration"));
+    BT.reset(new BugType(this,
+                         "Dangerous variable-length array (VLA) declaration",
+                         categories::LogicError));
 
   SmallString<256> buf;
   llvm::raw_svector_ostream os(buf);
@@ -231,9 +255,6 @@
   case VLA_Zero:
     os << "has zero size";
     break;
-  case VLA_Tainted:
-    os << "has tainted size";
-    break;
   case VLA_Negative:
     os << "has negative size";
     break;
@@ -243,7 +264,6 @@
   }
 
   auto report = std::make_unique<PathSensitiveBugReport>(*BT, os.str(), N);
-  report->addVisitor(std::move(Visitor));
   report->addRange(SizeE->getSourceRange());
   bugreporter::trackExpressionValue(N, SizeE, *report);
   C.emitReport(std::move(report));
Index: clang/lib/StaticAnalyzer/Checkers/Taint.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Checkers/Taint.cpp
+++ clang/lib/StaticAnalyzer/Checkers/Taint.cpp
@@ -146,41 +146,93 @@
 
 bool taint::isTainted(ProgramStateRef State, const Stmt *S,
                       const LocationContext *LCtx, TaintTagType Kind) {
-  SVal val = State->getSVal(S, LCtx);
-  return isTainted(State, val, Kind);
+  return !getTaintedSymbols(State, S, LCtx, Kind, true).empty();
 }
 
 bool taint::isTainted(ProgramStateRef State, SVal V, TaintTagType Kind) {
-  if (SymbolRef Sym = V.getAsSymbol())
-    return isTainted(State, Sym, Kind);
-  if (const MemRegion *Reg = V.getAsRegion())
-    return isTainted(State, Reg, Kind);
-  return false;
+  return !getTaintedSymbols(State, V, Kind, true).empty();
 }
 
 bool taint::isTainted(ProgramStateRef State, const MemRegion *Reg,
                       TaintTagType K) {
-  if (!Reg)
-    return false;
+  return !getTaintedSymbols(State, Reg, K, true).empty();
+}
+
+bool taint::isTainted(ProgramStateRef State, SymbolRef Sym, TaintTagType Kind) {
+  return !getTaintedSymbols(State, Sym, Kind, true).empty();
+}
+
+std::vector<SymbolRef> taint::getTaintedSymbols(ProgramStateRef State,
+                                                const Stmt *S,
+                                                const LocationContext *LCtx,
+                                                TaintTagType Kind,
+                                                bool returnFirstOnly) {
+  SVal val = State->getSVal(S, LCtx);
+  return getTaintedSymbols(State, val, Kind);
+}
+
+std::vector<SymbolRef> taint::getTaintedSymbols(ProgramStateRef State, SVal V,
+                                                TaintTagType Kind,
+                                                bool returnFirstOnly) {
+  if (SymbolRef Sym = V.getAsSymbol())
+    return getTaintedSymbols(State, Sym, Kind);
+  if (const MemRegion *Reg = V.getAsRegion())
+    return getTaintedSymbols(State, Reg, Kind);
+  return {};
+}
 
+std::vector<SymbolRef> taint::getTaintedSymbols(ProgramStateRef State,
+                                                const MemRegion *Reg,
+                                                TaintTagType K,
+                                                bool returnFirstOnly) {
+  std::vector<SymbolRef> TaintedSymbols;
+  if (!Reg)
+    return TaintedSymbols;
   // Element region (array element) is tainted if either the base or the offset
   // are tainted.
-  if (const ElementRegion *ER = dyn_cast<ElementRegion>(Reg))
-    return isTainted(State, ER->getSuperRegion(), K) ||
-           isTainted(State, ER->getIndex(), K);
+  if (const ElementRegion *ER = dyn_cast<ElementRegion>(Reg)) {
+    std::vector<SymbolRef> TaintedIndex =
+        getTaintedSymbols(State, ER->getIndex(), K);
+    TaintedSymbols.insert(TaintedSymbols.begin(), TaintedIndex.begin(),
+                          TaintedIndex.end());
+    if (returnFirstOnly && !TaintedSymbols.empty())
+      return TaintedSymbols; // return early if needed
+    std::vector<SymbolRef> TaintedSuperRegion =
+        getTaintedSymbols(State, ER->getSuperRegion(), K);
+    TaintedSymbols.insert(TaintedSymbols.begin(), TaintedSuperRegion.begin(),
+                          TaintedSuperRegion.end());
+    if (returnFirstOnly && !TaintedSymbols.empty())
+      return TaintedSymbols; // return early if needed
+  }
 
-  if (const SymbolicRegion *SR = dyn_cast<SymbolicRegion>(Reg))
-    return isTainted(State, SR->getSymbol(), K);
+  if (const SymbolicRegion *SR = dyn_cast<SymbolicRegion>(Reg)) {
+    std::vector<SymbolRef> TaintedRegions =
+        getTaintedSymbols(State, SR->getSymbol(), K);
+    TaintedSymbols.insert(TaintedSymbols.begin(), TaintedRegions.begin(),
+                          TaintedRegions.end());
+    if (returnFirstOnly && !TaintedSymbols.empty())
+      return TaintedSymbols; // return early if needed
+  }
 
-  if (const SubRegion *ER = dyn_cast<SubRegion>(Reg))
-    return isTainted(State, ER->getSuperRegion(), K);
+  if (const SubRegion *ER = dyn_cast<SubRegion>(Reg)) {
+    std::vector<SymbolRef> TaintedSubRegions =
+        getTaintedSymbols(State, ER->getSuperRegion(), K);
+    TaintedSymbols.insert(TaintedSymbols.begin(), TaintedSubRegions.begin(),
+                          TaintedSubRegions.end());
+    if (returnFirstOnly && !TaintedSymbols.empty())
+      return TaintedSymbols; // return early if needed
+  }
 
-  return false;
+  return TaintedSymbols;
 }
 
-bool taint::isTainted(ProgramStateRef State, SymbolRef Sym, TaintTagType Kind) {
+std::vector<SymbolRef> taint::getTaintedSymbols(ProgramStateRef State,
+                                                SymbolRef Sym,
+                                                TaintTagType Kind,
+                                                bool returnFirstOnly) {
+  std::vector<SymbolRef> TaintedSymbols;
   if (!Sym)
-    return false;
+    return TaintedSymbols;
 
   // Traverse all the symbols this symbol depends on to see if any are tainted.
   for (SymExpr::symbol_iterator SI = Sym->symbol_begin(),
@@ -190,18 +242,25 @@
       continue;
 
     if (const TaintTagType *Tag = State->get<TaintMap>(*SI)) {
-      if (*Tag == Kind)
-        return true;
+      if (*Tag == Kind) {
+        TaintedSymbols.push_back(*SI);
+        if (returnFirstOnly && !TaintedSymbols.empty())
+          return TaintedSymbols; // return early if needed
+      }
     }
 
     if (const auto *SD = dyn_cast<SymbolDerived>(*SI)) {
       // If this is a SymbolDerived with a tainted parent, it's also tainted.
-      if (isTainted(State, SD->getParentSymbol(), Kind))
-        return true;
+      std::vector<SymbolRef> TaintedParents =
+          getTaintedSymbols(State, SD->getParentSymbol(), Kind);
+      TaintedSymbols.insert(TaintedSymbols.begin(), TaintedParents.begin(),
+                            TaintedParents.end());
+      if (returnFirstOnly && !TaintedSymbols.empty())
+        return TaintedSymbols; // return early if needed
 
       // If this is a SymbolDerived with the same parent symbol as another
-      // tainted SymbolDerived and a region that's a sub-region of that tainted
-      // symbol, it's also tainted.
+      // tainted SymbolDerived and a region that's a sub-region of that
+      // tainted symbol, it's also tainted.
       if (const TaintedSubRegions *Regs =
               State->get<DerivedSymTaint>(SD->getParentSymbol())) {
         const TypedValueRegion *R = SD->getRegion();
@@ -210,46 +269,34 @@
           // complete. For example, this would not currently identify
           // overlapping fields in a union as tainted. To identify this we can
           // check for overlapping/nested byte offsets.
-          if (Kind == I.second && R->isSubRegionOf(I.first))
-            return true;
+          if (Kind == I.second && R->isSubRegionOf(I.first)) {
+            TaintedSymbols.push_back(SD->getParentSymbol());
+            if (returnFirstOnly && !TaintedSymbols.empty())
+              return TaintedSymbols; // return early if needed
+          }
         }
       }
     }
 
     // If memory region is tainted, data is also tainted.
     if (const auto *SRV = dyn_cast<SymbolRegionValue>(*SI)) {
-      if (isTainted(State, SRV->getRegion(), Kind))
-        return true;
+      std::vector<SymbolRef> TaintedRegions =
+          getTaintedSymbols(State, SRV->getRegion(), Kind);
+      TaintedSymbols.insert(TaintedSymbols.begin(), TaintedRegions.begin(),
+                            TaintedRegions.end());
+      if (returnFirstOnly && !TaintedSymbols.empty())
+        return TaintedSymbols; // return early if needed
     }
 
     // If this is a SymbolCast from a tainted value, it's also tainted.
     if (const auto *SC = dyn_cast<SymbolCast>(*SI)) {
-      if (isTainted(State, SC->getOperand(), Kind))
-        return true;
+      std::vector<SymbolRef> TaintedCasts =
+          getTaintedSymbols(State, SC->getOperand(), Kind);
+      TaintedSymbols.insert(TaintedSymbols.begin(), TaintedCasts.begin(),
+                            TaintedCasts.end());
+      if (returnFirstOnly && !TaintedSymbols.empty())
+        return TaintedSymbols; // return early if needed
     }
   }
-
-  return false;
-}
-
-PathDiagnosticPieceRef TaintBugVisitor::VisitNode(const ExplodedNode *N,
-                                                  BugReporterContext &BRC,
-                                                  PathSensitiveBugReport &BR) {
-
-  // Find the ExplodedNode where the taint was first introduced
-  if (!isTainted(N->getState(), V) ||
-      isTainted(N->getFirstPred()->getState(), V))
-    return nullptr;
-
-  const Stmt *S = N->getStmtForDiagnostics();
-  if (!S)
-    return nullptr;
-
-  const LocationContext *NCtx = N->getLocationContext();
-  PathDiagnosticLocation L =
-      PathDiagnosticLocation::createBegin(S, BRC.getSourceManager(), NCtx);
-  if (!L.isValid() || !L.asLocation().isValid())
-    return nullptr;
-
-  return std::make_shared<PathDiagnosticEventPiece>(L, "Taint originated here");
+  return TaintedSymbols;
 }
Index: clang/lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp
+++ clang/lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp
@@ -26,12 +26,14 @@
 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
 #include "clang/StaticAnalyzer/Core/PathSensitive/ProgramStateTrait.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/YAMLTraits.h"
 
 #include <limits>
 #include <memory>
 #include <optional>
 #include <utility>
+#include <vector>
 
 #define DEBUG_TYPE "taint-checker"
 
@@ -114,47 +116,110 @@
   return false;
 }
 
-SVal getPointeeOf(const CheckerContext &C, Loc LValue) {
-  const QualType ArgTy = LValue.getType(C.getASTContext());
+SVal getPointeeOf(ProgramStateRef State, Loc LValue) {
+  const QualType ArgTy = LValue.getType(State->getStateManager().getContext());
   if (!ArgTy->isPointerType() || !ArgTy->getPointeeType()->isVoidType())
-    return C.getState()->getSVal(LValue);
+    return State->getSVal(LValue);
 
   // Do not dereference void pointers. Treat them as byte pointers instead.
   // FIXME: we might want to consider more than just the first byte.
-  return C.getState()->getSVal(LValue, C.getASTContext().CharTy);
+  return State->getSVal(LValue, State->getStateManager().getContext().CharTy);
 }
 
 /// Given a pointer/reference argument, return the value it refers to.
-std::optional<SVal> getPointeeOf(const CheckerContext &C, SVal Arg) {
+std::optional<SVal> getPointeeOf(ProgramStateRef State, SVal Arg) {
   if (auto LValue = Arg.getAs<Loc>())
-    return getPointeeOf(C, *LValue);
+    return getPointeeOf(State, *LValue);
   return std::nullopt;
 }
 
 /// Given a pointer, return the SVal of its pointee or if it is tainted,
 /// otherwise return the pointer's SVal if tainted.
 /// Also considers stdin as a taint source.
-std::optional<SVal> getTaintedPointeeOrPointer(const CheckerContext &C,
+std::optional<SVal> getTaintedPointeeOrPointer(ProgramStateRef State,
                                                SVal Arg) {
-  const ProgramStateRef State = C.getState();
-
-  if (auto Pointee = getPointeeOf(C, Arg))
+  if (auto Pointee = getPointeeOf(State, Arg))
     if (isTainted(State, *Pointee)) // FIXME: isTainted(...) ? Pointee : None;
       return Pointee;
 
   if (isTainted(State, Arg))
     return Arg;
+  return std::nullopt;
+}
 
-  // FIXME: This should be done by the isTainted() API.
-  if (isStdin(Arg, C.getASTContext()))
-    return Arg;
+bool isTaintedOrPointsToTainted(ProgramStateRef State, SVal ExprSVal) {
+  return getTaintedPointeeOrPointer(State, ExprSVal).has_value();
+}
 
-  return std::nullopt;
+/// Helps in printing taint diagnostics.
+/// Marks the incoming parameters of a function interesting (to be printed)
+/// when the return value, or the outgoing parameters are tainted.
+const NoteTag *taintOriginTrackerTag(CheckerContext &C,
+                                     std::vector<SymbolRef> TaintedSymbols,
+                                     std::vector<ArgIdxTy> TaintedArgs,
+                                     const LocationContext *CallLocation) {
+  return C.getNoteTag([TaintedSymbols = std::move(TaintedSymbols),
+                       TaintedArgs = std::move(TaintedArgs), CallLocation](
+                          PathSensitiveBugReport &BR) -> std::string {
+    SmallString<256> Msg;
+    // We give diagnostics only for taint related reports
+    if (!BR.isInteresting(CallLocation) ||
+        BR.getBugType().getCategory() != categories::TaintedData) {
+      return "";
+    }
+    if (TaintedSymbols.empty())
+      return "Taint originated here";
+
+    for (auto Sym : TaintedSymbols) {
+      BR.markInteresting(Sym);
+    }
+    for (auto Arg : TaintedArgs) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Taint Propagated from argument " << Arg + 1 << "\n");
+    }
+    return "";
+  });
 }
 
-bool isTaintedOrPointsToTainted(const Expr *E, const ProgramStateRef &State,
-                                CheckerContext &C) {
-  return getTaintedPointeeOrPointer(C, C.getSVal(E)).has_value();
+/// Helps in printing taint diagnostics.
+/// Marks the function interesting (to be printed)
+/// when the return value, or the outgoing parameters are tainted.
+const NoteTag *taintPropagationExplainerTag(
+    CheckerContext &C, std::vector<SymbolRef> TaintedSymbols,
+    std::vector<ArgIdxTy> TaintedArgs, const LocationContext *CallLocation) {
+  assert(TaintedSymbols.size() == TaintedArgs.size());
+  return C.getNoteTag([TaintedSymbols = std::move(TaintedSymbols),
+                       TaintedArgs = std::move(TaintedArgs), CallLocation](
+                          PathSensitiveBugReport &BR) -> std::string {
+    SmallString<256> Msg;
+    llvm::raw_svector_ostream Out(Msg);
+    // We give diagnostics only for taint related reports
+    if (TaintedSymbols.empty() ||
+        BR.getBugType().getCategory() != categories::TaintedData) {
+      return "";
+    }
+    int nofTaintedArgs = 0;
+    for (auto [Idx, Sym] : llvm::enumerate(TaintedSymbols)) {
+      if (BR.isInteresting(Sym)) {
+        BR.markInteresting(CallLocation);
+        if (TaintedArgs[Idx] != ReturnValueIndex) {
+          LLVM_DEBUG(llvm::dbgs() << "Taint Propagated to argument "
+                                  << TaintedArgs[Idx] + 1 << "\n");
+          if (nofTaintedArgs == 0)
+            Out << "Taint propagated to the ";
+          else
+            Out << ", ";
+          Out << TaintedArgs[Idx] + 1
+              << llvm::getOrdinalSuffix(TaintedArgs[Idx] + 1) << " argument";
+          nofTaintedArgs++;
+        } else {
+          LLVM_DEBUG(llvm::dbgs() << "Taint Propagated to return value.\n");
+          Out << "Taint propagated to the return value";
+        }
+      }
+    }
+    return std::string(Out.str());
+  });
 }
 
 /// ArgSet is used to describe arguments relevant for taint detection or
@@ -193,7 +258,7 @@
   ArgSet SinkArgs;
   /// Arguments which should be sanitized on function return.
   ArgSet FilterArgs;
-  /// Arguments which can participate in taint propagationa. If any of the
+  /// Arguments which can participate in taint propagation. If any of the
   /// arguments in PropSrcArgs is tainted, all arguments in  PropDstArgs should
   /// be tainted.
   ArgSet PropSrcArgs;
@@ -343,7 +408,7 @@
                                CheckerContext &C) const;
 
 private:
-  const BugType BT{this, "Use of Untrusted Data", "Untrusted Data"};
+  const BugType BT{this, "Use of Untrusted Data", categories::TaintedData};
 
   bool checkUncontrolledFormatString(const CallEvent &Call,
                                      CheckerContext &C) const;
@@ -351,7 +416,7 @@
   void taintUnsafeSocketProtocol(const CallEvent &Call,
                                  CheckerContext &C) const;
 
-  /// Default taint rules are initilized with the help of a CheckerContext to
+  /// Default taint rules are initalized with the help of a CheckerContext to
   /// access the names of built-in functions like memcpy.
   void initTaintRules(CheckerContext &C) const;
 
@@ -788,22 +853,39 @@
     llvm::dbgs() << "> actually wants to taint arg index: " << I << '\n';
   });
 
+  const NoteTag *InjectionTag = nullptr;
+  std::vector<SymbolRef> TaintedSymbols;
+  std::vector<ArgIdxTy> TaintedIndexes;
   for (ArgIdxTy ArgNum : *TaintArgs) {
     // Special handling for the tainted return value.
     if (ArgNum == ReturnValueIndex) {
       State = addTaint(State, Call.getReturnValue());
+      std::vector<SymbolRef> TaintedSyms =
+          getTaintedSymbols(State, Call.getReturnValue());
+      if (!TaintedSyms.empty()) {
+        TaintedSymbols.push_back(TaintedSyms[0]);
+        TaintedIndexes.push_back(ArgNum);
+      }
       continue;
     }
-
     // The arguments are pointer arguments. The data they are pointing at is
     // tainted after the call.
-    if (auto V = getPointeeOf(C, Call.getArgSVal(ArgNum)))
+    if (auto V = getPointeeOf(State, Call.getArgSVal(ArgNum))) {
       State = addTaint(State, *V);
+      std::vector<SymbolRef> TaintedSyms = getTaintedSymbols(State, *V);
+      if (!TaintedSyms.empty()) {
+        TaintedSymbols.push_back(TaintedSyms[0]);
+        TaintedIndexes.push_back(ArgNum);
+      }
+    }
   }
-
+  // Create a NoteTag callback, which prints to the user where the taintedness
+  // was propagated to.
+  InjectionTag = taintPropagationExplainerTag(C, TaintedSymbols, TaintedIndexes,
+                                              Call.getCalleeStackFrame(0));
   // Clear up the taint info from the state.
   State = State->remove<TaintArgsOnPostVisit>(CurrentFrame);
-  C.addTransition(State);
+  C.addTransition(State, InjectionTag);
 }
 
 void GenericTaintChecker::printState(raw_ostream &Out, ProgramStateRef State,
@@ -826,7 +908,11 @@
 
   /// Check for taint sinks.
   ForEachCallArg([this, &Checker, &C, &State](ArgIdxTy I, const Expr *E, SVal) {
-    if (SinkArgs.contains(I) && isTaintedOrPointsToTainted(E, State, C))
+    // Add taintedness to stdin parameters
+    if (isStdin(C.getSVal(E), C.getASTContext())) {
+      State = addTaint(State, C.getSVal(E));
+    }
+    if (SinkArgs.contains(I) && isTaintedOrPointsToTainted(State, C.getSVal(E)))
       Checker.generateReportIfTainted(E, SinkMsg.value_or(MsgCustomSink), C);
   });
 
@@ -834,7 +920,7 @@
   ForEachCallArg([this, &C, &State](ArgIdxTy I, const Expr *E, SVal S) {
     if (FilterArgs.contains(I)) {
       State = removeTaint(State, S);
-      if (auto P = getPointeeOf(C, S))
+      if (auto P = getPointeeOf(State, S))
         State = removeTaint(State, *P);
     }
   });
@@ -843,11 +929,27 @@
   /// A rule is relevant if PropSrcArgs is empty, or if any of its signified
   /// args are tainted in context of the current CallEvent.
   bool IsMatching = PropSrcArgs.isEmpty();
-  ForEachCallArg(
-      [this, &C, &IsMatching, &State](ArgIdxTy I, const Expr *E, SVal) {
-        IsMatching = IsMatching || (PropSrcArgs.contains(I) &&
-                                    isTaintedOrPointsToTainted(E, State, C));
-      });
+  std::vector<SymbolRef> TaintedSymbols;
+  std::vector<ArgIdxTy> TaintedIndexes;
+  ForEachCallArg([this, &C, &IsMatching, &State, &TaintedSymbols,
+                  &TaintedIndexes](ArgIdxTy I, const Expr *E, SVal) {
+    IsMatching =
+        IsMatching || (PropSrcArgs.contains(I) &&
+                       isTaintedOrPointsToTainted(State, C.getSVal(E)));
+    std::optional<SVal> TaintedSVal =
+        getTaintedPointeeOrPointer(State, C.getSVal(E));
+
+    // We track back tainted arguments except for stdin
+    if (TaintedSVal && !isStdin(*TaintedSVal, C.getASTContext())) {
+      std::vector<SymbolRef> TaintedArgSyms =
+          getTaintedSymbols(State, *TaintedSVal);
+      if (!TaintedArgSyms.empty()) {
+        TaintedSymbols.insert(TaintedSymbols.begin(), TaintedArgSyms.begin(),
+                              TaintedArgSyms.end());
+        TaintedIndexes.push_back(I);
+      }
+    }
+  });
 
   if (!IsMatching)
     return;
@@ -890,7 +992,9 @@
 
   if (!Result.isEmpty())
     State = State->set<TaintArgsOnPostVisit>(C.getStackFrame(), Result);
-  C.addTransition(State);
+  const NoteTag *InjectionTag = taintOriginTrackerTag(
+      C, TaintedSymbols, TaintedIndexes, Call.getCalleeStackFrame(0));
+  C.addTransition(State, InjectionTag);
 }
 
 bool GenericTaintRule::UntrustedEnv(CheckerContext &C) {
@@ -902,7 +1006,8 @@
 bool GenericTaintChecker::generateReportIfTainted(const Expr *E, StringRef Msg,
                                                   CheckerContext &C) const {
   assert(E);
-  std::optional<SVal> TaintedSVal{getTaintedPointeeOrPointer(C, C.getSVal(E))};
+  std::optional<SVal> TaintedSVal{
+      getTaintedPointeeOrPointer(C.getState(), C.getSVal(E))};
 
   if (!TaintedSVal)
     return false;
@@ -911,7 +1016,12 @@
   if (ExplodedNode *N = C.generateNonFatalErrorNode()) {
     auto report = std::make_unique<PathSensitiveBugReport>(BT, Msg, N);
     report->addRange(E->getSourceRange());
-    report->addVisitor(std::make_unique<TaintBugVisitor>(*TaintedSVal));
+    std::vector<SymbolRef> TaintedSyms =
+        getTaintedSymbols(C.getState(), *TaintedSVal);
+    for (auto TaintedSym : TaintedSyms) {
+      report->markInteresting(TaintedSym);
+    }
+
     C.emitReport(std::move(report));
     return true;
   }
Index: clang/lib/StaticAnalyzer/Checkers/DivZeroChecker.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Checkers/DivZeroChecker.cpp
+++ clang/lib/StaticAnalyzer/Checkers/DivZeroChecker.cpp
@@ -25,9 +25,13 @@
 
 namespace {
 class DivZeroChecker : public Checker< check::PreStmt<BinaryOperator> > {
-  mutable std::unique_ptr<BuiltinBug> BT;
-  void reportBug(const char *Msg, ProgramStateRef StateZero, CheckerContext &C,
-                 std::unique_ptr<BugReporterVisitor> Visitor = nullptr) const;
+  mutable std::unique_ptr<BugType> BT;
+  mutable std::unique_ptr<BugType> TaintBT;
+  void reportBug(StringRef Msg, ProgramStateRef StateZero,
+                 CheckerContext &C) const;
+  void reportTaintBug(StringRef Msg, ProgramStateRef StateZero,
+                      CheckerContext &C,
+                      std::vector<SymbolRef> TaintedSyms) const;
 
 public:
   void checkPreStmt(const BinaryOperator *B, CheckerContext &C) const;
@@ -41,20 +45,34 @@
   return nullptr;
 }
 
-void DivZeroChecker::reportBug(
-    const char *Msg, ProgramStateRef StateZero, CheckerContext &C,
-    std::unique_ptr<BugReporterVisitor> Visitor) const {
+void DivZeroChecker::reportBug(StringRef Msg, ProgramStateRef StateZero,
+                               CheckerContext &C) const {
   if (ExplodedNode *N = C.generateErrorNode(StateZero)) {
     if (!BT)
-      BT.reset(new BuiltinBug(this, "Division by zero"));
+      BT.reset(new BugType(this, "Division by zero", categories::LogicError));
 
     auto R = std::make_unique<PathSensitiveBugReport>(*BT, Msg, N);
-    R->addVisitor(std::move(Visitor));
     bugreporter::trackExpressionValue(N, getDenomExpr(N), *R);
     C.emitReport(std::move(R));
   }
 }
 
+void DivZeroChecker::reportTaintBug(StringRef Msg, ProgramStateRef StateZero,
+                                    CheckerContext &C,
+                                    std::vector<SymbolRef> TaintedSyms) const {
+  if (ExplodedNode *N = C.generateErrorNode(StateZero)) {
+    if (!TaintBT)
+      TaintBT.reset(
+          new BugType(this, "Division by zero", categories::TaintedData));
+
+    auto R = std::make_unique<PathSensitiveBugReport>(*TaintBT, Msg, N);
+    bugreporter::trackExpressionValue(N, getDenomExpr(N), *R);
+    for (auto Sym : TaintedSyms)
+      R->markInteresting(Sym);
+    C.emitReport(std::move(R));
+  }
+}
+
 void DivZeroChecker::checkPreStmt(const BinaryOperator *B,
                                   CheckerContext &C) const {
   BinaryOperator::Opcode Op = B->getOpcode();
@@ -86,11 +104,13 @@
     return;
   }
 
-  bool TaintedD = isTainted(C.getState(), *DV);
-  if ((stateNotZero && stateZero && TaintedD)) {
-    reportBug("Division by a tainted value, possibly zero", stateZero, C,
-              std::make_unique<taint::TaintBugVisitor>(*DV));
-    return;
+  if ((stateNotZero && stateZero)) {
+    std::vector<SymbolRef> taintedSyms = getTaintedSymbols(C.getState(), *DV);
+    if (!taintedSyms.empty()) {
+      reportTaintBug("Division by a tainted value, possibly zero", stateZero, C,
+                     taintedSyms);
+      return;
+    }
   }
 
   // If we get here, then the denom should not be zero. We abandon the implicit
Index: clang/lib/StaticAnalyzer/Checkers/ArrayBoundCheckerV2.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Checkers/ArrayBoundCheckerV2.cpp
+++ clang/lib/StaticAnalyzer/Checkers/ArrayBoundCheckerV2.cpp
@@ -33,11 +33,14 @@
 class ArrayBoundCheckerV2 :
     public Checker<check::Location> {
   mutable std::unique_ptr<BuiltinBug> BT;
+  mutable std::unique_ptr<BugType> TaintBT;
 
-  enum OOB_Kind { OOB_Precedes, OOB_Excedes, OOB_Tainted };
+  enum OOB_Kind { OOB_Precedes, OOB_Excedes };
 
-  void reportOOB(CheckerContext &C, ProgramStateRef errorState, OOB_Kind kind,
-                 std::unique_ptr<BugReporterVisitor> Visitor = nullptr) const;
+  void reportOOB(CheckerContext &C, ProgramStateRef errorState,
+                 OOB_Kind kind) const;
+  void reportTaintOOB(CheckerContext &C, ProgramStateRef errorState,
+                      SVal TaintedSVal) const;
 
 public:
   void checkLocation(SVal l, bool isLoad, const Stmt*S,
@@ -207,8 +210,7 @@
     if (state_exceedsUpperBound && state_withinUpperBound) {
       SVal ByteOffset = rawOffset.getByteOffset();
       if (isTainted(state, ByteOffset)) {
-        reportOOB(checkerContext, state_exceedsUpperBound, OOB_Tainted,
-                  std::make_unique<TaintBugVisitor>(ByteOffset));
+        reportTaintOOB(checkerContext, state_exceedsUpperBound, ByteOffset);
         return;
       }
     } else if (state_exceedsUpperBound) {
@@ -226,10 +228,37 @@
 
   checkerContext.addTransition(state);
 }
+void ArrayBoundCheckerV2::reportTaintOOB(CheckerContext &checkerContext,
+                                         ProgramStateRef errorState,
+                                         SVal TaintedSVal) const {
+  ExplodedNode *errorNode = checkerContext.generateErrorNode(errorState);
+  if (!errorNode)
+    return;
+
+  if (!TaintBT)
+    TaintBT.reset(
+        new BugType(this, "Out-of-bound access", categories::TaintedData));
 
-void ArrayBoundCheckerV2::reportOOB(
-    CheckerContext &checkerContext, ProgramStateRef errorState, OOB_Kind kind,
-    std::unique_ptr<BugReporterVisitor> Visitor) const {
+  SmallString<256> buf;
+  llvm::raw_svector_ostream os(buf);
+  os << "Out of bound memory access (index is tainted)";
+  auto BR =
+      std::make_unique<PathSensitiveBugReport>(*TaintBT, os.str(), errorNode);
+
+  std::vector<SymbolRef> TaintedSyms =
+      getTaintedSymbols(errorState, TaintedSVal);
+  // Mark all tainted symbols interesting
+  // to track back the propagation of taintedness.
+  for (auto Sym : TaintedSyms) {
+    BR->markInteresting(Sym);
+  }
+
+  checkerContext.emitReport(std::move(BR));
+}
+
+void ArrayBoundCheckerV2::reportOOB(CheckerContext &checkerContext,
+                                    ProgramStateRef errorState,
+                                    OOB_Kind kind) const {
 
   ExplodedNode *errorNode = checkerContext.generateErrorNode(errorState);
   if (!errorNode)
@@ -251,13 +280,8 @@
   case OOB_Excedes:
     os << "(access exceeds upper limit of memory block)";
     break;
-  case OOB_Tainted:
-    os << "(index is tainted)";
-    break;
   }
-
   auto BR = std::make_unique<PathSensitiveBugReport>(*BT, os.str(), errorNode);
-  BR->addVisitor(std::move(Visitor));
   checkerContext.emitReport(std::move(BR));
 }
 
Index: clang/include/clang/StaticAnalyzer/Core/BugReporter/CommonBugCategories.h
===================================================================
--- clang/include/clang/StaticAnalyzer/Core/BugReporter/CommonBugCategories.h
+++ clang/include/clang/StaticAnalyzer/Core/BugReporter/CommonBugCategories.h
@@ -22,6 +22,7 @@
 extern const char *const CXXMoveSemantics;
 extern const char *const SecurityError;
 extern const char *const UnusedCode;
+extern const char *const TaintedData;
 } // namespace categories
 } // namespace ento
 } // namespace clang
Index: clang/include/clang/StaticAnalyzer/Checkers/Taint.h
===================================================================
--- clang/include/clang/StaticAnalyzer/Checkers/Taint.h
+++ clang/include/clang/StaticAnalyzer/Checkers/Taint.h
@@ -79,26 +79,33 @@
 bool isTainted(ProgramStateRef State, const MemRegion *Reg,
                TaintTagType Kind = TaintTagGeneric);
 
+/// Returns the tainted Symbols for a given Statement and state.
+std::vector<SymbolRef> getTaintedSymbols(ProgramStateRef State, const Stmt *S,
+                                         const LocationContext *LCtx,
+                                         TaintTagType Kind = TaintTagGeneric,
+                                         bool returnFirstOnly = false);
+
+/// Returns the tainted Symbols for a given SVal and state.
+std::vector<SymbolRef> getTaintedSymbols(ProgramStateRef State, SVal V,
+                                         TaintTagType Kind = TaintTagGeneric,
+                                         bool returnFirstOnly = false);
+
+/// Returns the tainted Symbols for a SymbolRef and state.
+std::vector<SymbolRef> getTaintedSymbols(ProgramStateRef State, SymbolRef Sym,
+                                         TaintTagType Kind = TaintTagGeneric,
+                                         bool returnFirstOnly = false);
+
+/// Returns the tainted (index, super/sub region, symbolic region) symbols
+/// for a given memory region.
+std::vector<SymbolRef> getTaintedSymbols(ProgramStateRef State,
+                                         const MemRegion *Reg,
+                                         TaintTagType Kind = TaintTagGeneric,
+                                         bool returnFirstOnly = false);
+
 void printTaint(ProgramStateRef State, raw_ostream &Out, const char *nl = "\n",
                 const char *sep = "");
 
 LLVM_DUMP_METHOD void dumpTaint(ProgramStateRef State);
-
-/// The bug visitor prints a diagnostic message at the location where a given
-/// variable was tainted.
-class TaintBugVisitor final : public BugReporterVisitor {
-private:
-  const SVal V;
-
-public:
-  TaintBugVisitor(const SVal V) : V(V) {}
-  void Profile(llvm::FoldingSetNodeID &ID) const override { ID.Add(V); }
-
-  PathDiagnosticPieceRef VisitNode(const ExplodedNode *N,
-                                   BugReporterContext &BRC,
-                                   PathSensitiveBugReport &BR) override;
-};
-
 } // namespace taint
 } // namespace ento
 } // namespace clang
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to