fhahn updated this revision to Diff 261915. fhahn added a comment. Add missing early exit.
Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D72281/new/ https://reviews.llvm.org/D72281 Files: clang/include/clang/AST/ASTContext.h clang/include/clang/AST/RecursiveASTVisitor.h clang/include/clang/AST/Type.h clang/include/clang/AST/TypeLoc.h clang/include/clang/AST/TypeProperties.td clang/include/clang/Basic/Attr.td clang/include/clang/Basic/DiagnosticSemaKinds.td clang/include/clang/Basic/Features.def clang/include/clang/Basic/LangOptions.def clang/include/clang/Basic/TypeNodes.td clang/include/clang/Driver/Options.td clang/include/clang/Sema/Sema.h clang/include/clang/Serialization/TypeBitCodes.def clang/lib/AST/ASTContext.cpp clang/lib/AST/ASTStructuralEquivalence.cpp clang/lib/AST/ExprConstant.cpp clang/lib/AST/ItaniumMangle.cpp clang/lib/AST/MicrosoftMangle.cpp clang/lib/AST/Type.cpp clang/lib/AST/TypePrinter.cpp clang/lib/CodeGen/CGDebugInfo.cpp clang/lib/CodeGen/CGDebugInfo.h clang/lib/CodeGen/CGExpr.cpp clang/lib/CodeGen/CodeGenFunction.cpp clang/lib/CodeGen/CodeGenTypes.cpp clang/lib/CodeGen/ItaniumCXXABI.cpp clang/lib/Driver/ToolChains/Clang.cpp clang/lib/Frontend/CompilerInvocation.cpp clang/lib/Sema/SemaExpr.cpp clang/lib/Sema/SemaLookup.cpp clang/lib/Sema/SemaTemplate.cpp clang/lib/Sema/SemaTemplateDeduction.cpp clang/lib/Sema/SemaType.cpp clang/lib/Sema/TreeTransform.h clang/lib/Serialization/ASTReader.cpp clang/lib/Serialization/ASTWriter.cpp clang/test/CodeGen/debug-info-matrix-types.c clang/test/CodeGen/matrix-type.c clang/test/CodeGenCXX/matrix-type.cpp clang/test/Parser/matrix-type-disabled.c clang/test/SemaCXX/matrix-type.cpp clang/tools/libclang/CIndex.cpp
Index: clang/tools/libclang/CIndex.cpp =================================================================== --- clang/tools/libclang/CIndex.cpp +++ clang/tools/libclang/CIndex.cpp @@ -1795,6 +1795,8 @@ DEFAULT_TYPELOC_IMPL(DependentSizedExtVector, Type) DEFAULT_TYPELOC_IMPL(Vector, Type) DEFAULT_TYPELOC_IMPL(ExtVector, VectorType) +DEFAULT_TYPELOC_IMPL(ConstantMatrix, MatrixType) +DEFAULT_TYPELOC_IMPL(DependentSizedMatrix, MatrixType) DEFAULT_TYPELOC_IMPL(FunctionProto, FunctionType) DEFAULT_TYPELOC_IMPL(FunctionNoProto, FunctionType) DEFAULT_TYPELOC_IMPL(Record, TagType) Index: clang/test/SemaCXX/matrix-type.cpp =================================================================== --- /dev/null +++ clang/test/SemaCXX/matrix-type.cpp @@ -0,0 +1,61 @@ +// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s + +using matrix_double_t = double __attribute__((matrix_type(6, 6))); +using matrix_float_t = float __attribute__((matrix_type(6, 6))); +using matrix_int_t = int __attribute__((matrix_type(6, 6))); + +void matrix_var_dimensions(int Rows, unsigned Columns, char C) { + using matrix1_t = int __attribute__((matrix_type(Rows, 1))); // expected-error{{matrix_type attribute requires an integer constant}} + using matrix2_t = int __attribute__((matrix_type(1, Columns))); // expected-error{{matrix_type attribute requires an integer constant}} + using matrix3_t = int __attribute__((matrix_type(C, C))); // expected-error{{matrix_type attribute requires an integer constant}} + using matrix4_t = int __attribute__((matrix_type(-1, 1))); // expected-error{{matrix row size too large}} + using matrix5_t = int __attribute__((matrix_type(1, -1))); // expected-error{{matrix column size too large}} + using matrix6_t = int __attribute__((matrix_type(0, 1))); // expected-error{{zero matrix size}} + using matrix7_t = int __attribute__((matrix_type(1, 0))); // expected-error{{zero matrix size}} + using matrix7_t = int __attribute__((matrix_type(char, 0))); // expected-error{{expected '(' for function-style cast or type construction}} + using matrix8_t = int __attribute__((matrix_type(1048576, 1))); // expected-error{{matrix row size too large}} +} + +struct S1 {}; + +enum TestEnum { + A, + B +}; + +void matrix_unsupported_element_type() { + using matrix1_t = char *__attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'char *'}} + using matrix2_t = S1 __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'S1'}} + using matrix3_t = bool __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'bool'}} + using matrix4_t = TestEnum __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'TestEnum'}} +} + +template <typename T> // expected-note{{declared here}} +void matrix_template_1() { + using matrix1_t = float __attribute__((matrix_type(T, T))); // expected-error{{'T' does not refer to a value}} +} + +template <class C> // expected-note{{declared here}} +void matrix_template_2() { + using matrix1_t = float __attribute__((matrix_type(C, C))); // expected-error{{'C' does not refer to a value}} +} + +template <unsigned Rows, unsigned Cols> +void matrix_template_3() { + using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{zero matrix size}} +} + +void instantiate_template_3() { + matrix_template_3<1, 10>(); + matrix_template_3<0, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_3<0, 10>' requested here}} +} + +template <int Rows, unsigned Cols> +void matrix_template_4() { + using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{matrix row size too large}} +} + +void instantiate_template_4() { + matrix_template_4<2, 10>(); + matrix_template_4<-3, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_4<-3, 10>' requested here}} +} Index: clang/test/Parser/matrix-type-disabled.c =================================================================== --- /dev/null +++ clang/test/Parser/matrix-type-disabled.c @@ -0,0 +1,14 @@ +// RUN: %clang_cc1 %s -triple i686-apple-darwin -verify -fsyntax-only + +// Matrix types are disabled by default. + +#if __has_extension(matrix_types) +#error Expected extension 'matrix_types' to be disabled +#endif + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); +// expected-error@-1 {{matrix types extension is disabled. Pass -fenable-matrix to enable it}} + +void load_store_double(dx5x5_t *a, dx5x5_t *b) { + *a = *b; +} Index: clang/test/CodeGenCXX/matrix-type.cpp =================================================================== --- /dev/null +++ clang/test/CodeGenCXX/matrix-type.cpp @@ -0,0 +1,176 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); +typedef float fx3x4_t __attribute__((matrix_type(3, 4))); + +// CHECK: %struct.Matrix = type { i8, [12 x float], float } + +void load_store(dx5x5_t *a, dx5x5_t *b) { + // CHECK-LABEL: define void @_Z10load_storePU9Matrix5x5dS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: store <25 x double> %2, <25 x double>* %4, align 8 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef float fx3x3_t __attribute__((matrix_type(3, 3))); + +void parameter_passing(fx3x3_t a, fx3x3_t *b) { + // CHECK-LABEL: define void @_Z17parameter_passingU9Matrix3x3fPS_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float], align 4 + // CHECK-NEXT: %b.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: %0 = bitcast [9 x float]* %a.addr to <9 x float>* + // CHECK-NEXT: store <9 x float> %a, <9 x float>* %0, align 4 + // CHECK-NEXT: store [9 x float]* %b, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %1 = load <9 x float>, <9 x float>* %0, align 4 + // CHECK-NEXT: %2 = load [9 x float]*, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %3 = bitcast [9 x float]* %2 to <9 x float>* + // CHECK-NEXT: store <9 x float> %1, <9 x float>* %3, align 4 + // CHECK-NEXT: ret void + *b = a; +} + +fx3x3_t return_matrix(fx3x3_t *a) { + // CHECK-LABEL: define <9 x float> @_Z13return_matrixPU9Matrix3x3f( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: store [9 x float]* %a, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %0 = load [9 x float]*, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %1 = bitcast [9 x float]* %0 to <9 x float>* + // CHECK-NEXT: %2 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: ret <9 x float> %2 + return *a; +} + +struct Matrix { + char Tmp1; + fx3x4_t Data; + float Tmp2; +}; + +void matrix_struct_pointers(Matrix *a, Matrix *b) { + // CHECK-LABEL: define void @_Z22matrix_struct_pointersP6MatrixS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b->Data = a->Data; +} + +void matrix_struct_reference(Matrix &a, Matrix &b) { + // CHECK-LABEL: define void @_Z23matrix_struct_referenceR6MatrixS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b.Data = a.Data; +} + +class MatrixClass { +public: + int Tmp1; + fx3x4_t Data; + long Tmp2; +}; + +void matrix_class_reference(MatrixClass &a, MatrixClass &b) { + // CHECK-LABEL: define void @_Z22matrix_class_referenceR11MatrixClassS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %class.MatrixClass*, align 8 + // CHECK-NEXT: %b.addr = alloca %class.MatrixClass*, align 8 + // CHECK-NEXT: store %class.MatrixClass* %a, %class.MatrixClass** %a.addr, align 8 + // CHECK-NEXT: store %class.MatrixClass* %b, %class.MatrixClass** %b.addr, align 8 + // CHECK-NEXT: %0 = load %class.MatrixClass*, %class.MatrixClass** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %class.MatrixClass*, %class.MatrixClass** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b.Data = a.Data; +} + +template <typename Ty, unsigned Rows, unsigned Cols> +class MatrixClassTemplate { +public: + using MatrixTy = Ty __attribute__((matrix_type(Rows, Cols))); + int Tmp1; + MatrixTy Data; + long Tmp2; +}; + +template <typename Ty, unsigned Rows, unsigned Cols> +void matrix_template_reference(MatrixClassTemplate<Ty, Rows, Cols> &a, MatrixClassTemplate<Ty, Rows, Cols> &b) { + b.Data = a.Data; +} + +MatrixClassTemplate<float, 10, 15> matrix_template_reference_caller(float *Data) { + // CHECK-LABEL: define void @_Z32matrix_template_reference_callerPf(%class.MatrixClassTemplate* noalias sret align 8 %agg.result, float* %Data + // CHECK-NEXT: entry: + // CHECK-NEXT: %Data.addr = alloca float*, align 8 + // CHECK-NEXT: %Arg = alloca %class.MatrixClassTemplate, align 8 + // CHECK-NEXT: store float* %Data, float** %Data.addr, align 8 + // CHECK-NEXT: %0 = load float*, float** %Data.addr, align 8 + // CHECK-NEXT: %1 = bitcast float* %0 to [150 x float]* + // CHECK-NEXT: %2 = bitcast [150 x float]* %1 to <150 x float>* + // CHECK-NEXT: %3 = load <150 x float>, <150 x float>* %2, align 4 + // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %Arg, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [150 x float]* %Data1 to <150 x float>* + // CHECK-NEXT: store <150 x float> %3, <150 x float>* %4, align 4 + // CHECK-NEXT: call void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %Arg, %class.MatrixClassTemplate* dereferenceable(616) %agg.result) + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %a, %class.MatrixClassTemplate* dereferenceable(616) %b) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %class.MatrixClassTemplate*, align 8 + // CHECK-NEXT: %b.addr = alloca %class.MatrixClassTemplate*, align 8 + // CHECK-NEXT: store %class.MatrixClassTemplate* %a, %class.MatrixClassTemplate** %a.addr, align 8 + // CHECK-NEXT: store %class.MatrixClassTemplate* %b, %class.MatrixClassTemplate** %b.addr, align 8 + // CHECK-NEXT: %0 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [150 x float]* %Data to <150 x float>* + // CHECK-NEXT: %2 = load <150 x float>, <150 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [150 x float]* %Data1 to <150 x float>* + // CHECK-NEXT: store <150 x float> %2, <150 x float>* %4, align 4 + // CHECK-NEXT: ret void + + MatrixClassTemplate<float, 10, 15> Result, Arg; + Arg.Data = *((MatrixClassTemplate<float, 10, 15>::MatrixTy *)Data); + matrix_template_reference(Arg, Result); + return Result; +} Index: clang/test/CodeGen/matrix-type.c =================================================================== --- /dev/null +++ clang/test/CodeGen/matrix-type.c @@ -0,0 +1,158 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +#if !__has_extension(matrix_types) +#error Expected extension 'matrix_types' to be enabled +#endif + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); + +// CHECK: %struct.Matrix = type { i8, [12 x float], float } + +void load_store_double(dx5x5_t *a, dx5x5_t *b) { + // CHECK-LABEL: define void @load_store_double( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: store <25 x double> %2, <25 x double>* %4, align 8 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef float fx3x4_t __attribute__((matrix_type(3, 4))); +void load_store_float(fx3x4_t *a, fx3x4_t *b) { + // CHECK-LABEL: define void @load_store_float( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [12 x float]*, align 8 + // CHECK-NEXT: %b.addr = alloca [12 x float]*, align 8 + // CHECK-NEXT: store [12 x float]* %a, [12 x float]** %a.addr, align 8 + // CHECK-NEXT: store [12 x float]* %b, [12 x float]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [12 x float]*, [12 x float]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [12 x float]* %0 to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load [12 x float]*, [12 x float]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [12 x float]* %3 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef int ix3x4_t __attribute__((matrix_type(4, 3))); +void load_store_int(ix3x4_t *a, ix3x4_t *b) { + // CHECK-LABEL: define void @load_store_int( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [12 x i32]*, align 8 + // CHECK-NEXT: %b.addr = alloca [12 x i32]*, align 8 + // CHECK-NEXT: store [12 x i32]* %a, [12 x i32]** %a.addr, align 8 + // CHECK-NEXT: store [12 x i32]* %b, [12 x i32]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [12 x i32]*, [12 x i32]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [12 x i32]* %0 to <12 x i32>* + // CHECK-NEXT: %2 = load <12 x i32>, <12 x i32>* %1, align 4 + // CHECK-NEXT: %3 = load [12 x i32]*, [12 x i32]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [12 x i32]* %3 to <12 x i32>* + // CHECK-NEXT: store <12 x i32> %2, <12 x i32>* %4, align 4 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef unsigned long long ullx3x4_t __attribute__((matrix_type(4, 3))); +void load_store_ull(ullx3x4_t *a, ullx3x4_t *b) { + // CHECK-LABEL: define void @load_store_ull( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [12 x i64]*, align 8 + // CHECK-NEXT: %b.addr = alloca [12 x i64]*, align 8 + // CHECK-NEXT: store [12 x i64]* %a, [12 x i64]** %a.addr, align 8 + // CHECK-NEXT: store [12 x i64]* %b, [12 x i64]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [12 x i64]*, [12 x i64]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [12 x i64]* %0 to <12 x i64>* + // CHECK-NEXT: %2 = load <12 x i64>, <12 x i64>* %1, align 8 + // CHECK-NEXT: %3 = load [12 x i64]*, [12 x i64]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [12 x i64]* %3 to <12 x i64>* + // CHECK-NEXT: store <12 x i64> %2, <12 x i64>* %4, align 8 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef __fp16 fp16x3x4_t __attribute__((matrix_type(4, 3))); +void load_store_fp16(fp16x3x4_t *a, fp16x3x4_t *b) { + // CHECK-LABEL: define void @load_store_fp16( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [12 x half]*, align 8 + // CHECK-NEXT: %b.addr = alloca [12 x half]*, align 8 + // CHECK-NEXT: store [12 x half]* %a, [12 x half]** %a.addr, align 8 + // CHECK-NEXT: store [12 x half]* %b, [12 x half]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [12 x half]*, [12 x half]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [12 x half]* %0 to <12 x half>* + // CHECK-NEXT: %2 = load <12 x half>, <12 x half>* %1, align 2 + // CHECK-NEXT: %3 = load [12 x half]*, [12 x half]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [12 x half]* %3 to <12 x half>* + // CHECK-NEXT: store <12 x half> %2, <12 x half>* %4, align 2 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef float fx3x3_t __attribute__((matrix_type(3, 3))); + +void parameter_passing(fx3x3_t a, fx3x3_t *b) { + // CHECK-LABEL: define void @parameter_passing( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float], align 4 + // CHECK-NEXT: %b.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: %0 = bitcast [9 x float]* %a.addr to <9 x float>* + // CHECK-NEXT: store <9 x float> %a, <9 x float>* %0, align 4 + // CHECK-NEXT: store [9 x float]* %b, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %1 = load <9 x float>, <9 x float>* %0, align 4 + // CHECK-NEXT: %2 = load [9 x float]*, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %3 = bitcast [9 x float]* %2 to <9 x float>* + // CHECK-NEXT: store <9 x float> %1, <9 x float>* %3, align 4 + // CHECK-NEXT: ret void + *b = a; +} + +fx3x3_t return_matrix(fx3x3_t *a) { + // CHECK-LABEL: define <9 x float> @return_matrix + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: store [9 x float]* %a, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %0 = load [9 x float]*, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %1 = bitcast [9 x float]* %0 to <9 x float>* + // CHECK-NEXT: %2 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: ret <9 x float> %2 + return *a; +} + +typedef struct { + char Tmp1; + fx3x4_t Data; + float Tmp2; +} Matrix; + +void matrix_struct(Matrix *a, Matrix *b) { + // CHECK-LABEL: define void @matrix_struct( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b->Data = a->Data; +} Index: clang/test/CodeGen/debug-info-matrix-types.c =================================================================== --- /dev/null +++ clang/test/CodeGen/debug-info-matrix-types.c @@ -0,0 +1,19 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -debug-info-kind=limited -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +typedef double dx2x3_t __attribute__((matrix_type(2, 3))); + +void load_store_double(dx2x3_t *a, dx2x3_t *b) { + // CHECK-DAG: @llvm.dbg.declare(metadata [6 x double]** %a.addr, metadata [[EXPR_A:![0-9]+]] + // CHECK-DAG: @llvm.dbg.declare(metadata [6 x double]** %b.addr, metadata [[EXPR_B:![0-9]+]] + // CHECK: [[PTR_TY:![0-9]+]] = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: [[TYPEDEF:![0-9]+]], size: 64) + // CHECK: [[TYPEDEF]] = !DIDerivedType(tag: DW_TAG_typedef, name: "dx2x3_t", {{.+}} baseType: [[MATRIX_TY:![0-9]+]]) + // CHECK: [[MATRIX_TY]] = !DICompositeType(tag: DW_TAG_array_type, baseType: [[ELT_TY:![0-9]+]], size: 384, elements: [[ELEMENTS:![0-9]+]]) + // CHECK: [[ELT_TY]] = !DIBasicType(name: "double", size: 64, encoding: DW_ATE_float) + // CHECK: [[ELEMENTS]] = !{[[COLS:![0-9]+]], [[ROWS:![0-9]+]]} + // CHECK: [[COLS]] = !DISubrange(count: 3) + // CHECK: [[ROWS]] = !DISubrange(count: 2) + // CHECK: [[EXPR_A]] = !DILocalVariable(name: "a", arg: 1, {{.+}} type: [[PTR_TY]]) + // CHECK: [[EXPR_B]] = !DILocalVariable(name: "b", arg: 2, {{.+}} type: [[PTR_TY]]) + + *a = *b; +} Index: clang/lib/Serialization/ASTWriter.cpp =================================================================== --- clang/lib/Serialization/ASTWriter.cpp +++ clang/lib/Serialization/ASTWriter.cpp @@ -288,6 +288,25 @@ Record.AddSourceLocation(TL.getNameLoc()); } +void TypeLocWriter::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) { + Record.AddSourceLocation(TL.getAttrNameLoc()); + SourceRange range = TL.getAttrOperandParensRange(); + Record.AddSourceLocation(range.getBegin()); + Record.AddSourceLocation(range.getEnd()); + Record.AddStmt(TL.getAttrRowOperand()); + Record.AddStmt(TL.getAttrColumnOperand()); +} + +void TypeLocWriter::VisitDependentSizedMatrixTypeLoc( + DependentSizedMatrixTypeLoc TL) { + Record.AddSourceLocation(TL.getAttrNameLoc()); + SourceRange range = TL.getAttrOperandParensRange(); + Record.AddSourceLocation(range.getBegin()); + Record.AddSourceLocation(range.getEnd()); + Record.AddStmt(TL.getAttrRowOperand()); + Record.AddStmt(TL.getAttrColumnOperand()); +} + void TypeLocWriter::VisitFunctionTypeLoc(FunctionTypeLoc TL) { Record.AddSourceLocation(TL.getLocalRangeBegin()); Record.AddSourceLocation(TL.getLParenLoc()); Index: clang/lib/Serialization/ASTReader.cpp =================================================================== --- clang/lib/Serialization/ASTReader.cpp +++ clang/lib/Serialization/ASTReader.cpp @@ -6554,6 +6554,21 @@ TL.setNameLoc(readSourceLocation()); } +void TypeLocReader::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) { + TL.setAttrNameLoc(readSourceLocation()); + TL.setAttrOperandParensRange(Reader.readSourceRange()); + TL.setAttrRowOperand(Reader.readExpr()); + TL.setAttrColumnOperand(Reader.readExpr()); +} + +void TypeLocReader::VisitDependentSizedMatrixTypeLoc( + DependentSizedMatrixTypeLoc TL) { + TL.setAttrNameLoc(readSourceLocation()); + TL.setAttrOperandParensRange(Reader.readSourceRange()); + TL.setAttrRowOperand(Reader.readExpr()); + TL.setAttrColumnOperand(Reader.readExpr()); +} + void TypeLocReader::VisitFunctionTypeLoc(FunctionTypeLoc TL) { TL.setLocalRangeBegin(readSourceLocation()); TL.setLParenLoc(readSourceLocation()); Index: clang/lib/Sema/TreeTransform.h =================================================================== --- clang/lib/Sema/TreeTransform.h +++ clang/lib/Sema/TreeTransform.h @@ -894,6 +894,16 @@ Expr *SizeExpr, SourceLocation AttributeLoc); + /// Build a new matrix type given the element type and dimensions. + QualType RebuildConstantMatrixType(QualType ElementType, unsigned NumRows, + unsigned NumColumns); + + /// Build a new matrix type given the type and dependently-defined + /// dimensions. + QualType RebuildDependentSizedMatrixType(QualType ElementType, Expr *RowExpr, + Expr *ColumnExpr, + SourceLocation AttributeLoc); + /// Build a new DependentAddressSpaceType or return the pointee /// type variable with the correct address space (retrieved from /// AddrSpaceExpr) applied to it. The former will be returned in cases @@ -5179,6 +5189,75 @@ return Result; } +template <typename Derived> +QualType +TreeTransform<Derived>::TransformConstantMatrixType(TypeLocBuilder &TLB, + ConstantMatrixTypeLoc TL) { + const ConstantMatrixType *T = TL.getTypePtr(); + QualType ElementType = getDerived().TransformType(T->getElementType()); + if (ElementType.isNull()) + return QualType(); + + QualType Result = TL.getType(); + if (getDerived().AlwaysRebuild() || ElementType != T->getElementType()) { + Result = getDerived().RebuildConstantMatrixType( + ElementType, T->getNumRows(), T->getNumColumns()); + if (Result.isNull()) + return QualType(); + } + + ConstantMatrixTypeLoc NewTL = TLB.push<ConstantMatrixTypeLoc>(Result); + NewTL.setAttrNameLoc(TL.getAttrNameLoc()); + NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange()); + NewTL.setAttrRowOperand(TL.getAttrRowOperand()); + NewTL.setAttrColumnOperand(TL.getAttrColumnOperand()); + + return Result; +} + +template <typename Derived> +QualType TreeTransform<Derived>::TransformDependentSizedMatrixType( + TypeLocBuilder &TLB, DependentSizedMatrixTypeLoc TL) { + const DependentSizedMatrixType *T = TL.getTypePtr(); + + QualType ElementType = getDerived().TransformType(T->getElementType()); + if (ElementType.isNull()) { + return QualType(); + } + + EnterExpressionEvaluationContext Unevaluated( + SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated); + ExprResult Rows = getDerived().TransformExpr(T->getRowExpr()); + ExprResult Cols = getDerived().TransformExpr(T->getColumnExpr()); + + QualType Result = TL.getType(); + // TODO: Finish this + if (getDerived().AlwaysRebuild() || ElementType != T->getElementType() || + Rows.get() != T->getRowExpr() || Cols.get() != T->getColumnExpr()) { + Result = getDerived().RebuildDependentSizedMatrixType( + ElementType, Rows.get(), Cols.get(), T->getAttributeLoc()); + + if (Result.isNull()) + return QualType(); + } + + if (auto *ResultMTy = dyn_cast<DependentSizedMatrixType>(Result)) { + DependentSizedMatrixTypeLoc NewTL = + TLB.push<DependentSizedMatrixTypeLoc>(Result); + NewTL.setAttrNameLoc(TL.getAttrNameLoc()); + NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange()); + NewTL.setAttrRowOperand(ResultMTy->getRowExpr()); + NewTL.setAttrColumnOperand(ResultMTy->getColumnExpr()); + } else { + MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result); + NewTL.setAttrNameLoc(TL.getAttrNameLoc()); + NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange()); + NewTL.setAttrRowOperand(TL.getAttrRowOperand()); + NewTL.setAttrColumnOperand(TL.getAttrColumnOperand()); + } + return Result; +} + template <typename Derived> QualType TreeTransform<Derived>::TransformDependentAddressSpaceType( TypeLocBuilder &TLB, DependentAddressSpaceTypeLoc TL) { @@ -13750,6 +13829,21 @@ return SemaRef.BuildExtVectorType(ElementType, SizeExpr, AttributeLoc); } +template <typename Derived> +QualType TreeTransform<Derived>::RebuildConstantMatrixType( + QualType ElementType, unsigned NumRows, unsigned NumColumns) { + return SemaRef.Context.getConstantMatrixType(ElementType, NumRows, + NumColumns); +} + +template <typename Derived> +QualType TreeTransform<Derived>::RebuildDependentSizedMatrixType( + QualType ElementType, Expr *RowExpr, Expr *ColumnExpr, + SourceLocation AttributeLoc) { + return SemaRef.BuildMatrixType(ElementType, RowExpr, ColumnExpr, + AttributeLoc); +} + template<typename Derived> QualType TreeTransform<Derived>::RebuildFunctionProtoType( QualType T, Index: clang/lib/Sema/SemaType.cpp =================================================================== --- clang/lib/Sema/SemaType.cpp +++ clang/lib/Sema/SemaType.cpp @@ -2492,14 +2492,15 @@ if (!VecSize.isIntN(61)) { // Bit size will overflow uint64. Diag(AttrLoc, diag::err_attribute_size_too_large) - << SizeExpr->getSourceRange(); + << SizeExpr->getSourceRange() << "vector"; return QualType(); } uint64_t VectorSizeBits = VecSize.getZExtValue() * 8; unsigned TypeSize = static_cast<unsigned>(Context.getTypeSize(CurType)); if (VectorSizeBits == 0) { - Diag(AttrLoc, diag::err_attribute_zero_size) << SizeExpr->getSourceRange(); + Diag(AttrLoc, diag::err_attribute_zero_size) + << SizeExpr->getSourceRange() << "vector"; return QualType(); } @@ -2511,7 +2512,7 @@ if (VectorSizeBits / TypeSize > std::numeric_limits<uint32_t>::max()) { Diag(AttrLoc, diag::err_attribute_size_too_large) - << SizeExpr->getSourceRange(); + << SizeExpr->getSourceRange() << "vector"; return QualType(); } @@ -2549,7 +2550,7 @@ if (!vecSize.isIntN(32)) { Diag(AttrLoc, diag::err_attribute_size_too_large) - << ArraySize->getSourceRange(); + << ArraySize->getSourceRange() << "vector"; return QualType(); } // Unlike gcc's vector_size attribute, the size is specified as the @@ -2558,7 +2559,7 @@ if (vectorSize == 0) { Diag(AttrLoc, diag::err_attribute_zero_size) - << ArraySize->getSourceRange(); + << ArraySize->getSourceRange() << "vector"; return QualType(); } @@ -2568,6 +2569,82 @@ return Context.getDependentSizedExtVectorType(T, ArraySize, AttrLoc); } +QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols, + SourceLocation AttrLoc) { + assert(Context.getLangOpts().MatrixTypes && + "Should never build a matrix type when it is disabled"); + + if (NumRows->isTypeDependent() || NumCols->isTypeDependent() || + NumRows->isValueDependent() || NumCols->isValueDependent()) + return Context.getDependentSizedMatrixType(ElementTy, NumRows, NumCols, + AttrLoc); + + if (!MatrixType::isValidElementType(ElementTy)) { + Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << ElementTy; + return QualType(); + } + + // Both row and column values can only be 20 bit wide currently. + llvm::APSInt ValueRows(32), ValueColumns(32); + + bool const RowsIsInteger = NumRows->isIntegerConstantExpr(ValueRows, Context); + bool const ColumnsIsInteger = + NumCols->isIntegerConstantExpr(ValueColumns, Context); + + auto const RowRange = NumRows->getSourceRange(); + auto const ColRange = NumCols->getSourceRange(); + + // Both are row and column expressions are invalid. + if (!RowsIsInteger && !ColumnsIsInteger) { + Diag(AttrLoc, diag::err_attribute_argument_type) + << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange + << ColRange; + return QualType(); + } + + // Only the row expression is invalid. + if (!RowsIsInteger) { + Diag(AttrLoc, diag::err_attribute_argument_type) + << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange; + return QualType(); + } + + // Only the column expression is invalid. + if (!ColumnsIsInteger) { + Diag(AttrLoc, diag::err_attribute_argument_type) + << "matrix_type" << AANT_ArgumentIntegerConstant << ColRange; + return QualType(); + } + + // Check the matrix dimensions. + unsigned MatrixRows = static_cast<unsigned>(ValueRows.getZExtValue()); + unsigned MatrixColumns = static_cast<unsigned>(ValueColumns.getZExtValue()); + if (MatrixRows == 0 && MatrixColumns == 0) { + Diag(AttrLoc, diag::err_attribute_zero_size) + << "matrix" << RowRange << ColRange; + return QualType(); + } + if (MatrixRows == 0) { + Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << RowRange; + return QualType(); + } + if (MatrixColumns == 0) { + Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange; + return QualType(); + } + if (!ConstantMatrixType::isDimensionValid(MatrixRows)) { + Diag(AttrLoc, diag::err_attribute_size_too_large) + << RowRange << "matrix row"; + return QualType(); + } + if (!ConstantMatrixType::isDimensionValid(MatrixColumns)) { + Diag(AttrLoc, diag::err_attribute_size_too_large) + << ColRange << "matrix column"; + return QualType(); + } + return Context.getConstantMatrixType(ElementTy, MatrixRows, MatrixColumns); +} + bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) { if (T->isArrayType() || T->isFunctionType()) { Diag(Loc, diag::err_func_returning_array_function) @@ -6013,6 +6090,37 @@ "no address_space attribute found at the expected location!"); } +static void fillMatrixTypeLoc(MatrixTypeLoc MTL, + const ParsedAttributesView &Attrs) { + for (const ParsedAttr &AL : Attrs) { + if (AL.getKind() == ParsedAttr::AT_MatrixType) { + MTL.setAttrNameLoc(AL.getLoc()); + MTL.setAttrRowOperand(AL.getArgAsExpr(0)); + MTL.setAttrColumnOperand(AL.getArgAsExpr(1)); + MTL.setAttrOperandParensRange(SourceRange()); + return; + } + } + + llvm_unreachable( + "no address_space attribute found at the expected location!"); +} + +static void fillDependentSizedMatrixTypeLoc(DependentSizedMatrixTypeLoc DMTL, + const ParsedAttributesView &Attrs) { + for (const ParsedAttr &AL : Attrs) { + if (AL.getKind() == ParsedAttr::AT_MatrixType) { + DMTL.setAttrNameLoc(AL.getLoc()); + DMTL.setAttrRowOperand(AL.getArgAsExpr(0)); + DMTL.setAttrColumnOperand(AL.getArgAsExpr(1)); + DMTL.setAttrOperandParensRange(SourceRange()); + return; + } + } + + llvm_unreachable( + "no address_space attribute found at the expected location!"); +} /// Create and instantiate a TypeSourceInfo with type source information. /// /// \param T QualType referring to the type as written in source code. @@ -6061,6 +6169,12 @@ CurrTL = TL.getPointeeTypeLoc().getUnqualifiedLoc(); } + if (MatrixTypeLoc TL = CurrTL.getAs<MatrixTypeLoc>()) + fillMatrixTypeLoc(TL, D.getTypeObject(i).getAttrs()); + if (DependentSizedMatrixTypeLoc TL = + CurrTL.getAs<DependentSizedMatrixTypeLoc>()) + fillDependentSizedMatrixTypeLoc(TL, D.getTypeObject(i).getAttrs()); + // FIXME: Ordering here? while (AdjustedTypeLoc TL = CurrTL.getAs<AdjustedTypeLoc>()) CurrTL = TL.getNextTypeLoc().getUnqualifiedLoc(); @@ -7706,6 +7820,68 @@ } } +/// HandleMatrixTypeAttr - "matrix_type" attribute, like ext_vector_type +static void HandleMatrixTypeAttr(QualType &CurType, const ParsedAttr &Attr, + Sema &S) { + if (!S.getLangOpts().MatrixTypes) { + S.Diag(Attr.getLoc(), diag::err_builtin_matrix_disabled); + return; + } + + if (Attr.getNumArgs() != 2) { + S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments) + << Attr << 2; + return; + } + + Expr *RowsExpr = nullptr; + Expr *ColsExpr = nullptr; + + // TODO: Refactor parameter extraction into separate function + // Get the number of rows + if (Attr.isArgIdent(0)) { + CXXScopeSpec SS; + SourceLocation TemplateKeywordLoc; + UnqualifiedId id; + id.setIdentifier(Attr.getArgAsIdent(0)->Ident, Attr.getLoc()); + ExprResult Rows = S.ActOnIdExpression(S.getCurScope(), SS, + TemplateKeywordLoc, id, false, false); + + if (Rows.isInvalid()) + // TODO: maybe a good error message would be nice here + return; + RowsExpr = Rows.get(); + } else { + assert(Attr.isArgExpr(0) && + "Argument to should either be an identity or expression"); + RowsExpr = Attr.getArgAsExpr(0); + } + + // Get the number of columns + if (Attr.isArgIdent(1)) { + CXXScopeSpec SS; + SourceLocation TemplateKeywordLoc; + UnqualifiedId id; + id.setIdentifier(Attr.getArgAsIdent(1)->Ident, Attr.getLoc()); + ExprResult Columns = S.ActOnIdExpression( + S.getCurScope(), SS, TemplateKeywordLoc, id, false, false); + + if (Columns.isInvalid()) + // TODO: a good error message would be nice here + return; + RowsExpr = Columns.get(); + } else { + assert(Attr.isArgExpr(1) && + "Argument to should either be an identity or expression"); + ColsExpr = Attr.getArgAsExpr(1); + } + + // Create the matrix type. + QualType T = S.BuildMatrixType(CurType, RowsExpr, ColsExpr, Attr.getLoc()); + if (!T.isNull()) + CurType = T; +} + static void HandleLifetimeBoundAttr(TypeProcessingState &State, QualType &CurType, ParsedAttr &Attr) { @@ -7857,6 +8033,11 @@ break; } + case ParsedAttr::AT_MatrixType: + HandleMatrixTypeAttr(type, attr, state.getSema()); + attr.setUsedAsTypeAttr(); + break; + MS_TYPE_ATTRS_CASELIST: if (!handleMSPointerTypeQualifierAttr(state, attr, type)) attr.setUsedAsTypeAttr(); Index: clang/lib/Sema/SemaTemplateDeduction.cpp =================================================================== --- clang/lib/Sema/SemaTemplateDeduction.cpp +++ clang/lib/Sema/SemaTemplateDeduction.cpp @@ -2055,6 +2055,97 @@ return Sema::TDK_NonDeducedMismatch; } + // (clang extension) + // + // T __attribute__((matrix_type(<integral constant>, + // <integral constant>))) + case Type::ConstantMatrix: { + const ConstantMatrixType *MatrixArg = dyn_cast<ConstantMatrixType>(Arg); + if (!MatrixArg) + return Sema::TDK_NonDeducedMismatch; + + const ConstantMatrixType *MatrixParam = cast<ConstantMatrixType>(Param); + // Check that the dimensions are the same + if (MatrixParam->getNumRows() != MatrixArg->getNumRows() || + MatrixParam->getNumColumns() != MatrixArg->getNumColumns()) { + return Sema::TDK_NonDeducedMismatch; + } + // Perform deduction on element types. + return DeduceTemplateArgumentsByTypeMatch( + S, TemplateParams, MatrixParam->getElementType(), + MatrixArg->getElementType(), Info, Deduced, TDF); + } + + case Type::DependentSizedMatrix: { + const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg); + if (!MatrixArg) + return Sema::TDK_NonDeducedMismatch; + + unsigned SubTDF = TDF & TDF_IgnoreQualifiers; + + // Check the element type of the matrixes. + const DependentSizedMatrixType *MatrixParam = + cast<DependentSizedMatrixType>(Param); + if (Sema::TemplateDeductionResult Result = + DeduceTemplateArgumentsByTypeMatch( + S, TemplateParams, MatrixParam->getElementType(), + MatrixArg->getElementType(), Info, Deduced, SubTDF)) + return Result; + + // Determine if the number of rows and columns is something we can deduce. + NonTypeTemplateParmDecl *RowNTTP = + getDeducedParameterFromExpr(Info, MatrixParam->getRowExpr()); + NonTypeTemplateParmDecl *ColumnNTTP = + getDeducedParameterFromExpr(Info, MatrixParam->getRowExpr()); + if (!RowNTTP && !ColumnNTTP) + return Sema::TDK_Success; + + // Otherwise perform template argument deduction for the given non-type + // template parameters. + if (RowNTTP) { + auto Result = Sema::TDK_NonDeducedMismatch; + assert(RowNTTP->getDepth() == Info.getDeducedDepth() && + "saw non-type template parameter with wrong depth"); + if (const ConstantMatrixType *ConstantMatrixArg = + dyn_cast<ConstantMatrixType>(MatrixArg)) { + llvm::APSInt ArgRows(S.Context.getTypeSize(S.Context.IntTy), + ConstantMatrixArg->getNumRows()); + Result = DeduceNonTypeTemplateArgument( + S, TemplateParams, RowNTTP, ArgRows, S.Context.getSizeType(), + /*ArrayBound=*/false, Info, Deduced); + } else if (const DependentSizedMatrixType *DepMatrixArg = + dyn_cast<DependentSizedMatrixType>(MatrixArg)) + Result = DeduceNonTypeTemplateArgument(S, TemplateParams, ColumnNTTP, + DepMatrixArg->getRowExpr(), + Info, Deduced); + if (Result != Sema::TDK_Success) + return Result; + } + if (ColumnNTTP) { + auto Result = Sema::TDK_NonDeducedMismatch; + assert(ColumnNTTP->getDepth() == Info.getDeducedDepth() && + "saw non-type template parameter with wrong depth"); + if (const ConstantMatrixType *ConstantMatrixArg = + dyn_cast<ConstantMatrixType>(MatrixArg)) { + llvm::APSInt ArgColumns(S.Context.getTypeSize(S.Context.IntTy), + ConstantMatrixArg->getNumColumns()); + ArgColumns = ConstantMatrixArg->getNumColumns(); + Result = DeduceNonTypeTemplateArgument( + S, TemplateParams, ColumnNTTP, ArgColumns, + S.Context.getSizeType(), + /*ArrayBound=*/false, Info, Deduced); + } else if (const DependentSizedMatrixType *DepMatrixArg = + dyn_cast<DependentSizedMatrixType>(MatrixArg)) + Result = DeduceNonTypeTemplateArgument(S, TemplateParams, ColumnNTTP, + DepMatrixArg->getColumnExpr(), + Info, Deduced); + + if (Result != Sema::TDK_Success) + return Result; + } + return Sema::TDK_Success; + } + // (clang extension) // // T __attribute__(((address_space(N)))) @@ -5723,6 +5814,24 @@ break; } + case Type::ConstantMatrix: { + const ConstantMatrixType *MatType = cast<ConstantMatrixType>(T); + MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced, + Depth, Used); + break; + } + + case Type::DependentSizedMatrix: { + const DependentSizedMatrixType *MatType = cast<DependentSizedMatrixType>(T); + MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced, + Depth, Used); + MarkUsedTemplateParameters(Ctx, MatType->getRowExpr(), OnlyDeduced, Depth, + Used); + MarkUsedTemplateParameters(Ctx, MatType->getColumnExpr(), OnlyDeduced, + Depth, Used); + break; + } + case Type::FunctionProto: { const FunctionProtoType *Proto = cast<FunctionProtoType>(T); MarkUsedTemplateParameters(Ctx, Proto->getReturnType(), OnlyDeduced, Depth, Index: clang/lib/Sema/SemaTemplate.cpp =================================================================== --- clang/lib/Sema/SemaTemplate.cpp +++ clang/lib/Sema/SemaTemplate.cpp @@ -5867,6 +5867,11 @@ return Visit(T->getElementType()); } +bool UnnamedLocalNoLinkageFinder::VisitDependentSizedMatrixType( + const DependentSizedMatrixType *T) { + return Visit(T->getElementType()); +} + bool UnnamedLocalNoLinkageFinder::VisitDependentAddressSpaceType( const DependentAddressSpaceType *T) { return Visit(T->getPointeeType()); @@ -5885,6 +5890,11 @@ return Visit(T->getElementType()); } +bool UnnamedLocalNoLinkageFinder::VisitConstantMatrixType( + const ConstantMatrixType *T) { + return Visit(T->getElementType()); +} + bool UnnamedLocalNoLinkageFinder::VisitFunctionProtoType( const FunctionProtoType* T) { for (const auto &A : T->param_types()) { Index: clang/lib/Sema/SemaLookup.cpp =================================================================== --- clang/lib/Sema/SemaLookup.cpp +++ clang/lib/Sema/SemaLookup.cpp @@ -2966,6 +2966,7 @@ // These are fundamental types. case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::Complex: case Type::ExtInt: break; Index: clang/lib/Sema/SemaExpr.cpp =================================================================== --- clang/lib/Sema/SemaExpr.cpp +++ clang/lib/Sema/SemaExpr.cpp @@ -4257,6 +4257,7 @@ case Type::Complex: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::Record: case Type::Enum: case Type::Elaborated: Index: clang/lib/Frontend/CompilerInvocation.cpp =================================================================== --- clang/lib/Frontend/CompilerInvocation.cpp +++ clang/lib/Frontend/CompilerInvocation.cpp @@ -3336,6 +3336,8 @@ Opts.CompleteMemberPointers = Args.hasArg(OPT_fcomplete_member_pointers); Opts.BuildingPCHWithObjectFile = Args.hasArg(OPT_building_pch_with_obj); + Opts.MatrixTypes = Args.hasArg(OPT_fenable_matrix); + Opts.MaxTokens = getLastArgIntValue(Args, OPT_fmax_tokens_EQ, 0, Diags); if (Arg *A = Args.getLastArg(OPT_msign_return_address_EQ)) { Index: clang/lib/Driver/ToolChains/Clang.cpp =================================================================== --- clang/lib/Driver/ToolChains/Clang.cpp +++ clang/lib/Driver/ToolChains/Clang.cpp @@ -4565,6 +4565,13 @@ if (Args.hasFlag(options::OPT_mrtd, options::OPT_mno_rtd, false)) CmdArgs.push_back("-fdefault-calling-conv=stdcall"); + if (Args.hasArg(options::OPT_fenable_matrix)) { + // enable-matrix is needed by both the LangOpts and by LLVM. + CmdArgs.push_back("-fenable-matrix"); + CmdArgs.push_back("-mllvm"); + CmdArgs.push_back("-enable-matrix"); + } + CodeGenOptions::FramePointerKind FPKeepKind = getFramePointerKind(Args, RawTriple); const char *FPKeepKindStr = nullptr; Index: clang/lib/CodeGen/ItaniumCXXABI.cpp =================================================================== --- clang/lib/CodeGen/ItaniumCXXABI.cpp +++ clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -3223,6 +3223,7 @@ // GCC treats vector and complex types as fundamental types. case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::Complex: case Type::Atomic: // FIXME: GCC treats block pointers as fundamental types?! @@ -3458,6 +3459,7 @@ case Type::Builtin: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::Complex: case Type::BlockPointer: // Itanium C++ ABI 2.9.5p4: Index: clang/lib/CodeGen/CodeGenTypes.cpp =================================================================== --- clang/lib/CodeGen/CodeGenTypes.cpp +++ clang/lib/CodeGen/CodeGenTypes.cpp @@ -82,6 +82,13 @@ /// a type. For example, the scalar representation for _Bool is i1, but the /// memory representation is usually i8 or i32, depending on the target. llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T, bool ForBitField) { + if (T->isConstantMatrixType()) { + const Type *Ty = Context.getCanonicalType(T).getTypePtr(); + const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty); + return llvm::ArrayType::get(ConvertType(MT->getElementType()), + MT->getNumRows() * MT->getNumColumns()); + } + llvm::Type *R = ConvertType(T); // If this is a bool type, or an ExtIntType in a bitfield representation, @@ -646,6 +653,12 @@ VT->getNumElements()); break; } + case Type::ConstantMatrix: { + const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty); + ResultType = llvm::VectorType::get(ConvertType(MT->getElementType()), + MT->getNumRows() * MT->getNumColumns()); + break; + } case Type::FunctionNoProto: case Type::FunctionProto: ResultType = ConvertFunctionTypeInternal(T); Index: clang/lib/CodeGen/CodeGenFunction.cpp =================================================================== --- clang/lib/CodeGen/CodeGenFunction.cpp +++ clang/lib/CodeGen/CodeGenFunction.cpp @@ -247,6 +247,7 @@ case Type::MemberPointer: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::FunctionProto: case Type::FunctionNoProto: case Type::Enum: @@ -2000,6 +2001,7 @@ case Type::Complex: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::Record: case Type::Enum: case Type::Elaborated: Index: clang/lib/CodeGen/CGExpr.cpp =================================================================== --- clang/lib/CodeGen/CGExpr.cpp +++ clang/lib/CodeGen/CGExpr.cpp @@ -145,8 +145,19 @@ Address CodeGenFunction::CreateMemTemp(QualType Ty, CharUnits Align, const Twine &Name, Address *Alloca) { - return CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name, - /*ArraySize=*/nullptr, Alloca); + Address Result = CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name, + /*ArraySize=*/nullptr, Alloca); + + if (Ty->isConstantMatrixType()) { + auto *ArrayTy = cast<llvm::ArrayType>(Result.getType()->getElementType()); + auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(), + ArrayTy->getNumElements()); + + Result = Address( + Builder.CreateBitCast(Result.getPointer(), VectorTy->getPointerTo()), + Result.getAlignment()); + } + return Result; } Address CodeGenFunction::CreateMemTempWithoutCast(QualType Ty, CharUnits Align, @@ -1732,6 +1743,31 @@ return Value; } +// Convert the pointer of \p Addr to a pointer to a vector (the value type of +// MatrixType), if it points to a array (the memory type of MatrixType). +static Address MaybeConvertMatrixAddress(Address Addr, CodeGenFunction &CGF) { + auto *ArrayTy = dyn_cast<llvm::ArrayType>( + cast<llvm::PointerType>(Addr.getPointer()->getType())->getElementType()); + if (ArrayTy) { + auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(), + ArrayTy->getNumElements()); + + return Address(CGF.Builder.CreateElementBitCast(Addr, VectorTy)); + } + return Addr; +} + +// Emit a store of a matrix LValue. This may require casting the original +// pointer to memory address (ArrayType) to a pointer to the value type +// (VectorType). +static void EmitStoreOfMatrixScalar(llvm::Value *value, LValue lvalue, + bool isInit, CodeGenFunction &CGF) { + Address Addr = MaybeConvertMatrixAddress(lvalue.getAddress(CGF), CGF); + CGF.EmitStoreOfScalar(value, Addr, lvalue.isVolatile(), lvalue.getType(), + lvalue.getBaseInfo(), lvalue.getTBAAInfo(), isInit, + lvalue.isNontemporal()); +} + void CodeGenFunction::EmitStoreOfScalar(llvm::Value *Value, Address Addr, bool Volatile, QualType Ty, LValueBaseInfo BaseInfo, @@ -1779,11 +1815,26 @@ void CodeGenFunction::EmitStoreOfScalar(llvm::Value *value, LValue lvalue, bool isInit) { + if (lvalue.getType()->isConstantMatrixType()) { + EmitStoreOfMatrixScalar(value, lvalue, isInit, *this); + return; + } + EmitStoreOfScalar(value, lvalue.getAddress(*this), lvalue.isVolatile(), lvalue.getType(), lvalue.getBaseInfo(), lvalue.getTBAAInfo(), isInit, lvalue.isNontemporal()); } +// Emit a load of a LValue of matrix type. This may require casting the pointer +// to memory address (ArrayType) to a pointer to the value type (VectorType). +static RValue EmitLoadOfMatrixLValue(LValue LV, SourceLocation Loc, + CodeGenFunction &CGF) { + assert(LV.getType()->isConstantMatrixType()); + Address Addr = MaybeConvertMatrixAddress(LV.getAddress(CGF), CGF); + LV.setAddress(Addr); + return RValue::get(CGF.EmitLoadOfScalar(LV, Loc)); +} + /// EmitLoadOfLValue - Given an expression that represents a value lvalue, this /// method emits the address of the lvalue, then loads the result as an rvalue, /// returning the rvalue. @@ -1809,6 +1860,9 @@ if (LV.isSimple()) { assert(!LV.getType()->isFunctionType()); + if (LV.getType()->isConstantMatrixType()) + return EmitLoadOfMatrixLValue(LV, Loc, *this); + // Everything needs a load. return RValue::get(EmitLoadOfScalar(LV, Loc)); } Index: clang/lib/CodeGen/CGDebugInfo.h =================================================================== --- clang/lib/CodeGen/CGDebugInfo.h +++ clang/lib/CodeGen/CGDebugInfo.h @@ -192,6 +192,7 @@ llvm::DIType *CreateType(const ObjCTypeParamType *Ty, llvm::DIFile *Unit); llvm::DIType *CreateType(const VectorType *Ty, llvm::DIFile *F); + llvm::DIType *CreateType(const ConstantMatrixType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const ArrayType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const LValueReferenceType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const RValueReferenceType *Ty, llvm::DIFile *Unit); Index: clang/lib/CodeGen/CGDebugInfo.cpp =================================================================== --- clang/lib/CodeGen/CGDebugInfo.cpp +++ clang/lib/CodeGen/CGDebugInfo.cpp @@ -2736,6 +2736,23 @@ return DBuilder.createVectorType(Size, Align, ElementTy, SubscriptArray); } +llvm::DIType *CGDebugInfo::CreateType(const ConstantMatrixType *Ty, + llvm::DIFile *Unit) { + // FIXME: Create another debug type for matrices + // For the time being, it treats it like a nested ArrayType. + + llvm::DIType *ElementTy = getOrCreateType(Ty->getElementType(), Unit); + uint64_t Size = CGM.getContext().getTypeSize(Ty); + uint32_t Align = getTypeAlignIfRequired(Ty, CGM.getContext()); + + // Create ranges for both dimensions. + llvm::SmallVector<llvm::Metadata *, 2> Subscripts; + Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumColumns())); + Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumRows())); + llvm::DINodeArray SubscriptArray = DBuilder.getOrCreateArray(Subscripts); + return DBuilder.createArrayType(Size, Align, ElementTy, SubscriptArray); +} + llvm::DIType *CGDebugInfo::CreateType(const ArrayType *Ty, llvm::DIFile *Unit) { uint64_t Size; uint32_t Align; @@ -3129,6 +3146,8 @@ case Type::ExtVector: case Type::Vector: return CreateType(cast<VectorType>(Ty), Unit); + case Type::ConstantMatrix: + return CreateType(cast<ConstantMatrixType>(Ty), Unit); case Type::ObjCObjectPointer: return CreateType(cast<ObjCObjectPointerType>(Ty), Unit); case Type::ObjCObject: Index: clang/lib/AST/TypePrinter.cpp =================================================================== --- clang/lib/AST/TypePrinter.cpp +++ clang/lib/AST/TypePrinter.cpp @@ -256,6 +256,8 @@ case Type::DependentSizedExtVector: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: + case Type::DependentSizedMatrix: case Type::FunctionProto: case Type::FunctionNoProto: case Type::Paren: @@ -720,6 +722,38 @@ OS << ")))"; } +void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, + raw_ostream &OS) { + printBefore(T->getElementType(), OS); + OS << " __attribute__((matrix_type("; + OS << T->getNumRows() << ", " << T->getNumColumns(); + OS << ")))"; +} + +void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, + raw_ostream &OS) { + printAfter(T->getElementType(), OS); +} + +void TypePrinter::printDependentSizedMatrixBefore( + const DependentSizedMatrixType *T, raw_ostream &OS) { + printBefore(T->getElementType(), OS); + OS << " __attribute__((matrix_type("; + if (T->getRowExpr()) { + T->getRowExpr()->printPretty(OS, nullptr, Policy); + } + OS << ", "; + if (T->getColumnExpr()) { + T->getColumnExpr()->printPretty(OS, nullptr, Policy); + } + OS << ")))"; +} + +void TypePrinter::printDependentSizedMatrixAfter( + const DependentSizedMatrixType *T, raw_ostream &OS) { + printAfter(T->getElementType(), OS); +} + void FunctionProtoType::printExceptionSpecification(raw_ostream &OS, const PrintingPolicy &Policy) Index: clang/lib/AST/Type.cpp =================================================================== --- clang/lib/AST/Type.cpp +++ clang/lib/AST/Type.cpp @@ -282,6 +282,53 @@ AddrSpaceExpr->Profile(ID, Context, true); } +MatrixType::MatrixType(TypeClass tc, QualType matrixType, QualType canonType, + const Expr *RowExpr, const Expr *ColumnExpr) + : Type(tc, canonType, + (RowExpr + ? (TypeDependence::Dependent | TypeDependence::Instantiation | + (matrixType->isVariablyModifiedType() + ? TypeDependence::VariablyModified + : TypeDependence::None) | + (matrixType->containsUnexpandedParameterPack() || + (RowExpr && + RowExpr->containsUnexpandedParameterPack()) || + (ColumnExpr && + ColumnExpr->containsUnexpandedParameterPack()) + ? TypeDependence::UnexpandedPack + : TypeDependence::None)) + : matrixType->getDependence())), + ElementType(matrixType) {} + +ConstantMatrixType::ConstantMatrixType(QualType matrixType, unsigned nRows, + unsigned nColumns, QualType canonType) + : ConstantMatrixType(ConstantMatrix, matrixType, nRows, nColumns, + canonType) {} + +ConstantMatrixType::ConstantMatrixType(TypeClass tc, QualType matrixType, + unsigned nRows, unsigned nColumns, + QualType canonType) + : MatrixType(tc, matrixType, canonType) { + ConstantMatrixTypeBits.NumRows = nRows; + ConstantMatrixTypeBits.NumColumns = nColumns; +} + +DependentSizedMatrixType::DependentSizedMatrixType( + const ASTContext &CTX, QualType ElementType, QualType CanonicalType, + Expr *RowExpr, Expr *ColumnExpr, SourceLocation loc) + : MatrixType(DependentSizedMatrix, ElementType, CanonicalType, RowExpr, + ColumnExpr), + Context(CTX), RowExpr(RowExpr), ColumnExpr(ColumnExpr), loc(loc) {} + +void DependentSizedMatrixType::Profile(llvm::FoldingSetNodeID &ID, + const ASTContext &CTX, + QualType ElementType, Expr *RowExpr, + Expr *ColumnExpr) { + ID.AddPointer(ElementType.getAsOpaquePtr()); + RowExpr->Profile(ID, CTX, true); + ColumnExpr->Profile(ID, CTX, true); +} + VectorType::VectorType(QualType vecType, unsigned nElements, QualType canonType, VectorKind vecKind) : VectorType(Vector, vecType, nElements, canonType, vecKind) {} @@ -971,6 +1018,17 @@ return Ctx.getExtVectorType(elementType, T->getNumElements()); } + QualType VisitConstantMatrixType(const ConstantMatrixType *T) { + QualType elementType = recurse(T->getElementType()); + if (elementType.isNull()) + return {}; + if (elementType.getAsOpaquePtr() == T->getElementType().getAsOpaquePtr()) + return QualType(T, 0); + + return Ctx.getConstantMatrixType(elementType, T->getNumRows(), + T->getNumColumns()); + } + QualType VisitFunctionNoProtoType(const FunctionNoProtoType *T) { QualType returnType = recurse(T->getReturnType()); if (returnType.isNull()) @@ -1790,6 +1848,14 @@ return Visit(T->getElementType()); } + Type *VisitDependentSizedMatrixType(const DependentSizedMatrixType *T) { + return Visit(T->getElementType()); + } + + Type *VisitConstantMatrixType(const ConstantMatrixType *T) { + return Visit(T->getElementType()); + } + Type *VisitFunctionProtoType(const FunctionProtoType *T) { if (Syntactic && T->hasTrailingReturn()) return const_cast<FunctionProtoType*>(T); @@ -3744,6 +3810,8 @@ case Type::Vector: case Type::ExtVector: return Cache::get(cast<VectorType>(T)->getElementType()); + case Type::ConstantMatrix: + return Cache::get(cast<ConstantMatrixType>(T)->getElementType()); case Type::FunctionNoProto: return Cache::get(cast<FunctionType>(T)->getReturnType()); case Type::FunctionProto: { @@ -3830,6 +3898,9 @@ case Type::Vector: case Type::ExtVector: return computeTypeLinkageInfo(cast<VectorType>(T)->getElementType()); + case Type::ConstantMatrix: + return computeTypeLinkageInfo( + cast<ConstantMatrixType>(T)->getElementType()); case Type::FunctionNoProto: return computeTypeLinkageInfo(cast<FunctionType>(T)->getReturnType()); case Type::FunctionProto: { @@ -3993,6 +4064,8 @@ case Type::DependentSizedExtVector: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: + case Type::DependentSizedMatrix: case Type::DependentAddressSpace: case Type::FunctionProto: case Type::FunctionNoProto: Index: clang/lib/AST/MicrosoftMangle.cpp =================================================================== --- clang/lib/AST/MicrosoftMangle.cpp +++ clang/lib/AST/MicrosoftMangle.cpp @@ -2730,6 +2730,23 @@ << Range; } +void MicrosoftCXXNameMangler::mangleType(const ConstantMatrixType *T, + Qualifiers quals, SourceRange Range) { + DiagnosticsEngine &Diags = Context.getDiags(); + unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, + "Cannot mangle this matrix type yet"); + Diags.Report(Range.getBegin(), DiagID) << Range; +} + +void MicrosoftCXXNameMangler::mangleType(const DependentSizedMatrixType *T, + Qualifiers quals, SourceRange Range) { + DiagnosticsEngine &Diags = Context.getDiags(); + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, + "Cannot mangle this dependent-sized matrix type yet"); + Diags.Report(Range.getBegin(), DiagID) << Range; +} + void MicrosoftCXXNameMangler::mangleType(const DependentAddressSpaceType *T, Qualifiers, SourceRange Range) { DiagnosticsEngine &Diags = Context.getDiags(); Index: clang/lib/AST/ItaniumMangle.cpp =================================================================== --- clang/lib/AST/ItaniumMangle.cpp +++ clang/lib/AST/ItaniumMangle.cpp @@ -2079,6 +2079,8 @@ case Type::DependentSizedExtVector: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: + case Type::DependentSizedMatrix: case Type::FunctionProto: case Type::FunctionNoProto: case Type::Paren: @@ -3343,6 +3345,25 @@ mangleType(T->getElementType()); } +void CXXNameMangler::mangleType(const ConstantMatrixType *T) { + // Mangle matrix types using a vendor extended type qualifier: + // U<Len>Matrix<Rows>x<Columns><element type> + std::string VendorQualifier = + (llvm::Twine("Matrix") + llvm::Twine(T->getNumRows()) + llvm::Twine("x") + + llvm::Twine(T->getNumColumns())) + .str(); + Out << "U" << VendorQualifier.size() << VendorQualifier; + mangleType(T->getElementType()); +} + +void CXXNameMangler::mangleType(const DependentSizedMatrixType *T) { + DiagnosticsEngine &Diags = Context.getDiags(); + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, + "Cannot mangle this dependent-sized matrix type yet"); + Diags.Report(T->getAttributeLoc(), DiagID); +} + void CXXNameMangler::mangleType(const DependentAddressSpaceType *T) { SplitQualType split = T->getPointeeType().split(); mangleQualifiers(split.Quals, T); Index: clang/lib/AST/ExprConstant.cpp =================================================================== --- clang/lib/AST/ExprConstant.cpp +++ clang/lib/AST/ExprConstant.cpp @@ -10350,6 +10350,7 @@ case Type::BlockPointer: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::ObjCObject: case Type::ObjCInterface: case Type::ObjCObjectPointer: Index: clang/lib/AST/ASTStructuralEquivalence.cpp =================================================================== --- clang/lib/AST/ASTStructuralEquivalence.cpp +++ clang/lib/AST/ASTStructuralEquivalence.cpp @@ -617,6 +617,34 @@ break; } + case Type::DependentSizedMatrix: { + const DependentSizedMatrixType *Mat1 = cast<DependentSizedMatrixType>(T1); + const DependentSizedMatrixType *Mat2 = cast<DependentSizedMatrixType>(T2); + // The element types, row and column expressions must be structurally + // equivalent. + if (!IsStructurallyEquivalent(Context, Mat1->getRowExpr(), + Mat2->getRowExpr()) || + !IsStructurallyEquivalent(Context, Mat1->getColumnExpr(), + Mat2->getColumnExpr()) || + !IsStructurallyEquivalent(Context, Mat1->getElementType(), + Mat2->getElementType())) + return false; + break; + } + + case Type::ConstantMatrix: { + const ConstantMatrixType *Mat1 = cast<ConstantMatrixType>(T1); + const ConstantMatrixType *Mat2 = cast<ConstantMatrixType>(T2); + // The element types must be structurally equivalent and the number of rows + // and columns must match. + if (!IsStructurallyEquivalent(Context, Mat1->getElementType(), + Mat2->getElementType()) || + Mat1->getNumRows() != Mat2->getNumRows() || + Mat1->getNumColumns() != Mat2->getNumColumns()) + return false; + break; + } + case Type::FunctionProto: { const auto *Proto1 = cast<FunctionProtoType>(T1); const auto *Proto2 = cast<FunctionProtoType>(T2); Index: clang/lib/AST/ASTContext.cpp =================================================================== --- clang/lib/AST/ASTContext.cpp +++ clang/lib/AST/ASTContext.cpp @@ -1932,6 +1932,17 @@ break; } + case Type::ConstantMatrix: { + const auto *MT = cast<ConstantMatrixType>(T); + TypeInfo ElementInfo = getTypeInfo(MT->getElementType()); + // The internal layout of a matrix value is implementation defined. + // Initially be ABI compatible with arrays with respect to alignment and + // size. + Width = ElementInfo.Width * MT->getNumRows() * MT->getNumColumns(); + Align = ElementInfo.Align; + break; + } + case Type::Builtin: switch (cast<BuiltinType>(T)->getKind()) { default: llvm_unreachable("Unknown builtin type!"); @@ -3362,6 +3373,8 @@ case Type::DependentVector: case Type::ExtVector: case Type::DependentSizedExtVector: + case Type::ConstantMatrix: + case Type::DependentSizedMatrix: case Type::DependentAddressSpace: case Type::ObjCObject: case Type::ObjCInterface: @@ -3775,6 +3788,83 @@ return QualType(New, 0); } +QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows, + unsigned NumColumns) const { + llvm::FoldingSetNodeID ID; + ConstantMatrixType::Profile(ID, ElementTy, NumRows, NumColumns, + Type::ConstantMatrix); + + assert(MatrixType::isValidElementType(ElementTy) && + "need a valid element type"); + assert(ConstantMatrixType::isDimensionValid(NumRows) && + ConstantMatrixType::isDimensionValid(NumColumns) && + "need valid matrix dimensions"); + void *InsertPos = nullptr; + if (ConstantMatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos)) + return QualType(MTP, 0); + + QualType Canonical; + if (!ElementTy.isCanonical()) { + Canonical = + getConstantMatrixType(getCanonicalType(ElementTy), NumRows, NumColumns); + + ConstantMatrixType *NewIP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + assert(!NewIP && "Matrix type shouldn't already exist in the map"); + (void)NewIP; + } + + auto *New = new (*this, TypeAlignment) + ConstantMatrixType(ElementTy, NumRows, NumColumns, Canonical); + MatrixTypes.InsertNode(New, InsertPos); + Types.push_back(New); + return QualType(New, 0); +} + +QualType ASTContext::getDependentSizedMatrixType(QualType ElementTy, + Expr *RowExpr, + Expr *ColumnExpr, + SourceLocation AttrLoc) const { + QualType CanonElementTy = getCanonicalType(ElementTy); + llvm::FoldingSetNodeID ID; + DependentSizedMatrixType::Profile(ID, *this, CanonElementTy, RowExpr, + ColumnExpr); + + void *InsertPos = nullptr; + DependentSizedMatrixType *Canon = + DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + DependentSizedMatrixType *New; + if (Canon) { + // Already have a canonical version of the matrix type + // + // If it exactly matches the requested type, use it directly. + if (Canon->getElementType() == ElementTy && + Canon->getRowExpr() == RowExpr && Canon->getRowExpr() == ColumnExpr) + New = Canon; + else + // Otherwise use Canon as the canonical type for newly-built type. + New = new (*this, TypeAlignment) DependentSizedMatrixType( + *this, ElementTy, QualType(Canon, 0), RowExpr, ColumnExpr, AttrLoc); + } else { + if (CanonElementTy == ElementTy) { + New = new (*this, TypeAlignment) DependentSizedMatrixType( + *this, ElementTy, QualType(), RowExpr, ColumnExpr, AttrLoc); +#ifndef NDEBUG + DependentSizedMatrixType *CanonCheck = + DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + assert(!CanonCheck && "Dependent-sized matrix canonical type broken"); +#endif + DependentSizedMatrixTypes.InsertNode(New, InsertPos); + } else { + QualType Canon = getDependentSizedMatrixType( + CanonElementTy, RowExpr, ColumnExpr, SourceLocation()); + New = new (*this, TypeAlignment) DependentSizedMatrixType( + *this, ElementTy, Canon, RowExpr, ColumnExpr, AttrLoc); + } + } + Types.push_back(New); + return QualType(New, 0); +} + QualType ASTContext::getDependentAddressSpaceType(QualType PointeeType, Expr *AddrSpaceExpr, SourceLocation AttrLoc) const { @@ -7338,6 +7428,11 @@ *NotEncodedT = T; return; + case Type::ConstantMatrix: + if (NotEncodedT) + *NotEncodedT = T; + return; + // We could see an undeduced auto type here during error recovery. // Just ignore it. case Type::Auto: @@ -8217,6 +8312,16 @@ LHS->getNumElements() == RHS->getNumElements(); } +/// areCompatMatrixTypes - Return true if the two specified matrix types are +/// compatible. +static bool areCompatMatrixTypes(const ConstantMatrixType *LHS, + const ConstantMatrixType *RHS) { + assert(LHS->isCanonicalUnqualified() && RHS->isCanonicalUnqualified()); + return LHS->getElementType() == RHS->getElementType() && + LHS->getNumRows() == RHS->getNumRows() && + LHS->getNumColumns() == RHS->getNumColumns(); +} + bool ASTContext::areCompatibleVectorTypes(QualType FirstVec, QualType SecondVec) { assert(FirstVec->isVectorType() && "FirstVec should be a vector type"); @@ -9414,6 +9519,11 @@ RHSCan->castAs<VectorType>())) return LHS; return {}; + case Type::ConstantMatrix: + if (areCompatMatrixTypes(LHSCan->castAs<ConstantMatrixType>(), + RHSCan->castAs<ConstantMatrixType>())) + return LHS; + return {}; case Type::ObjCObject: { // Check if the types are assignment compatible. // FIXME: This should be type compatibility, e.g. whether Index: clang/include/clang/Serialization/TypeBitCodes.def =================================================================== --- clang/include/clang/Serialization/TypeBitCodes.def +++ clang/include/clang/Serialization/TypeBitCodes.def @@ -60,5 +60,7 @@ TYPE_BIT_CODE(MacroQualified, MACRO_QUALIFIED, 49) TYPE_BIT_CODE(ExtInt, EXT_INT, 50) TYPE_BIT_CODE(DependentExtInt, DEPENDENT_EXT_INT, 51) +TYPE_BIT_CODE(ConstantMatrix, CONSTANT_MATRIX, 52) +TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 53) #undef TYPE_BIT_CODE Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -1627,6 +1627,9 @@ QualType BuildVectorType(QualType T, Expr *VecSize, SourceLocation AttrLoc); QualType BuildExtVectorType(QualType T, Expr *ArraySize, SourceLocation AttrLoc); + QualType BuildMatrixType(QualType T, Expr *NumRows, Expr *NumColumns, + SourceLocation AttrLoc); + QualType BuildAddressSpaceAttr(QualType &T, LangAS ASIdx, Expr *AddrSpace, SourceLocation AttrLoc); Index: clang/include/clang/Driver/Options.td =================================================================== --- clang/include/clang/Driver/Options.td +++ clang/include/clang/Driver/Options.td @@ -2007,6 +2007,10 @@ def fno_strict_return : Flag<["-"], "fno-strict-return">, Group<f_Group>, Flags<[CC1Option]>; +def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>, + Flags<[CC1Option]>, + HelpText<"Enable matrix data type and related builtin functions">; + def fallow_editor_placeholders : Flag<["-"], "fallow-editor-placeholders">, Group<f_Group>, Flags<[CC1Option]>, HelpText<"Treat editor placeholders as valid source code">; Index: clang/include/clang/Basic/TypeNodes.td =================================================================== --- clang/include/clang/Basic/TypeNodes.td +++ clang/include/clang/Basic/TypeNodes.td @@ -69,6 +69,9 @@ def VectorType : TypeNode<Type>; def DependentVectorType : TypeNode<Type>, AlwaysDependent; def ExtVectorType : TypeNode<VectorType>; +def MatrixType : TypeNode<Type, 1>; +def ConstantMatrixType : TypeNode<MatrixType>; +def DependentSizedMatrixType : TypeNode<MatrixType>, AlwaysDependent; def FunctionType : TypeNode<Type, 1>; def FunctionProtoType : TypeNode<FunctionType>; def FunctionNoProtoType : TypeNode<FunctionType>; Index: clang/include/clang/Basic/LangOptions.def =================================================================== --- clang/include/clang/Basic/LangOptions.def +++ clang/include/clang/Basic/LangOptions.def @@ -357,6 +357,8 @@ LANGOPT(RegisterStaticDestructors, 1, 1, "Register C++ static destructors") +LANGOPT(MatrixTypes, 1, 0, "Enable or disable the builtin matrix type") + COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0") ENUM_LANGOPT(SignReturnAddressScope, SignReturnAddressScopeKind, 2, SignReturnAddressScopeKind::None, Index: clang/include/clang/Basic/Features.def =================================================================== --- clang/include/clang/Basic/Features.def +++ clang/include/clang/Basic/Features.def @@ -253,6 +253,7 @@ EXTENSION(pragma_clang_attribute_external_declaration, true) EXTENSION(gnu_asm, LangOpts.GNUAsm) EXTENSION(gnu_asm_goto_with_outputs, LangOpts.GNUAsm) +EXTENSION(matrix_types, LangOpts.MatrixTypes) #undef EXTENSION #undef FEATURE Index: clang/include/clang/Basic/DiagnosticSemaKinds.td =================================================================== --- clang/include/clang/Basic/DiagnosticSemaKinds.td +++ clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -2776,6 +2776,7 @@ def err_attribute_too_few_arguments : Error< "%0 attribute takes at least %1 argument%s1">; def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">; +def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">; def err_attribute_bad_neon_vector_size : Error< "Neon vector size must be 64 or 128 bits">; def err_attribute_requires_positive_integer : Error< @@ -2879,8 +2880,8 @@ "init methods must return an object pointer type, not %0">; def err_attribute_invalid_size : Error< "vector size not an integral multiple of component size">; -def err_attribute_zero_size : Error<"zero vector size">; -def err_attribute_size_too_large : Error<"vector size too large">; +def err_attribute_zero_size : Error<"zero %0 size">; +def err_attribute_size_too_large : Error<"%0 size too large">; def err_typecheck_vector_not_convertable_implict_truncation : Error< "cannot convert between %select{scalar|vector}0 type %1 and vector type" " %2 as implicit conversion would cause truncation">; @@ -10722,6 +10723,9 @@ "%select{non-pointer|function pointer|void pointer}0 argument to " "'__builtin_launder' is not allowed">; +def err_builtin_matrix_disabled: Error< + "matrix types extension is disabled. Pass -fenable-matrix to enable it">; + def err_preserve_field_info_not_field : Error< "__builtin_preserve_field_info argument %0 not a field access">; def err_preserve_field_info_not_const: Error< Index: clang/include/clang/Basic/Attr.td =================================================================== --- clang/include/clang/Basic/Attr.td +++ clang/include/clang/Basic/Attr.td @@ -2460,6 +2460,15 @@ let Documentation = [Undocumented]; } +def MatrixType : TypeAttr { + let Spellings = [Clang<"matrix_type">]; + let Subjects = SubjectList<[TypedefName], ErrorDiag>; + let Args = [ExprArgument<"NumRows">, ExprArgument<"NumColumns">]; + let Documentation = [Undocumented]; + let ASTNode = 0; + let PragmaAttributeSupport = 0; +} + def Visibility : InheritableAttr { let Clone = 0; let Spellings = [GCC<"visibility">]; Index: clang/include/clang/AST/TypeProperties.td =================================================================== --- clang/include/clang/AST/TypeProperties.td +++ clang/include/clang/AST/TypeProperties.td @@ -224,6 +224,41 @@ }]>; } +let Class = MatrixType in { + def : Property<"elementType", QualType> { + let Read = [{ node->getElementType() }]; + } +} + +let Class = ConstantMatrixType in { + def : Property<"numRows", UInt32> { + let Read = [{ node->getNumRows() }]; + } + def : Property<"numColumns", UInt32> { + let Read = [{ node->getNumColumns() }]; + } + + def : Creator<[{ + return ctx.getConstantMatrixType(elementType, numRows, numColumns); + }]>; +} + +let Class = DependentSizedMatrixType in { + def : Property<"rows", ExprRef> { + let Read = [{ node->getRowExpr() }]; + } + def : Property<"columns", ExprRef> { + let Read = [{ node->getColumnExpr() }]; + } + def : Property<"attributeLoc", SourceLocation> { + let Read = [{ node->getAttributeLoc() }]; + } + + def : Creator<[{ + return ctx.getDependentSizedMatrixType(elementType, rows, columns, attributeLoc); + }]>; +} + let Class = FunctionType in { def : Property<"returnType", QualType> { let Read = [{ node->getReturnType() }]; Index: clang/include/clang/AST/TypeLoc.h =================================================================== --- clang/include/clang/AST/TypeLoc.h +++ clang/include/clang/AST/TypeLoc.h @@ -1735,6 +1735,7 @@ void initializeLocal(ASTContext &Context, SourceLocation loc) { setAttrNameLoc(loc); + setAttrOperandParensRange(loc); setAttrOperandParensRange(SourceRange(loc)); setAttrExprOperand(getTypePtr()->getAddrSpaceExpr()); } @@ -1774,6 +1775,68 @@ DependentSizedExtVectorType> { }; +struct MatrixTypeLocInfo { + SourceLocation AttrLoc; + SourceRange OperandParens; + Expr *RowOperand; + Expr *ColumnOperand; +}; + +class MatrixTypeLoc : public ConcreteTypeLoc<UnqualTypeLoc, MatrixTypeLoc, + MatrixType, MatrixTypeLocInfo> { +public: + /// The location of the attribute name, i.e. + /// float __attribute__((matrix_type(4, 2))) + /// ^~~~~~~~~~~~~~~~~ + SourceLocation getAttrNameLoc() const { return getLocalData()->AttrLoc; } + void setAttrNameLoc(SourceLocation loc) { getLocalData()->AttrLoc = loc; } + + /// The attribute's row operand, if it has one. + /// float __attribute__((matrix_type(4, 2))) + /// ^ + Expr *getAttrRowOperand() const { return getLocalData()->RowOperand; } + void setAttrRowOperand(Expr *e) { getLocalData()->RowOperand = e; } + + /// The attribute's column operand, if it has one. + /// float __attribute__((matrix_type(4, 2))) + /// ^ + Expr *getAttrColumnOperand() const { return getLocalData()->ColumnOperand; } + void setAttrColumnOperand(Expr *e) { getLocalData()->ColumnOperand = e; } + + /// The location of the parentheses around the operand, if there is + /// an operand. + /// float __attribute__((matrix_type(4, 2))) + /// ^ ^ + SourceRange getAttrOperandParensRange() const { + return getLocalData()->OperandParens; + } + void setAttrOperandParensRange(SourceRange range) { + getLocalData()->OperandParens = range; + } + + SourceRange getLocalSourceRange() const { + SourceRange range(getAttrNameLoc()); + range.setEnd(getAttrOperandParensRange().getEnd()); + return range; + } + + void initializeLocal(ASTContext &Context, SourceLocation loc) { + setAttrNameLoc(loc); + setAttrOperandParensRange(loc); + setAttrRowOperand(nullptr); + setAttrColumnOperand(nullptr); + } +}; + +class ConstantMatrixTypeLoc + : public InheritingConcreteTypeLoc<MatrixTypeLoc, ConstantMatrixTypeLoc, + ConstantMatrixType> {}; + +class DependentSizedMatrixTypeLoc + : public InheritingConcreteTypeLoc<MatrixTypeLoc, + DependentSizedMatrixTypeLoc, + DependentSizedMatrixType> {}; + // FIXME: location of the '_Complex' keyword. class ComplexTypeLoc : public InheritingConcreteTypeLoc<TypeSpecTypeLoc, ComplexTypeLoc, Index: clang/include/clang/AST/Type.h =================================================================== --- clang/include/clang/AST/Type.h +++ clang/include/clang/AST/Type.h @@ -1654,6 +1654,19 @@ uint32_t NumElements; }; + class ConstantMatrixTypeBitfields { + friend class ConstantMatrixType; + + unsigned : NumTypeBits; + + /// Number of rows and columns. Using 20 bits allows supporting very large + /// matrixes, while keeping 24 bits to accommodate NumTypeBits. + unsigned NumRows : 20; + unsigned NumColumns : 20; + + static constexpr uint32_t MaxElementsPerDimension = (1 << 20) - 1; + }; + class AttributedTypeBitfields { friend class AttributedType; @@ -1763,6 +1776,7 @@ TypeWithKeywordBitfields TypeWithKeywordBits; ElaboratedTypeBitfields ElaboratedTypeBits; VectorTypeBitfields VectorTypeBits; + ConstantMatrixTypeBitfields ConstantMatrixTypeBits; SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits; TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits; DependentTemplateSpecializationTypeBitfields @@ -2021,6 +2035,7 @@ bool isComplexIntegerType() const; // GCC _Complex integer type. bool isVectorType() const; // GCC vector type. bool isExtVectorType() const; // Extended vector type. + bool isConstantMatrixType() const; // Matrix type. bool isDependentAddressSpaceType() const; // value-dependent address space qualifier bool isObjCObjectPointerType() const; // pointer to ObjC object bool isObjCRetainableType() const; // ObjC object or block pointer @@ -3390,6 +3405,130 @@ } }; +/// Represents a matrix type, as defined in the Matrix Types clang extensions. +/// __attribute__((matrix_type(rows, columns))), where "rows" specifies +/// number of rows and "columns" specifies the number of columns. +class MatrixType : public Type, public llvm::FoldingSetNode { +protected: + friend class ASTContext; + + /// The element type of the matrix. + QualType ElementType; + + MatrixType(QualType ElementTy, QualType CanonElementTy); + + MatrixType(TypeClass TypeClass, QualType ElementTy, QualType CanonElementTy, + const Expr *RowExpr = nullptr, const Expr *ColumnExpr = nullptr); + +public: + /// Returns type of the elements being stored in the matrix + QualType getElementType() const { return ElementType; } + + /// Valid elements types are the following: + /// * an integer type (as in C2x 6.2.5p19), but excluding enumerated types + /// and _Bool + /// * the standard floating types float or double + /// * a half-precision floating point type, if one is supported on the target + static bool isValidElementType(QualType T) { + return T->isRealType() && !T->isBooleanType() && !T->isEnumeralType(); + } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + static bool classof(const Type *T) { + return T->getTypeClass() == ConstantMatrix || + T->getTypeClass() == DependentSizedMatrix; + } +}; + +/// Represents a concrete matrix type with constant number of rows and columns +class ConstantMatrixType final : public MatrixType { +protected: + friend class ASTContext; + + /// The element type of the matrix. + QualType ElementType; + + ConstantMatrixType(QualType MatrixElementType, unsigned NRows, + unsigned NColumns, QualType CanonElementType); + + ConstantMatrixType(TypeClass typeClass, QualType MatrixType, unsigned NRows, + unsigned NColumns, QualType CanonElementType); + +public: + /// Returns the number of rows in the matrix. + unsigned getNumRows() const { return ConstantMatrixTypeBits.NumRows; } + + /// Returns the number of columns in the matrix. + unsigned getNumColumns() const { return ConstantMatrixTypeBits.NumColumns; } + + /// Returns the number of elements required to embed the matrix into a vector. + unsigned getNumElementsFlattened() const { + return ConstantMatrixTypeBits.NumRows * ConstantMatrixTypeBits.NumColumns; + } + + /// Returns true if \p NumElements is a valid matrix dimension. + static bool isDimensionValid(uint64_t NumElements) { + return NumElements > 0 && + NumElements <= ConstantMatrixTypeBitfields::MaxElementsPerDimension; + } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, getElementType(), getNumRows(), getNumColumns(), + getTypeClass()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, QualType ElementType, + unsigned NumRows, unsigned NumColumns, + TypeClass TypeClass) { + ID.AddPointer(ElementType.getAsOpaquePtr()); + ID.AddInteger(NumRows); + ID.AddInteger(NumColumns); + ID.AddInteger(TypeClass); + } + + static bool classof(const Type *T) { + return T->getTypeClass() == ConstantMatrix; + } +}; + +/// Represents a matrix type where the type and the number of rows and columns +/// is dependent on a template. +class DependentSizedMatrixType final : public MatrixType { + friend class ASTContext; + + const ASTContext &Context; + Expr *RowExpr; + Expr *ColumnExpr; + + SourceLocation loc; + + DependentSizedMatrixType(const ASTContext &Context, QualType ElementType, + QualType CanonicalType, Expr *RowExpr, + Expr *ColumnExpr, SourceLocation loc); + +public: + QualType getElementType() const { return ElementType; } + Expr *getRowExpr() const { return RowExpr; } + Expr *getColumnExpr() const { return ColumnExpr; } + SourceLocation getAttributeLoc() const { return loc; } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + static bool classof(const Type *T) { + return T->getTypeClass() == DependentSizedMatrix; + } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, Context, getElementType(), getRowExpr(), getColumnExpr()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context, + QualType ElementType, Expr *RowExpr, Expr *ColumnExpr); +}; + /// FunctionType - C99 6.7.5.3 - Function Declarators. This is the common base /// class of FunctionNoProtoType and FunctionProtoType. class FunctionType : public Type { @@ -6605,6 +6744,10 @@ return isa<ExtVectorType>(CanonicalType); } +inline bool Type::isConstantMatrixType() const { + return isa<ConstantMatrixType>(CanonicalType); +} + inline bool Type::isDependentAddressSpaceType() const { return isa<DependentAddressSpaceType>(CanonicalType); } Index: clang/include/clang/AST/RecursiveASTVisitor.h =================================================================== --- clang/include/clang/AST/RecursiveASTVisitor.h +++ clang/include/clang/AST/RecursiveASTVisitor.h @@ -1006,6 +1006,17 @@ DEF_TRAVERSE_TYPE(ExtVectorType, { TRY_TO(TraverseType(T->getElementType())); }) +DEF_TRAVERSE_TYPE(ConstantMatrixType, + { TRY_TO(TraverseType(T->getElementType())); }) + +DEF_TRAVERSE_TYPE(DependentSizedMatrixType, { + if (T->getRowExpr()) + TRY_TO(TraverseStmt(T->getRowExpr())); + if (T->getColumnExpr()) + TRY_TO(TraverseStmt(T->getColumnExpr())); + TRY_TO(TraverseType(T->getElementType())); +}) + DEF_TRAVERSE_TYPE(FunctionNoProtoType, { TRY_TO(TraverseType(T->getReturnType())); }) @@ -1258,6 +1269,18 @@ TRY_TO(TraverseType(TL.getTypePtr()->getElementType())); }) +DEF_TRAVERSE_TYPELOC(ConstantMatrixType, { + TRY_TO(TraverseStmt(TL.getAttrRowOperand())); + TRY_TO(TraverseStmt(TL.getAttrColumnOperand())); + TRY_TO(TraverseType(TL.getTypePtr()->getElementType())); +}) + +DEF_TRAVERSE_TYPELOC(DependentSizedMatrixType, { + TRY_TO(TraverseStmt(TL.getAttrRowOperand())); + TRY_TO(TraverseStmt(TL.getAttrColumnOperand())); + TRY_TO(TraverseType(TL.getTypePtr()->getElementType())); +}) + DEF_TRAVERSE_TYPELOC(FunctionNoProtoType, { TRY_TO(TraverseTypeLoc(TL.getReturnLoc())); }) Index: clang/include/clang/AST/ASTContext.h =================================================================== --- clang/include/clang/AST/ASTContext.h +++ clang/include/clang/AST/ASTContext.h @@ -194,6 +194,8 @@ DependentAddressSpaceTypes; mutable llvm::FoldingSet<VectorType> VectorTypes; mutable llvm::FoldingSet<DependentVectorType> DependentVectorTypes; + mutable llvm::FoldingSet<ConstantMatrixType> MatrixTypes; + mutable llvm::FoldingSet<DependentSizedMatrixType> DependentSizedMatrixTypes; mutable llvm::FoldingSet<FunctionNoProtoType> FunctionNoProtoTypes; mutable llvm::ContextualFoldingSet<FunctionProtoType, ASTContext&> FunctionProtoTypes; @@ -1326,6 +1328,20 @@ Expr *SizeExpr, SourceLocation AttrLoc) const; + /// Return the unique reference to the matrix type of the specified element + /// type and size + /// + /// \pre \p ElementType must be a valid matrix element type (see + /// MatrixType::isValidElementType). + QualType getConstantMatrixType(QualType ElementType, unsigned NumRows, + unsigned NumColumns) const; + + /// Return the unique reference to the matrix type of the specified element + /// type and size + QualType getDependentSizedMatrixType(QualType ElementType, Expr *RowExpr, + Expr *ColumnExpr, + SourceLocation AttrLoc) const; + QualType getDependentAddressSpaceType(QualType PointeeType, Expr *AddrSpaceExpr, SourceLocation AttrLoc) const;
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits