beanz updated this revision to Diff 441759.
beanz added a comment.

Updates based on review feedback.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D128012/new/

https://reviews.llvm.org/D128012

Files:
  clang/include/clang/Sema/HLSLExternalSemaSource.h
  clang/lib/Frontend/FrontendAction.cpp
  clang/lib/Headers/hlsl/hlsl_basic_types.h
  clang/lib/Sema/CMakeLists.txt
  clang/lib/Sema/HLSLExternalSemaSource.cpp
  clang/test/AST/HLSL/vector-alias.hlsl

Index: clang/test/AST/HLSL/vector-alias.hlsl
===================================================================
--- /dev/null
+++ clang/test/AST/HLSL/vector-alias.hlsl
@@ -0,0 +1,53 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s 
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'element __attribute__((ext_vector_type(element_count)))'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'element __attribute__((ext_vector_type(element_count)))' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+  // Verify that the alias is generated inside the hlsl namespace.
+  hlsl::vector<float, 2> Vec2 = {1.0, 2.0};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:24:3, col:43>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:42> col:26 Vec2 'hlsl::vector<float, 2>':'float __attribute__((ext_vector_type(2)))' cinit
+
+  // Verify that you don't need to specify the namespace.
+  vector<int, 2> Vec2a = {1, 2};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:30:3, col:32>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:31> col:18 Vec2a 'vector<int, 2>':'int __attribute__((ext_vector_type(2)))' cinit
+
+  // Build a bigger vector.
+  vector<double, 4> Vec4 = {1.0, 2.0, 3.0, 4.0};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:36:3, col:48>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:47> col:21 used Vec4 'vector<double, 4>':'double __attribute__((ext_vector_type(4)))' cinit
+
+  // Verify that swizzles still work.
+  vector<double, 3> Vec3 = Vec4.xyz;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:42:3, col:36>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:33> col:21 Vec3 'vector<double, 3>':'double __attribute__((ext_vector_type(3)))' cinit
+
+  // Verify that the implicit arguments generate the correct type.
+  vector<> ImpVec4 = {1.0, 2.0, 3.0, 4.0};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:48:3, col:42>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:41> col:12 ImpVec4 'vector<>':'float __attribute__((ext_vector_type(4)))' cinit
+  return 1;
+}
Index: clang/lib/Sema/HLSLExternalSemaSource.cpp
===================================================================
--- /dev/null
+++ clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -0,0 +1,98 @@
+//===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Sema/HLSLExternalSemaSource.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/Basic/AttrKinds.h"
+#include "clang/Sema/Sema.h"
+
+using namespace clang;
+
+char HLSLExternalSemaSource::ID;
+
+HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
+
+void HLSLExternalSemaSource::InitializeSema(Sema &S) {
+  SemaPtr = &S;
+  ASTContext &AST = SemaPtr->getASTContext();
+  IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier);
+  HLSLNamespace =
+      NamespaceDecl::Create(AST, AST.getTranslationUnitDecl(), false,
+                            SourceLocation(), SourceLocation(), &HLSL, nullptr);
+  HLSLNamespace->setImplicit(true);
+  AST.getTranslationUnitDecl()->addDecl(HLSLNamespace);
+  defineHLSLVectorAlias();
+
+  // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
+  // built in types inside a namespace, but we are planning to change that in
+  // the near future. In order to be source compatible older versions of HLSL
+  // will need to implicitly use the hlsl namespace. For now in clang everything
+  // will get added to the namespace, and we can remove the using directive for
+  // future language versions to match HLSL's evolution.
+  auto *UsingDecl = UsingDirectiveDecl::Create(
+      AST, AST.getTranslationUnitDecl(), SourceLocation(), SourceLocation(),
+      NestedNameSpecifierLoc(), SourceLocation(), HLSLNamespace,
+      AST.getTranslationUnitDecl());
+
+  AST.getTranslationUnitDecl()->addDecl(UsingDecl);
+}
+
+void HLSLExternalSemaSource::defineHLSLVectorAlias() {
+  ASTContext &AST = SemaPtr->getASTContext();
+
+  llvm::SmallVector<NamedDecl *> TemplateParams;
+
+  auto *TypeParam = TemplateTypeParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+      &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+  TypeParam->setDefaultArgument(AST.getTrivialTypeSourceInfo(AST.FloatTy));
+
+  TemplateParams.emplace_back(TypeParam);
+
+  auto *SizeParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+      &AST.Idents.get("element_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  Expr *LiteralExpr =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getIntWidth(AST.IntTy), 4),
+                             AST.IntTy, SourceLocation());
+  SizeParam->setDefaultArgument(LiteralExpr);
+  TemplateParams.emplace_back(SizeParam);
+
+  auto *ParamList =
+      TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
+                                    TemplateParams, SourceLocation(), nullptr);
+
+  IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier);
+
+  QualType AliasType = AST.getDependentSizedExtVectorType(
+      AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), SizeParam, false,
+          DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      SourceLocation());
+
+  auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                       SourceLocation(), &II,
+                                       AST.getTrivialTypeSourceInfo(AliasType));
+  Record->setImplicit(true);
+
+  auto *Template =
+      TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                    Record->getIdentifier(), ParamList, Record);
+
+  Record->setDescribedAliasTemplate(Template);
+  Template->setImplicit(true);
+  Template->setLexicalDeclContext(Record->getDeclContext());
+  HLSLNamespace->addDecl(Template);
+}
Index: clang/lib/Sema/CMakeLists.txt
===================================================================
--- clang/lib/Sema/CMakeLists.txt
+++ clang/lib/Sema/CMakeLists.txt
@@ -15,6 +15,7 @@
   CodeCompleteConsumer.cpp
   DeclSpec.cpp
   DelayedDiagnostic.cpp
