AntonBikineev updated this revision to Diff 62317.

http://reviews.llvm.org/D21834

Files:
  include/clang/AST/Stmt.h
  include/clang/Sema/Sema.h
  lib/AST/ASTImporter.cpp
  lib/AST/Stmt.cpp
  lib/Analysis/BodyFarm.cpp
  lib/CodeGen/CGStmt.cpp
  lib/Sema/SemaStmt.cpp
  lib/Sema/TreeTransform.h
  lib/Serialization/ASTReaderStmt.cpp
  lib/Serialization/ASTWriterStmt.cpp

Index: lib/Serialization/ASTWriterStmt.cpp
===================================================================
--- lib/Serialization/ASTWriterStmt.cpp
+++ lib/Serialization/ASTWriterStmt.cpp
@@ -129,6 +129,7 @@
 void ASTStmtWriter::VisitIfStmt(IfStmt *S) {
   VisitStmt(S);
   Record.push_back(S->isConstexpr());
+  Record.AddStmt(S->getInit());
   Record.AddDeclRef(S->getConditionVariable());
   Record.AddStmt(S->getCond());
   Record.AddStmt(S->getThen());
@@ -140,6 +141,7 @@
 
 void ASTStmtWriter::VisitSwitchStmt(SwitchStmt *S) {
   VisitStmt(S);
+  Record.AddStmt(S->getInit());
   Record.AddDeclRef(S->getConditionVariable());
   Record.AddStmt(S->getCond());
   Record.AddStmt(S->getBody());
Index: lib/Serialization/ASTReaderStmt.cpp
===================================================================
--- lib/Serialization/ASTReaderStmt.cpp
+++ lib/Serialization/ASTReaderStmt.cpp
@@ -185,6 +185,7 @@
 void ASTStmtReader::VisitIfStmt(IfStmt *S) {
   VisitStmt(S);
   S->setConstexpr(Record[Idx++]);
+  S->setInit(Reader.ReadSubStmt());
   S->setConditionVariable(Reader.getContext(),
                           ReadDeclAs<VarDecl>(Record, Idx));
   S->setCond(Reader.ReadSubExpr());
@@ -196,6 +197,7 @@
 
 void ASTStmtReader::VisitSwitchStmt(SwitchStmt *S) {
   VisitStmt(S);
+  S->setInit(Reader.ReadSubStmt());
   S->setConditionVariable(Reader.getContext(),
                           ReadDeclAs<VarDecl>(Record, Idx));
   S->setCond(Reader.ReadSubExpr());
Index: lib/Sema/TreeTransform.h
===================================================================
--- lib/Sema/TreeTransform.h
+++ lib/Sema/TreeTransform.h
@@ -1174,9 +1174,9 @@
   /// By default, performs semantic analysis to build the new statement.
   /// Subclasses may override this routine to provide different behavior.
   StmtResult RebuildIfStmt(SourceLocation IfLoc, bool IsConstexpr,
-                           Sema::ConditionResult Cond, Stmt *Then,
+                           Sema::ConditionResult Cond, Stmt *Init, Stmt *Then,
                            SourceLocation ElseLoc, Stmt *Else) {
-    return getSema().ActOnIfStmt(IfLoc, IsConstexpr, nullptr, Cond, Then,
+    return getSema().ActOnIfStmt(IfLoc, IsConstexpr, Init, Cond, Then,
                                  ElseLoc, Else);
   }
 
@@ -1185,8 +1185,9 @@
   /// By default, performs semantic analysis to build the new statement.
   /// Subclasses may override this routine to provide different behavior.
   StmtResult RebuildSwitchStmtStart(SourceLocation SwitchLoc,
+                                    Stmt *Init,
                                     Sema::ConditionResult Cond) {
-    return getSema().ActOnStartOfSwitchStmt(SwitchLoc, nullptr, Cond);
+    return getSema().ActOnStartOfSwitchStmt(SwitchLoc, Init, Cond);
   }
 
   /// \brief Attach the body to the switch statement.
@@ -6236,6 +6237,11 @@
 template<typename Derived>
 StmtResult
 TreeTransform<Derived>::TransformIfStmt(IfStmt *S) {
+  // Transform the initialization statement
+  StmtResult Init = getDerived().TransformStmt(S->getInit());
+  if (Init.isInvalid())
+    return StmtError();
+
   // Transform the condition
   Sema::ConditionResult Cond = getDerived().TransformCondition(
       S->getIfLoc(), S->getConditionVariable(), S->getCond(),
@@ -6268,6 +6274,7 @@
   }
 
   if (!getDerived().AlwaysRebuild() &&
+      Init.get() == S->getInit() &&
       Cond.get() == std::make_pair(S->getConditionVariable(), S->getCond()) &&
       Then.get() == S->getThen() &&
       Else.get() == S->getElse())
@@ -6274,12 +6281,17 @@
     return S;
 
   return getDerived().RebuildIfStmt(S->getIfLoc(), S->isConstexpr(), Cond,
-                                    Then.get(), S->getElseLoc(), Else.get());
+                                    Init.get(), Then.get(), S->getElseLoc(), Else.get());
 }
 
 template<typename Derived>
 StmtResult
 TreeTransform<Derived>::TransformSwitchStmt(SwitchStmt *S) {
+  // Transform the initialization statement
+  StmtResult Init = getDerived().TransformStmt(S->getInit());
+  if (Init.isInvalid())
+    return StmtError();
+
   // Transform the condition.
   Sema::ConditionResult Cond = getDerived().TransformCondition(
       S->getSwitchLoc(), S->getConditionVariable(), S->getCond(),
@@ -6289,7 +6301,8 @@
 
   // Rebuild the switch statement.
   StmtResult Switch
-    = getDerived().RebuildSwitchStmtStart(S->getSwitchLoc(), Cond);
+    = getDerived().RebuildSwitchStmtStart(S->getSwitchLoc(),
+                                          S->getInit(), Cond);
   if (Switch.isInvalid())
     return StmtError();
 
Index: lib/Sema/SemaStmt.cpp
===================================================================
--- lib/Sema/SemaStmt.cpp
+++ lib/Sema/SemaStmt.cpp
@@ -508,9 +508,6 @@
                   ConditionResult Cond,
                   Stmt *thenStmt, SourceLocation ElseLoc,
                   Stmt *elseStmt) {
-  if (InitStmt)
-    Diag(InitStmt->getLocStart(), diag::err_init_stmt_not_supported);
-
   if (Cond.isInvalid())
     Cond = ConditionResult(
         *this, nullptr,
@@ -528,10 +525,10 @@
     DiagnoseEmptyStmtBody(CondExpr->getLocEnd(), thenStmt,
                           diag::warn_empty_if_body);
 
-  return BuildIfStmt(IfLoc, IsConstexpr, Cond, thenStmt, ElseLoc, elseStmt);
+  return BuildIfStmt(IfLoc, IsConstexpr, InitStmt, Cond, thenStmt, ElseLoc, elseStmt);
 }
 
-StmtResult Sema::BuildIfStmt(SourceLocation IfLoc, bool IsConstexpr,
+StmtResult Sema::BuildIfStmt(SourceLocation IfLoc, bool IsConstexpr, Stmt *InitStmt,
                              ConditionResult Cond, Stmt *thenStmt,
                              SourceLocation ElseLoc, Stmt *elseStmt) {
   if (Cond.isInvalid())
@@ -543,7 +540,7 @@
   DiagnoseUnusedExprResult(thenStmt);
   DiagnoseUnusedExprResult(elseStmt);
 
-  return new (Context) IfStmt(Context, IfLoc, IsConstexpr, Cond.get().first,
+  return new (Context) IfStmt(Context, IfLoc, IsConstexpr, InitStmt, Cond.get().first,
                               Cond.get().second, thenStmt, ElseLoc, elseStmt);
 }
 
@@ -668,13 +665,10 @@
   if (Cond.isInvalid())
     return StmtError();
 
-  if (InitStmt)
-    Diag(InitStmt->getLocStart(), diag::err_init_stmt_not_supported);
-
   getCurFunction()->setHasBranchIntoScope();
 
   SwitchStmt *SS =
-      new (Context) SwitchStmt(Context, Cond.get().first, Cond.get().second);
+      new (Context) SwitchStmt(Context, InitStmt, Cond.get().first, Cond.get().second);
   getCurFunction()->SwitchStack.push_back(SS);
   return SS;
 }
Index: lib/CodeGen/CGStmt.cpp
===================================================================
--- lib/CodeGen/CGStmt.cpp
+++ lib/CodeGen/CGStmt.cpp
@@ -561,6 +561,9 @@
   // unequal to 0.  The condition must be a scalar type.
   LexicalScope ConditionScope(*this, S.getCond()->getSourceRange());
 
+  if (S.getInit())
+      EmitStmt(S.getInit());
+
   if (S.getConditionVariable())
     EmitAutoVarDecl(*S.getConditionVariable());
 
@@ -1465,6 +1468,9 @@
         incrementProfileCounter(Case);
       RunCleanupsScope ExecutedScope(*this);
 
+      if (S.getInit())
+        EmitStmt(S.getInit());
+
       // Emit the condition variable if needed inside the entire cleanup scope
       // used by this special case for constant folded switches.
       if (S.getConditionVariable())
@@ -1492,6 +1498,10 @@
   JumpDest SwitchExit = getJumpDestInCurrentScope("sw.epilog");
 
   RunCleanupsScope ConditionScope(*this);
+
+  if (S.getInit())
+    EmitStmt(S.getInit());
+
   if (S.getConditionVariable())
     EmitAutoVarDecl(*S.getConditionVariable());
   llvm::Value *CondV = EmitScalarExpr(S.getCond());
Index: lib/Analysis/BodyFarm.cpp
===================================================================
--- lib/Analysis/BodyFarm.cpp
+++ lib/Analysis/BodyFarm.cpp
@@ -239,7 +239,7 @@
                                            SourceLocation());
   
   // (5) Create the 'if' statement.
-  IfStmt *If = new (C) IfStmt(C, SourceLocation(), false, nullptr, UO, CS);
+  IfStmt *If = new (C) IfStmt(C, SourceLocation(), false, nullptr, nullptr, UO, CS);
   return If;
 }
 
@@ -343,7 +343,7 @@
   
   /// Construct the If.
   Stmt *If =
-    new (C) IfStmt(C, SourceLocation(), false, nullptr, Comparison, Body,
+    new (C) IfStmt(C, SourceLocation(), false, nullptr, nullptr, Comparison, Body,
                    SourceLocation(), Else);
 
   return If;  
Index: lib/AST/Stmt.cpp
===================================================================
--- lib/AST/Stmt.cpp
+++ lib/AST/Stmt.cpp
@@ -764,11 +764,12 @@
 }
 
 IfStmt::IfStmt(const ASTContext &C, SourceLocation IL, bool IsConstexpr,
-               VarDecl *var, Expr *cond, Stmt *then, SourceLocation EL,
+               Stmt *init, VarDecl *var, Expr *cond, Stmt *then, SourceLocation EL,
                Stmt *elsev)
     : Stmt(IfStmtClass), IfLoc(IL), ElseLoc(EL) {
   setConstexpr(IsConstexpr);
   setConditionVariable(C, var);
+  SubExprs[INIT] = init;
   SubExprs[COND] = cond;
   SubExprs[THEN] = then;
   SubExprs[ELSE] = elsev;
@@ -824,9 +825,10 @@
                                        VarRange.getEnd());
 }
 
-SwitchStmt::SwitchStmt(const ASTContext &C, VarDecl *Var, Expr *cond)
+SwitchStmt::SwitchStmt(const ASTContext &C, Stmt* init, VarDecl *Var, Expr *cond)
     : Stmt(SwitchStmtClass), FirstCase(nullptr, false) {
   setConditionVariable(C, Var);
+  SubExprs[INIT] = init;
   SubExprs[COND] = cond;
   SubExprs[BODY] = nullptr;
 }
Index: lib/AST/ASTImporter.cpp
===================================================================
--- lib/AST/ASTImporter.cpp
+++ lib/AST/ASTImporter.cpp
@@ -4963,6 +4963,9 @@
 
 Stmt *ASTNodeImporter::VisitIfStmt(IfStmt *S) {
   SourceLocation ToIfLoc = Importer.Import(S->getIfLoc());
+  Stmt *ToInit = Importer.Import(S->getInit());
+  if (!ToInit && S->getInit())
+    return nullptr;
   VarDecl *ToConditionVariable = nullptr;
   if (VarDecl *FromConditionVariable = S->getConditionVariable()) {
     ToConditionVariable =
@@ -4982,6 +4985,7 @@
     return nullptr;
   return new (Importer.getToContext()) IfStmt(Importer.getToContext(),
                                               ToIfLoc, S->isConstexpr(),
+                                              ToInit,
                                               ToConditionVariable,
                                               ToCondition, ToThenStmt,
                                               ToElseLoc, ToElseStmt);
@@ -4988,6 +4992,9 @@
 }
 
 Stmt *ASTNodeImporter::VisitSwitchStmt(SwitchStmt *S) {
+  Stmt *ToInit = Importer.Import(S->getInit());
+  if (!ToInit && S->getInit())
+    return nullptr;
   VarDecl *ToConditionVariable = nullptr;
   if (VarDecl *FromConditionVariable = S->getConditionVariable()) {
     ToConditionVariable =
@@ -4999,8 +5006,8 @@
   if (!ToCondition && S->getCond())
     return nullptr;
   SwitchStmt *ToStmt = new (Importer.getToContext()) SwitchStmt(
-                         Importer.getToContext(), ToConditionVariable,
-                         ToCondition);
+                         Importer.getToContext(), ToInit,
+                         ToConditionVariable, ToCondition);
   Stmt *ToBody = Importer.Import(S->getBody());
   if (!ToBody && S->getBody())
     return nullptr;
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -3401,6 +3401,7 @@
                          ConditionResult Cond, Stmt *ThenVal,
                          SourceLocation ElseLoc, Stmt *ElseVal);
   StmtResult BuildIfStmt(SourceLocation IfLoc, bool IsConstexpr,
+                         Stmt *InitStmt,
                          ConditionResult Cond, Stmt *ThenVal,
                          SourceLocation ElseLoc, Stmt *ElseVal);
   StmtResult ActOnStartOfSwitchStmt(SourceLocation SwitchLoc,
Index: include/clang/AST/Stmt.h
===================================================================
--- include/clang/AST/Stmt.h
+++ include/clang/AST/Stmt.h
@@ -879,7 +879,7 @@
 /// IfStmt - This represents an if/then/else.
 ///
 class IfStmt : public Stmt {
-  enum { VAR, COND, THEN, ELSE, END_EXPR };
+  enum { INIT, VAR, COND, THEN, ELSE, END_EXPR };
   Stmt* SubExprs[END_EXPR];
 
   SourceLocation IfLoc;
@@ -887,7 +887,7 @@
 
 public:
   IfStmt(const ASTContext &C, SourceLocation IL,
-         bool IsConstexpr, VarDecl *var, Expr *cond,
+         bool IsConstexpr, Stmt* init, VarDecl *var, Expr *cond,
          Stmt *then, SourceLocation EL = SourceLocation(),
          Stmt *elsev = nullptr);
 
@@ -911,6 +911,9 @@
     return reinterpret_cast<DeclStmt*>(SubExprs[VAR]);
   }
 
+  Stmt *getInit() { return SubExprs[INIT]; }
+  const Stmt *getInit() const { return SubExprs[INIT]; }
+  void setInit(Stmt *S) { SubExprs[INIT] = S; }
   const Expr *getCond() const { return reinterpret_cast<Expr*>(SubExprs[COND]);}
   void setCond(Expr *E) { SubExprs[COND] = reinterpret_cast<Stmt *>(E); }
   const Stmt *getThen() const { return SubExprs[THEN]; }
@@ -953,7 +956,7 @@
 ///
 class SwitchStmt : public Stmt {
   SourceLocation SwitchLoc;
-  enum { VAR, COND, BODY, END_EXPR };
+  enum { INIT, VAR, COND, BODY, END_EXPR };
   Stmt* SubExprs[END_EXPR];
   // This points to a linked list of case and default statements and, if the
   // SwitchStmt is a switch on an enum value, records whether all the enum
@@ -962,7 +965,7 @@
   llvm::PointerIntPair<SwitchCase *, 1, bool> FirstCase;
 
 public:
-  SwitchStmt(const ASTContext &C, VarDecl *Var, Expr *cond);
+  SwitchStmt(const ASTContext &C, Stmt *Init, VarDecl *Var, Expr *cond);
 
   /// \brief Build a empty switch statement.
   explicit SwitchStmt(EmptyShell Empty) : Stmt(SwitchStmtClass, Empty) { }
@@ -985,6 +988,9 @@
     return reinterpret_cast<DeclStmt*>(SubExprs[VAR]);
   }
 
+  Stmt *getInit() { return SubExprs[INIT]; }
+  const Stmt *getInit() const { return SubExprs[INIT]; }
+  void setInit(Stmt *S) { SubExprs[INIT] = S; }
   const Expr *getCond() const { return reinterpret_cast<Expr*>(SubExprs[COND]);}
   const Stmt *getBody() const { return SubExprs[BODY]; }
   const SwitchCase *getSwitchCaseList() const { return FirstCase.getPointer(); }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to