+  HLSLExternalSemaSource.cpp
   IdentifierResolver.cpp
   JumpDiagnostics.cpp
   MultiplexExternalSemaSource.cpp
Index: clang/lib/Headers/hlsl/hlsl_basic_types.h
===================================================================
--- clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -27,38 +27,38 @@
 // built-in vector data types:
 
 #ifdef __HLSL_ENABLE_16_BIT
-typedef int16_t int16_t2 __attribute__((ext_vector_type(2)));
-typedef int16_t int16_t3 __attribute__((ext_vector_type(3)));
-typedef int16_t int16_t4 __attribute__((ext_vector_type(4)));
-typedef uint16_t uint16_t2 __attribute__((ext_vector_type(2)));
-typedef uint16_t uint16_t3 __attribute__((ext_vector_type(3)));
-typedef uint16_t uint16_t4 __attribute__((ext_vector_type(4)));
+typedef vector<int16_t, 2> int16_t2;
+typedef vector<int16_t, 3> int16_t3;
+typedef vector<int16_t, 4> int16_t4;
+typedef vector<uint16_t, 2> uint16_t2;
+typedef vector<uint16_t, 3> uint16_t3;
+typedef vector<uint16_t, 4> uint16_t4;
 #endif
 
-typedef int int2 __attribute__((ext_vector_type(2)));
-typedef int int3 __attribute__((ext_vector_type(3)));
-typedef int int4 __attribute__((ext_vector_type(4)));
-typedef uint uint2 __attribute__((ext_vector_type(2)));
-typedef uint uint3 __attribute__((ext_vector_type(3)));
-typedef uint uint4 __attribute__((ext_vector_type(4)));
-typedef int64_t int64_t2 __attribute__((ext_vector_type(2)));
-typedef int64_t int64_t3 __attribute__((ext_vector_type(3)));
-typedef int64_t int64_t4 __attribute__((ext_vector_type(4)));
-typedef uint64_t uint64_t2 __attribute__((ext_vector_type(2)));
-typedef uint64_t uint64_t3 __attribute__((ext_vector_type(3)));
-typedef uint64_t uint64_t4 __attribute__((ext_vector_type(4)));
+typedef vector<int, 2> int2;
+typedef vector<int, 3> int3;
+typedef vector<int, 4> int4;
+typedef vector<uint, 2> uint2;
+typedef vector<uint, 3> uint3;
+typedef vector<uint, 4> uint4;
+typedef vector<int64_t, 2> int64_t2;
+typedef vector<int64_t, 3> int64_t3;
+typedef vector<int64_t, 4> int64_t4;
+typedef vector<uint64_t, 2> uint64_t2;
+typedef vector<uint64_t, 3> uint64_t3;
+typedef vector<uint64_t, 4> uint64_t4;
 
 #ifdef __HLSL_ENABLE_16_BIT
-typedef half half2 __attribute__((ext_vector_type(2)));
-typedef half half3 __attribute__((ext_vector_type(3)));
-typedef half half4 __attribute__((ext_vector_type(4)));
+typedef vector<half, 2> half2;
+typedef vector<half, 3> half3;
+typedef vector<half, 4> half4;
 #endif
 
-typedef float float2 __attribute__((ext_vector_type(2)));
-typedef float float3 __attribute__((ext_vector_type(3)));
-typedef float float4 __attribute__((ext_vector_type(4)));
-typedef double double2 __attribute__((ext_vector_type(2)));
-typedef double double3 __attribute__((ext_vector_type(3)));
-typedef double double4 __attribute__((ext_vector_type(4)));
+typedef vector<float, 2> float2;
+typedef vector<float, 3> float3;
+typedef vector<float, 4> float4;
+typedef vector<double, 2> double2;
+typedef vector<double, 3> double3;
+typedef vector<double, 4> double4;
 
 #endif //_HLSL_HLSL_BASIC_TYPES_H_
Index: clang/lib/Frontend/FrontendAction.cpp
===================================================================
--- clang/lib/Frontend/FrontendAction.cpp
+++ clang/lib/Frontend/FrontendAction.cpp
@@ -24,6 +24,7 @@
 #include "clang/Lex/Preprocessor.h"
 #include "clang/Lex/PreprocessorOptions.h"
 #include "clang/Parse/ParseAST.h"
+#include "clang/Sema/HLSLExternalSemaSource.h"
 #include "clang/Serialization/ASTDeserializationListener.h"
 #include "clang/Serialization/ASTReader.h"
 #include "clang/Serialization/GlobalModuleIndex.h"
@@ -1014,6 +1015,13 @@
     CI.getASTContext().setExternalSource(Override);
   }
 
+  // Setup HLSL External Sema Source
+  if (CI.getLangOpts().HLSL && CI.hasASTContext()) {
+    IntrusiveRefCntPtr<ExternalASTSource> HLSLSema(
+        new HLSLExternalSemaSource());
+    CI.getASTContext().setExternalSource(HLSLSema);
+  }
+
   FailureCleanup.release();
   return true;
 }
Index: clang/include/clang/Sema/HLSLExternalSemaSource.h
===================================================================
--- /dev/null
+++ clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -0,0 +1,43 @@
+//===--- HLSLExternalSemaSource.h - HLSL Sema Source ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file defines the HLSLExternalSemaSource interface.
+//
+//===----------------------------------------------------------------------===//
+#ifndef CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H
+#define CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H
+
+#include "clang/Sema/ExternalSemaSource.h"
+
+namespace clang {
+class NamespaceDecl;
+class Sema;
+
+class HLSLExternalSemaSource : public ExternalSemaSource {
+  static char ID;
+
+  Sema *SemaPtr = nullptr;
+  NamespaceDecl *HLSLNamespace;
+
+  void defineHLSLVectorAlias();
+
+public:
+  ~HLSLExternalSemaSource() override;
+
+  /// Initialize the semantic source with the Sema instance
+  /// being used to perform semantic analysis on the abstract syntax
+  /// tree.
+  void InitializeSema(Sema &S) override;
+
+  /// Inform the semantic consumer that Sema is no longer available.
+  void ForgetSema() override { SemaPtr = nullptr; }
+};
+
+} // namespace clang
+
+#endif // CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to