fhahn updated this revision to Diff 252634.
fhahn edited the summary of this revision.
fhahn added a comment.

Include columns in structural equi check, fixed type printing todo & rebased


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D72281/new/

https://reviews.llvm.org/D72281

Files:
  clang/include/clang/AST/ASTContext.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/AST/Type.h
  clang/include/clang/AST/TypeLoc.h
  clang/include/clang/AST/TypeProperties.td
  clang/include/clang/Basic/Attr.td
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Basic/LangOptions.def
  clang/include/clang/Basic/TypeNodes.td
  clang/include/clang/Driver/Options.td
  clang/include/clang/Sema/Sema.h
  clang/include/clang/Serialization/TypeBitCodes.def
  clang/lib/AST/ASTContext.cpp
  clang/lib/AST/ASTStructuralEquivalence.cpp
  clang/lib/AST/ExprConstant.cpp
  clang/lib/AST/ItaniumMangle.cpp
  clang/lib/AST/MicrosoftMangle.cpp
  clang/lib/AST/Type.cpp
  clang/lib/AST/TypePrinter.cpp
  clang/lib/Basic/Targets/OSTargets.cpp
  clang/lib/CodeGen/CGDebugInfo.cpp
  clang/lib/CodeGen/CGDebugInfo.h
  clang/lib/CodeGen/CGExpr.cpp
  clang/lib/CodeGen/CodeGenFunction.cpp
  clang/lib/CodeGen/CodeGenTypes.cpp
  clang/lib/CodeGen/ItaniumCXXABI.cpp
  clang/lib/Driver/ToolChains/Clang.cpp
  clang/lib/Frontend/CompilerInvocation.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/lib/Sema/SemaLookup.cpp
  clang/lib/Sema/SemaTemplate.cpp
  clang/lib/Sema/SemaTemplateDeduction.cpp
  clang/lib/Sema/SemaType.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/CodeGen/matrix-type.c
  clang/test/CodeGenCXX/matrix-type.cpp
  clang/test/SemaCXX/matrix-type.cpp
  clang/tools/libclang/CIndex.cpp

Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -1786,6 +1786,8 @@
 DEFAULT_TYPELOC_IMPL(DependentSizedExtVector, Type)
 DEFAULT_TYPELOC_IMPL(Vector, Type)
 DEFAULT_TYPELOC_IMPL(ExtVector, VectorType)
+DEFAULT_TYPELOC_IMPL(Matrix, Type)
+DEFAULT_TYPELOC_IMPL(DependentSizedMatrix, Type)
 DEFAULT_TYPELOC_IMPL(FunctionProto, FunctionType)
 DEFAULT_TYPELOC_IMPL(FunctionNoProto, FunctionType)
 DEFAULT_TYPELOC_IMPL(Record, TagType)
Index: clang/test/SemaCXX/matrix-type.cpp
===================================================================
--- /dev/null
+++ clang/test/SemaCXX/matrix-type.cpp
@@ -0,0 +1,53 @@
+// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
+
+using matrix_double_t = double __attribute__((matrix_type(6, 6)));
+using matrix_float_t = float __attribute__((matrix_type(6, 6)));
+using matrix_int_t = int __attribute__((matrix_type(6, 6)));
+
+void matrix_var_dimensions(int Rows, unsigned Columns, char C) {
+  using matrix1_t = int __attribute__((matrix_type(Rows, 1)));    // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix2_t = int __attribute__((matrix_type(1, Columns))); // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix3_t = int __attribute__((matrix_type(C, C)));       // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix4_t = int __attribute__((matrix_type(-1, 1)));      // expected-error{{vector size too large}}
+  using matrix5_t = int __attribute__((matrix_type(1, -1)));      // expected-error{{vector size too large}}
+  using matrix6_t = int __attribute__((matrix_type(0, 1)));       // expected-error{{zero vector size}}
+  using matrix7_t = int __attribute__((matrix_type(1, 0)));       // expected-error{{zero vector size}}
+  using matrix7_t = int __attribute__((matrix_type(char, 0)));    // expected-error{{expected '(' for function-style cast or type construction}}
+}
+
+struct S1 {};
+
+void matrix_unsupported_element_type() {
+  using matrix1_t = char *__attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'char *'}}
+  using matrix2_t = S1 __attribute__((matrix_type(1, 1)));    // expected-error{{invalid matrix element type 'S1'}}
+}
+
+template <typename T> // expected-note{{declared here}}
+void matrix_template_1() {
+  using matrix1_t = float __attribute__((matrix_type(T, T))); // expected-error{{'T' does not refer to a value}}
+}
+
+template <class C> // expected-note{{declared here}}
+void matrix_template_2() {
+  using matrix1_t = float __attribute__((matrix_type(C, C))); // expected-error{{'C' does not refer to a value}}
+}
+
+template <unsigned Rows, unsigned Cols>
+void matrix_template_3() {
+  using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{zero vector size}}
+}
+
+void instantiate_template_3() {
+  matrix_template_3<1, 10>();
+  matrix_template_3<0, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_3<0, 10>' requested here}}
+}
+
+template <int Rows, unsigned Cols>
+void matrix_template_4() {
+  using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{vector size too large}}
+}
+
+void instantiate_template_4() {
+  matrix_template_4<2, 10>();
+  matrix_template_4<-3, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_4<-3, 10>' requested here}}
+}
Index: clang/test/CodeGenCXX/matrix-type.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGenCXX/matrix-type.cpp
@@ -0,0 +1,176 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store(dx5x5_t *a, dx5x5_t *b) {
+  // CHECK-LABEL:  define void @_Z10load_storePDm5_5_dS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %b, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %0 to <25 x double>*
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [25 x double]* %3 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %2, <25 x double>* %4, align 8
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+  // CHECK-LABEL: define void @_Z17parameter_passingDm3_3_fPS_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float], align 4
+  // CHECK-NEXT:    %b.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %a, <9 x float>* %0, align 4
+  // CHECK-NEXT:    store [9 x float]* %b, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = load <9 x float>, <9 x float>* %0, align 4
+  // CHECK-NEXT:    %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %3 = bitcast [9 x float]* %2 to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %1, <9 x float>* %3, align 4
+  // CHECK-NEXT:    ret void
+  *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+  // CHECK-LABEL: define <9 x float> @_Z13return_matrixPDm3_3_f(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    store [9 x float]* %a, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [9 x float]* %0 to <9 x float>*
+  // CHECK-NEXT:    %2 = load <9 x float>, <9 x float>* %1, align 4
+  // CHECK-NEXT:    ret <9 x float> %2
+  return *a;
+}
+
+struct Matrix {
+  char Tmp1;
+  fx3x4_t Data;
+  float Tmp2;
+};
+
+void matrix_struct_pointers(Matrix *a, Matrix *b) {
+  // CHECK-LABEL: define void @_Z22matrix_struct_pointersP6MatrixS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b->Data = a->Data;
+}
+
+void matrix_struct_reference(Matrix &a, Matrix &b) {
+  // CHECK-LABEL: define void @_Z23matrix_struct_referenceR6MatrixS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b.Data = a.Data;
+}
+
+class MatrixClass {
+public:
+  int Tmp1;
+  fx3x4_t Data;
+  long Tmp2;
+};
+
+void matrix_class_reference(MatrixClass &a, MatrixClass &b) {
+  // CHECK-LABEL: define void @_Z22matrix_class_referenceR11MatrixClassS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %class.MatrixClass*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %class.MatrixClass*, align 8
+  // CHECK-NEXT:    store %class.MatrixClass* %a, %class.MatrixClass** %a.addr, align 8
+  // CHECK-NEXT:    store %class.MatrixClass* %b, %class.MatrixClass** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %class.MatrixClass*, %class.MatrixClass** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %class.MatrixClass*, %class.MatrixClass** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b.Data = a.Data;
+}
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+class MatrixClassTemplate {
+public:
+  using MatrixTy = Ty __attribute__((matrix_type(Rows, Cols)));
+  int Tmp1;
+  MatrixTy Data;
+  long Tmp2;
+};
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+void matrix_template_reference(MatrixClassTemplate<Ty, Rows, Cols> &a, MatrixClassTemplate<Ty, Rows, Cols> &b) {
+  b.Data = a.Data;
+}
+
+MatrixClassTemplate<float, 10, 15> matrix_template_reference_caller(float *Data) {
+  // CHECK-LABEL: define void @_Z32matrix_template_reference_callerPf(%class.MatrixClassTemplate* noalias sret align 8 %agg.result, float* %Data
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Data.addr = alloca float*, align 8
+  // CHECK-NEXT:    %Arg = alloca %class.MatrixClassTemplate, align 8
+  // CHECK-NEXT:    store float* %Data, float** %Data.addr, align 8
+  // CHECK-NEXT:    %0 = load float*, float** %Data.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast float* %0 to [150 x float]*
+  // CHECK-NEXT:    %2 = bitcast [150 x float]* %1 to <150 x float>*
+  // CHECK-NEXT:    %3 = load <150 x float>, <150 x float>* %2, align 4
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %Arg, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+  // CHECK-NEXT:    store <150 x float> %3, <150 x float>* %4, align 4
+  // CHECK-NEXT:    call void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %Arg, %class.MatrixClassTemplate* dereferenceable(616) %agg.result)
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %a, %class.MatrixClassTemplate* dereferenceable(616) %b)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %class.MatrixClassTemplate*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %class.MatrixClassTemplate*, align 8
+  // CHECK-NEXT:    store %class.MatrixClassTemplate* %a, %class.MatrixClassTemplate** %a.addr, align 8
+  // CHECK-NEXT:    store %class.MatrixClassTemplate* %b, %class.MatrixClassTemplate** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [150 x float]* %Data to <150 x float>*
+  // CHECK-NEXT:    %2 = load <150 x float>, <150 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+  // CHECK-NEXT:    store <150 x float> %2, <150 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+
+  MatrixClassTemplate<float, 10, 15> Result, Arg;
+  Arg.Data = *((MatrixClassTemplate<float, 10, 15>::MatrixTy *)Data);
+  matrix_template_reference(Arg, Result);
+  return Result;
+}
Index: clang/test/CodeGen/matrix-type.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/matrix-type.c
@@ -0,0 +1,79 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store(dx5x5_t *a, dx5x5_t *b) {
+  // CHECK-LABEL:  define void @load_store(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %b, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %0 to <25 x double>*
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [25 x double]* %3 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %2, <25 x double>* %4, align 8
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+  // CHECK-LABEL: define void @parameter_passing(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float], align 4
+  // CHECK-NEXT:    %b.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %a, <9 x float>* %0, align 4
+  // CHECK-NEXT:    store [9 x float]* %b, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = load <9 x float>, <9 x float>* %0, align 4
+  // CHECK-NEXT:    %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %3 = bitcast [9 x float]* %2 to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %1, <9 x float>* %3, align 4
+  // CHECK-NEXT:    ret void
+  *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+  // CHECK-LABEL: define <9 x float> @return_matrix
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    store [9 x float]* %a, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [9 x float]* %0 to <9 x float>*
+  // CHECK-NEXT:    %2 = load <9 x float>, <9 x float>* %1, align 4
+  // CHECK-NEXT:    ret <9 x float> %2
+  return *a;
+}
+
+typedef struct {
+  char Tmp1;
+  fx3x4_t Data;
+  float Tmp2;
+} Matrix;
+
+void matrix_struct(Matrix *a, Matrix *b) {
+  // CHECK-LABEL: define void @matrix_struct(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b->Data = a->Data;
+}
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -288,6 +288,15 @@
   Record.AddSourceLocation(TL.getNameLoc());
 }
 
+void TypeLocWriter::VisitMatrixTypeLoc(MatrixTypeLoc TL) {
+  Record.AddSourceLocation(TL.getNameLoc());
+}
+
+void TypeLocWriter::VisitDependentSizedMatrixTypeLoc(
+    DependentSizedMatrixTypeLoc TL) {
+  Record.AddSourceLocation(TL.getNameLoc());
+}
+
 void TypeLocWriter::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
   Record.AddSourceLocation(TL.getLocalRangeBegin());
   Record.AddSourceLocation(TL.getLParenLoc());
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -6525,6 +6525,15 @@
   TL.setNameLoc(readSourceLocation());
 }
 
+void TypeLocReader::VisitMatrixTypeLoc(MatrixTypeLoc TL) {
+  TL.setNameLoc(readSourceLocation());
+}
+
+void TypeLocReader::VisitDependentSizedMatrixTypeLoc(
+    DependentSizedMatrixTypeLoc TL) {
+  TL.setNameLoc(readSourceLocation());
+}
+
 void TypeLocReader::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
   TL.setLocalRangeBegin(readSourceLocation());
   TL.setLParenLoc(readSourceLocation());
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -894,6 +894,16 @@
                                               Expr *SizeExpr,
                                               SourceLocation AttributeLoc);
 
+  /// Build a new matrix type given the element type and dimensions.
+  QualType RebuildMatrixType(QualType ElementType, unsigned NumRows,
+                             unsigned NumColumns);
+
+  /// Build a new matrix type given the type and dependently-defined
+  /// dimensions.
+  QualType RebuildDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
+                                           Expr *ColumnExpr,
+                                           SourceLocation AttributeLoc);
+
   /// Build a new DependentAddressSpaceType or return the pointee
   /// type variable with the correct address space (retrieved from
   /// AddrSpaceExpr) applied to it. The former will be returned in cases
@@ -5136,6 +5146,65 @@
   return Result;
 }
 
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformMatrixType(TypeLocBuilder &TLB,
+                                                     MatrixTypeLoc TL) {
+  const MatrixType *T = TL.getTypePtr();
+  QualType ElementType = getDerived().TransformType(T->getElementType());
+  if (ElementType.isNull())
+    return QualType();
+
+  QualType Result = TL.getType();
+  if (getDerived().AlwaysRebuild() || ElementType != T->getElementType()) {
+    Result = getDerived().RebuildMatrixType(ElementType, T->getNumRows(),
+                                            T->getNumColumns());
+    if (Result.isNull())
+      return QualType();
+  }
+
+  MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result);
+  NewTL.setNameLoc(TL.getNameLoc());
+
+  return Result;
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformDependentSizedMatrixType(
+    TypeLocBuilder &TLB, DependentSizedMatrixTypeLoc TL) {
+  const DependentSizedMatrixType *T = TL.getTypePtr();
+
+  QualType ElementType = getDerived().TransformType(T->getElementType());
+  if (ElementType.isNull()) {
+    return QualType();
+  }
+
+  EnterExpressionEvaluationContext Unevaluated(
+      SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+  ExprResult Rows = getDerived().TransformExpr(T->getRowExpr());
+  ExprResult Cols = getDerived().TransformExpr(T->getColumnExpr());
+
+  QualType Result = TL.getType();
+  // TODO: Finish this
+  if (getDerived().AlwaysRebuild() || ElementType != T->getElementType() ||
+      Rows.get() != T->getRowExpr() || Cols.get() != T->getColumnExpr()) {
+    Result = getDerived().RebuildDependentSizedMatrixType(
+        ElementType, Rows.get(), Cols.get(), T->getAttributeLoc());
+
+    if (Result.isNull())
+      return QualType();
+  }
+
+  if (isa<DependentSizedMatrixType>(Result)) {
+    DependentSizedMatrixTypeLoc NewTL =
+        TLB.push<DependentSizedMatrixTypeLoc>(Result);
+    NewTL.setNameLoc(TL.getNameLoc());
+  } else {
+    MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result);
+    NewTL.setNameLoc(TL.getNameLoc());
+  }
+  return Result;
+}
+
 template <typename Derived>
 QualType TreeTransform<Derived>::TransformDependentAddressSpaceType(
     TypeLocBuilder &TLB, DependentAddressSpaceTypeLoc TL) {
@@ -13546,6 +13615,21 @@
   return SemaRef.BuildExtVectorType(ElementType, SizeExpr, AttributeLoc);
 }
 
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildMatrixType(QualType ElementType,
+                                                   unsigned NumRows,
+                                                   unsigned NumColumns) {
+  return SemaRef.Context.getMatrixType(ElementType, NumRows, NumColumns);
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildDependentSizedMatrixType(
+    QualType ElementType, Expr *RowExpr, Expr *ColumnExpr,
+    SourceLocation AttributeLoc) {
+  return SemaRef.BuildMatrixType(ElementType, RowExpr, ColumnExpr,
+                                 AttributeLoc);
+}
+
 template<typename Derived>
 QualType TreeTransform<Derived>::RebuildFunctionProtoType(
     QualType T,
Index: clang/lib/Sema/SemaType.cpp
===================================================================
--- clang/lib/Sema/SemaType.cpp
+++ clang/lib/Sema/SemaType.cpp
@@ -2505,6 +2505,101 @@
   return Context.getDependentSizedExtVectorType(T, ArraySize, AttrLoc);
 }
 
+/// \brief Build a Matrix Type
+///
+/// Run the required checks for the matrix type
+QualType Sema::BuildMatrixType(QualType T, Expr *NumRows, Expr *NumCols,
+                               SourceLocation AttrLoc) {
+  assert(Context.getLangOpts().EnableMatrix &&
+         "Should never build a matrix type when it is disabled");
+
+  if (NumRows->isTypeDependent() || NumCols->isTypeDependent() ||
+      NumRows->isValueDependent() || NumCols->isValueDependent()) {
+    return Context.getDependentSizedMatrixType(T, NumRows, NumCols, AttrLoc);
+  }
+
+  unsigned MatrixRows = 0;
+  unsigned MatrixColumns = 0;
+
+  { // Handle parameter error checking
+    // Invalid matrix type (must be float or integer)
+    if (!(T->isIntegerType() || T->isRealFloatingType() ||
+          T->isDependentType())) {
+      Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << T;
+      return QualType();
+    }
+
+    // Should this be kept at 32bit even though we're deprecating it?
+    llvm::APSInt ValueRows(32), ValueColumns(32);
+
+    bool const RowsIsInteger =
+        NumRows->isIntegerConstantExpr(ValueRows, Context);
+    bool const ColumnsIsInteger =
+        NumCols->isIntegerConstantExpr(ValueColumns, Context);
+
+    auto const RowRange = NumRows->getSourceRange();
+    auto const ColRange = NumCols->getSourceRange();
+
+    // Both are invalid types
+    if (!RowsIsInteger && !ColumnsIsInteger) {
+      Diag(AttrLoc, diag::err_attribute_argument_type)
+          << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange
+          << ColRange;
+      return QualType();
+    }
+
+    // One or the other are invalid
+    if (!RowsIsInteger) {
+      Diag(AttrLoc, diag::err_attribute_argument_type)
+          << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange;
+      return QualType();
+    }
+
+    // Getting the wrong source range
+    if (!ColumnsIsInteger) {
+      Diag(AttrLoc, diag::err_attribute_argument_type)
+          << "matrix_type" << AANT_ArgumentIntegerConstant << ColRange;
+      return QualType();
+    }
+
+    MatrixRows = static_cast<unsigned>(ValueRows.getZExtValue());
+    MatrixColumns = static_cast<unsigned>(ValueColumns.getZExtValue());
+
+    // Check Matrix size
+    if (MatrixRows == 0 && MatrixColumns == 0) {
+      Diag(AttrLoc, diag::err_attribute_zero_size)
+          << "matrix" << RowRange << ColRange;
+      return QualType();
+    }
+    if (MatrixRows == 0) {
+      Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << RowRange;
+      return QualType();
+    }
+    if (MatrixColumns == 0) {
+      Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange;
+      return QualType();
+    }
+
+    if (VectorType::isVectorSizeTooLarge(MatrixRows) &&
+        VectorType::isVectorSizeTooLarge(MatrixColumns)) {
+      Diag(AttrLoc, diag::err_attribute_size_too_large)
+          << "matrix" << RowRange << ColRange;
+      return QualType();
+    }
+
+    if (VectorType::isVectorSizeTooLarge(MatrixRows)) {
+      Diag(AttrLoc, diag::err_attribute_size_too_large) << "matrix" << RowRange;
+      return QualType();
+    }
+
+    if (VectorType::isVectorSizeTooLarge(MatrixColumns)) {
+      Diag(AttrLoc, diag::err_attribute_size_too_large) << "matrix" << ColRange;
+      return QualType();
+    }
+  }
+  return Context.getMatrixType(T, MatrixRows, MatrixColumns);
+}
+
 bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
   if (T->isArrayType() || T->isFunctionType()) {
     Diag(Loc, diag::err_func_returning_array_function)
@@ -7632,6 +7727,71 @@
   }
 }
 
+/// HandleMatrixTypeAttr - "matrix_type" attribute, like ext_vector_type
+static void HandleMatrixTypeAttr(QualType &CurType, const ParsedAttr &Attr,
+                                 Sema &S) {
+  if (!S.getLangOpts().EnableMatrix) {
+    S.Diag(Attr.getLoc(), diag::err_builtin_matrix_disabled);
+    return;
+  }
+
+  if (Attr.getNumArgs() != 2) {
+    S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments)
+        << Attr << 2;
+    return;
+  }
+
+  Expr *rowsExpr = nullptr;
+  Expr *colsExpr = nullptr;
+
+  // TODO: Refactor parameter extraction into separate function
+  // Get the number of rows
+  if (Attr.isArgIdent(0)) {
+    CXXScopeSpec SS;
+    SourceLocation TemplateKeywordLoc;
+    UnqualifiedId id;
+    id.setIdentifier(Attr.getArgAsIdent(0)->Ident, Attr.getLoc());
+    ExprResult Rows = S.ActOnIdExpression(S.getCurScope(), SS,
+                                          TemplateKeywordLoc, id, false, false);
+
+    if (Rows.isInvalid()) {
+      // TODO: maybe a good error message would be nice here
+      return;
+    }
+    rowsExpr = Rows.get();
+  } else {
+    assert(Attr.isArgExpr(0) &&
+           "Argument to should either be an identity or expression");
+    rowsExpr = Attr.getArgAsExpr(0);
+  }
+
+  // Get the number of columns
+  if (Attr.isArgIdent(1)) {
+    CXXScopeSpec SS;
+    SourceLocation TemplateKeywordLoc;
+    UnqualifiedId id;
+    id.setIdentifier(Attr.getArgAsIdent(1)->Ident, Attr.getLoc());
+    ExprResult Columns = S.ActOnIdExpression(
+        S.getCurScope(), SS, TemplateKeywordLoc, id, false, false);
+
+    if (Columns.isInvalid()) {
+      // TODO: a good error message would be nice here
+      return;
+    }
+    rowsExpr = Columns.get();
+  } else {
+    assert(Attr.isArgExpr(1) &&
+           "Argument to should either be an identity or expression");
+    colsExpr = Attr.getArgAsExpr(1);
+  }
+
+  // Create Matrix Type
+  QualType T = S.BuildMatrixType(CurType, rowsExpr, colsExpr, Attr.getLoc());
+  if (!T.isNull()) {
+    CurType = T;
+  }
+}
+
 static void HandleLifetimeBoundAttr(TypeProcessingState &State,
                                     QualType &CurType,
                                     ParsedAttr &Attr) {
@@ -7783,6 +7943,11 @@
       break;
     }
 
+    case ParsedAttr::AT_MatrixType:
+      HandleMatrixTypeAttr(type, attr, state.getSema());
+      attr.setUsedAsTypeAttr();
+      break;
+
     MS_TYPE_ATTRS_CASELIST:
       if (!handleMSPointerTypeQualifierAttr(state, attr, type))
         attr.setUsedAsTypeAttr();
Index: clang/lib/Sema/SemaTemplateDeduction.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateDeduction.cpp
+++ clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -2054,6 +2054,89 @@
       return Sema::TDK_NonDeducedMismatch;
     }
 
+    //     (clang extension)
+    //
+    //     T __attribute__((matrix_type(<integral constant>, <integral
+    //     constant>)))
+    //     TODO: Allow deduction from matrix type to vector type
+    //     TODO: Decide on deduction from vector type to matrix type
+    case Type::Matrix: {
+      const MatrixType *MatrixParam = cast<MatrixType>(Param);
+      // Matrix-DepSizedMatrix deduction
+      if (const DependentSizedMatrixType *MatrixArg =
+              dyn_cast<DependentSizedMatrixType>(Arg)) {
+        // can't check number of elements since the argument is dependent
+        return DeduceTemplateArgumentsByTypeMatch(
+            S, TemplateParams, MatrixParam->getElementType(),
+            MatrixArg->getElementType(), Info, Deduced, TDF);
+      }
+      // Matrix-Matrix deduction
+      if (const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg)) {
+        // Check that the dimensions are the same
+        if (MatrixParam->getNumRows() != MatrixArg->getNumRows() ||
+            MatrixParam->getNumColumns() != MatrixArg->getNumColumns()) {
+          return Sema::TDK_NonDeducedMismatch;
+        }
+        // Perform deduction on element types
+        return DeduceTemplateArgumentsByTypeMatch(
+            S, TemplateParams, MatrixParam->getElementType(),
+            MatrixArg->getElementType(), Info, Deduced, TDF);
+      }
+      return Sema::TDK_NonDeducedMismatch;
+    }
+
+    case Type::DependentSizedMatrix: {
+      const DependentSizedMatrixType *MatrixParam =
+          cast<DependentSizedMatrixType>(Param);
+      // DepSizedMatrix - DepSizedMatrix deduction
+      // DepSizedMatrix - Matrix deduction
+      if (const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg)) {
+        // Do deduction on the element types
+        if (Sema::TemplateDeductionResult Result =
+                DeduceTemplateArgumentsByTypeMatch(
+                    S, TemplateParams, MatrixParam->getElementType(),
+                    MatrixArg->getElementType(), Info, Deduced, TDF)) {
+          return Result;
+        }
+
+        // Deduce matrix size if possible
+        NonTypeTemplateParmDecl *RowExprTemplateParam =
+            getDeducedParameterFromExpr(Info, MatrixParam->getRowExpr());
+        NonTypeTemplateParmDecl *ColumnExprTemplateParam =
+            getDeducedParameterFromExpr(Info, MatrixParam->getColumnExpr());
+
+        // TODO: Allow one to fail and the other to succeed in the deduction
+        // Can't deduce either rows or columns, just say everything is fine
+        if (!RowExprTemplateParam || !ColumnExprTemplateParam) {
+          return Sema::TDK_Success;
+        }
+
+        // Unsigned might make more sense
+        llvm::APSInt ArgRows(S.Context.getTypeSize(S.Context.IntTy));
+        ArgRows = MatrixArg->getNumRows();
+
+        // Deduce Rows
+        {
+          Sema::TemplateDeductionResult Res = DeduceNonTypeTemplateArgument(
+              S, TemplateParams, RowExprTemplateParam, ArgRows, S.Context.IntTy,
+              true, Info, Deduced);
+          if (Res != Sema::TDK_Success) {
+            return Res;
+          }
+        }
+
+        // Deduce Columns
+        llvm::APSInt ArgColumns(S.Context.getTypeSize(S.Context.IntTy));
+        ArgColumns = MatrixArg->getNumColumns();
+
+        // Deduce columns
+        return DeduceNonTypeTemplateArgument(
+            S, TemplateParams, ColumnExprTemplateParam, ArgColumns,
+            S.Context.IntTy, true, Info, Deduced);
+      }
+      return Sema::TDK_NonDeducedMismatch;
+    }
+
     //     (clang extension)
     //
     //     T __attribute__(((address_space(N))))
@@ -5695,6 +5778,24 @@
     break;
   }
 
+  case Type::Matrix: {
+    const MatrixType *MatType = cast<MatrixType>(T);
+    MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+                               Depth, Used);
+    break;
+  }
+
+  case Type::DependentSizedMatrix: {
+    const DependentSizedMatrixType *MatType = cast<DependentSizedMatrixType>(T);
+    MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+                               Depth, Used);
+    MarkUsedTemplateParameters(Ctx, MatType->getRowExpr(), OnlyDeduced, Depth,
+                               Used);
+    MarkUsedTemplateParameters(Ctx, MatType->getColumnExpr(), OnlyDeduced,
+                               Depth, Used);
+    break;
+  }
+
   case Type::FunctionProto: {
     const FunctionProtoType *Proto = cast<FunctionProtoType>(T);
     MarkUsedTemplateParameters(Ctx, Proto->getReturnType(), OnlyDeduced, Depth,
Index: clang/lib/Sema/SemaTemplate.cpp
===================================================================
--- clang/lib/Sema/SemaTemplate.cpp
+++ clang/lib/Sema/SemaTemplate.cpp
@@ -5829,6 +5829,11 @@
   return Visit(T->getElementType());
 }
 
+bool UnnamedLocalNoLinkageFinder::VisitDependentSizedMatrixType(
+    const DependentSizedMatrixType *T) {
+  return Visit(T->getElementType());
+}
+
 bool UnnamedLocalNoLinkageFinder::VisitDependentAddressSpaceType(
     const DependentAddressSpaceType *T) {
   return Visit(T->getPointeeType());
@@ -5847,6 +5852,10 @@
   return Visit(T->getElementType());
 }
 
+bool UnnamedLocalNoLinkageFinder::VisitMatrixType(const MatrixType *T) {
+  return Visit(T->getElementType());
+}
+
 bool UnnamedLocalNoLinkageFinder::VisitFunctionProtoType(
                                                   const FunctionProtoType* T) {
   for (const auto &A : T->param_types()) {
Index: clang/lib/Sema/SemaLookup.cpp
===================================================================
--- clang/lib/Sema/SemaLookup.cpp
+++ clang/lib/Sema/SemaLookup.cpp
@@ -2966,6 +2966,7 @@
     // These are fundamental types.
     case Type::Vector:
     case Type::ExtVector:
+    case Type::Matrix:
     case Type::Complex:
       break;
 
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -4248,6 +4248,7 @@
     case Type::Complex:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::Matrix:
     case Type::Record:
     case Type::Enum:
     case Type::Elaborated:
Index: clang/lib/Frontend/CompilerInvocation.cpp
===================================================================
--- clang/lib/Frontend/CompilerInvocation.cpp
+++ clang/lib/Frontend/CompilerInvocation.cpp
@@ -3346,6 +3346,8 @@
   Opts.CompleteMemberPointers = Args.hasArg(OPT_fcomplete_member_pointers);
   Opts.BuildingPCHWithObjectFile = Args.hasArg(OPT_building_pch_with_obj);
 
+  Opts.EnableMatrix = Args.hasArg(OPT_fenable_matrix);
+
   Opts.MaxTokens = getLastArgIntValue(Args, OPT_fmax_tokens_EQ, 0, Diags);
 }
 
@@ -3567,7 +3569,7 @@
   InputArgList Args = Opts.ParseArgs(CommandLineArgs, MissingArgIndex,
                                      MissingArgCount, IncludedFlagsBitmask);
   LangOptions &LangOpts = *Res.getLangOpts();
-
+  //
   // Check for missing argument error.
   if (MissingArgCount) {
     Diags.Report(diag::err_drv_missing_argument)
Index: clang/lib/Driver/ToolChains/Clang.cpp
===================================================================
--- clang/lib/Driver/ToolChains/Clang.cpp
+++ clang/lib/Driver/ToolChains/Clang.cpp
@@ -4553,6 +4553,13 @@
   if (Args.hasFlag(options::OPT_mrtd, options::OPT_mno_rtd, false))
     CmdArgs.push_back("-fdefault-calling-conv=stdcall");
 
+  if (Args.hasArg(options::OPT_fenable_matrix)) {
+    // enable-matrix is needed by both the LangOpts and by LLVM.
+    CmdArgs.push_back("-fenable-matrix");
+    CmdArgs.push_back("-mllvm");
+    CmdArgs.push_back("-enable-matrix");
+  }
+
   CodeGenOptions::FramePointerKind FPKeepKind =
                   getFramePointerKind(Args, RawTriple);
   const char *FPKeepKindStr = nullptr;
Index: clang/lib/CodeGen/ItaniumCXXABI.cpp
===================================================================
--- clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -3222,6 +3222,7 @@
   // GCC treats vector and complex types as fundamental types.
   case Type::Vector:
   case Type::ExtVector:
+  case Type::Matrix:
   case Type::Complex:
   case Type::Atomic:
   // FIXME: GCC treats block pointers as fundamental types?!
@@ -3457,6 +3458,7 @@
   case Type::Builtin:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::Matrix:
   case Type::Complex:
   case Type::BlockPointer:
     // Itanium C++ ABI 2.9.5p4:
Index: clang/lib/CodeGen/CodeGenTypes.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenTypes.cpp
+++ clang/lib/CodeGen/CodeGenTypes.cpp
@@ -84,6 +84,13 @@
 /// a type.  For example, the scalar representation for _Bool is i1, but the
 /// memory representation is usually i8 or i32, depending on the target.
 llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T) {
+  if (T->isMatrixType()) {
+    const Type *Ty = Context.getCanonicalType(T).getTypePtr();
+    const MatrixType *MT = cast<MatrixType>(Ty);
+    return llvm::ArrayType::get(ConvertType(MT->getElementType()),
+                                MT->getNumRows() * MT->getNumColumns());
+  }
+
   llvm::Type *R = ConvertType(T);
 
   // If this is a non-bool type, don't map it.
@@ -630,6 +637,12 @@
                                        VT->getNumElements());
     break;
   }
+  case Type::Matrix: {
+    const MatrixType *MT = cast<MatrixType>(Ty);
+    ResultType = llvm::VectorType::get(ConvertType(MT->getElementType()),
+                                       MT->getNumRows() * MT->getNumColumns());
+    break;
+  }
   case Type::FunctionNoProto:
   case Type::FunctionProto:
     ResultType = ConvertFunctionTypeInternal(T);
Index: clang/lib/CodeGen/CodeGenFunction.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.cpp
+++ clang/lib/CodeGen/CodeGenFunction.cpp
@@ -268,6 +268,7 @@
     case Type::MemberPointer:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::Matrix:
     case Type::FunctionProto:
     case Type::FunctionNoProto:
     case Type::Enum:
@@ -2019,6 +2020,7 @@
     case Type::Complex:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::Matrix:
     case Type::Record:
     case Type::Enum:
     case Type::Elaborated:
Index: clang/lib/CodeGen/CGExpr.cpp
===================================================================
--- clang/lib/CodeGen/CGExpr.cpp
+++ clang/lib/CodeGen/CGExpr.cpp
@@ -145,8 +145,19 @@
 
 Address CodeGenFunction::CreateMemTemp(QualType Ty, CharUnits Align,
                                        const Twine &Name, Address *Alloca) {
-  return CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
-                          /*ArraySize=*/nullptr, Alloca);
+  Address Result = CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
+                                    /*ArraySize=*/nullptr, Alloca);
+
+  if (Ty->isMatrixType()) {
+    auto *ArrayTy = cast<llvm::ArrayType>(Result.getType()->getElementType());
+    auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+                                           ArrayTy->getNumElements());
+
+    Result = Address(
+        Builder.CreateBitCast(Result.getPointer(), VectorTy->getPointerTo()),
+        Result.getAlignment());
+  }
+  return Result;
 }
 
 Address CodeGenFunction::CreateMemTempWithoutCast(QualType Ty, CharUnits Align,
@@ -1759,6 +1770,20 @@
     }
   }
 
+  if (Ty->isMatrixType()) {
+    auto *ArrayTy = dyn_cast<llvm::ArrayType>(
+        cast<llvm::PointerType>(Addr.getPointer()->getType())
+            ->getElementType());
+    if (ArrayTy) {
+      auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+                                             ArrayTy->getNumElements());
+
+      Addr = Address(
+          Builder.CreateBitCast(Addr.getPointer(), VectorTy->getPointerTo()),
+          Addr.getAlignment());
+    }
+  }
+
   Value = EmitToMemory(Value, Ty);
 
   LValue AtomicLValue =
@@ -1812,6 +1837,20 @@
   if (LV.isSimple()) {
     assert(!LV.getType()->isFunctionType());
 
+    if (LV.getType()->isMatrixType()) {
+      auto *ArrayTy = dyn_cast<llvm::ArrayType>(
+          cast<llvm::PointerType>(LV.getPointer(*this)->getType())
+              ->getElementType());
+      if (ArrayTy) {
+        auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+                                               ArrayTy->getNumElements());
+
+        LV.setAddress(Address(Builder.CreateBitCast(LV.getPointer(*this),
+                                                    VectorTy->getPointerTo()),
+                              LV.getAlignment()));
+      }
+    }
+
     // Everything needs a load.
     return RValue::get(EmitLoadOfScalar(LV, Loc));
   }
Index: clang/lib/CodeGen/CGDebugInfo.h
===================================================================
--- clang/lib/CodeGen/CGDebugInfo.h
+++ clang/lib/CodeGen/CGDebugInfo.h
@@ -190,6 +190,7 @@
   llvm::DIType *CreateType(const ObjCTypeParamType *Ty, llvm::DIFile *Unit);
 
   llvm::DIType *CreateType(const VectorType *Ty, llvm::DIFile *F);
+  llvm::DIType *CreateType(const MatrixType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const ArrayType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const LValueReferenceType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const RValueReferenceType *Ty, llvm::DIFile *Unit);
Index: clang/lib/CodeGen/CGDebugInfo.cpp
===================================================================
--- clang/lib/CodeGen/CGDebugInfo.cpp
+++ clang/lib/CodeGen/CGDebugInfo.cpp
@@ -2704,6 +2704,23 @@
   return DBuilder.createVectorType(Size, Align, ElementTy, SubscriptArray);
 }
 
+llvm::DIType *CGDebugInfo::CreateType(const MatrixType *Ty,
+                                      llvm::DIFile *Unit) {
+  llvm::DIType *ElementTy = getOrCreateType(Ty->getElementType(), Unit);
+  uint64_t Size = CGM.getContext().getTypeSize(Ty);
+  uint32_t Align = getTypeAlignIfRequired(Ty, CGM.getContext());
+
+  // Number of Columns, followed by rows
+  llvm::SmallVector<llvm::Metadata *, 2> Subscripts;
+  Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumColumns()));
+  Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumRows()));
+  llvm::DINodeArray SubscriptArray = DBuilder.getOrCreateArray(Subscripts);
+
+  // FIXME: Create another debug type for matrices
+  // For the time being, it treats it like a 2D array
+  return DBuilder.createArrayType(Size, Align, ElementTy, SubscriptArray);
+}
+
 llvm::DIType *CGDebugInfo::CreateType(const ArrayType *Ty, llvm::DIFile *Unit) {
   uint64_t Size;
   uint32_t Align;
@@ -3097,6 +3114,8 @@
   case Type::ExtVector:
   case Type::Vector:
     return CreateType(cast<VectorType>(Ty), Unit);
+  case Type::Matrix:
+    return CreateType(cast<MatrixType>(Ty), Unit);
   case Type::ObjCObjectPointer:
     return CreateType(cast<ObjCObjectPointerType>(Ty), Unit);
   case Type::ObjCObject:
Index: clang/lib/Basic/Targets/OSTargets.cpp
===================================================================
--- clang/lib/Basic/Targets/OSTargets.cpp
+++ clang/lib/Basic/Targets/OSTargets.cpp
@@ -133,6 +133,9 @@
     Builder.defineMacro("__MACH__");
 
   PlatformMinVersion = VersionTuple(Maj, Min, Rev);
+
+  if (Opts.EnableMatrix)
+    Builder.defineMacro("__MATRIX_EXTENSION__", "1");
 }
 
 static void addMinGWDefines(const llvm::Triple &Triple, const LangOptions &Opts,
Index: clang/lib/AST/TypePrinter.cpp
===================================================================
--- clang/lib/AST/TypePrinter.cpp
+++ clang/lib/AST/TypePrinter.cpp
@@ -254,6 +254,8 @@
     case Type::DependentSizedExtVector:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::Matrix:
+    case Type::DependentSizedMatrix:
     case Type::FunctionProto:
     case Type::FunctionNoProto:
     case Type::Paren:
@@ -718,6 +720,37 @@
   OS << ")))";
 }
 
+void TypePrinter::printMatrixBefore(const MatrixType *T, raw_ostream &OS) {
+  // TODO: Fix the spacing between the element type and the __attribute__
+  printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  OS << T->getNumRows() << ", " << T->getNumColumns();
+  OS << ")))";
+}
+
+void TypePrinter::printMatrixAfter(const MatrixType *T, raw_ostream &OS) {
+  printAfter(T->getElementType(), OS);
+}
+
+void TypePrinter::printDependentSizedMatrixBefore(
+    const DependentSizedMatrixType *T, raw_ostream &OS) {
+  printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  if (T->getRowExpr()) {
+    T->getRowExpr()->printPretty(OS, nullptr, Policy);
+  }
+  OS << ", ";
+  if (T->getColumnExpr()) {
+    T->getColumnExpr()->printPretty(OS, nullptr, Policy);
+  }
+  OS << ")))";
+}
+
+void TypePrinter::printDependentSizedMatrixAfter(
+    const DependentSizedMatrixType *T, raw_ostream &OS) {
+  printAfter(T->getElementType(), OS);
+}
+
 void
 FunctionProtoType::printExceptionSpecification(raw_ostream &OS,
                                                const PrintingPolicy &Policy)
Index: clang/lib/AST/Type.cpp
===================================================================
--- clang/lib/AST/Type.cpp
+++ clang/lib/AST/Type.cpp
@@ -282,6 +282,45 @@
   AddrSpaceExpr->Profile(ID, Context, true);
 }
 
+MatrixType::MatrixType(QualType matrixType, unsigned nRows, unsigned nColumns,
+                       QualType canonType)
+    : MatrixType(Matrix, matrixType, nRows, nColumns, canonType) {}
+
+MatrixType::MatrixType(TypeClass tc, QualType matrixType, unsigned nRows,
+                       unsigned nColumns, QualType canonType)
+    : Type(tc, canonType, matrixType->getDependence()),
+      ElementType(matrixType) {
+  MatrixTypeBits.NumRows = nRows;
+  MatrixTypeBits.NumColumns = nColumns;
+}
+
+DependentSizedMatrixType::DependentSizedMatrixType(
+    const ASTContext &CTX, QualType ElementType, QualType CanonicalType,
+    Expr *RowExpr, Expr *ColumnExpr, SourceLocation loc)
+    : Type(DependentSizedMatrix, CanonicalType,
+           TypeDependence::Dependent | TypeDependence::Instantiation |
+               (ElementType->isVariablyModifiedType()
+                    ? TypeDependence::VariablyModified
+                    : TypeDependence::None) |
+               (ElementType->containsUnexpandedParameterPack() ||
+                        (RowExpr &&
+                         RowExpr->containsUnexpandedParameterPack()) ||
+                        (ColumnExpr &&
+                         ColumnExpr->containsUnexpandedParameterPack())
+                    ? TypeDependence::UnexpandedPack
+                    : TypeDependence::None)),
+      Context(CTX), RowExpr(RowExpr), ColumnExpr(ColumnExpr),
+      ElementType(ElementType), loc(loc) {}
+
+void DependentSizedMatrixType::Profile(llvm::FoldingSetNodeID &ID,
+                                       const ASTContext &CTX,
+                                       QualType ElementType, Expr *RowExpr,
+                                       Expr *ColumnExpr) {
+  ID.AddPointer(ElementType.getAsOpaquePtr());
+  RowExpr->Profile(ID, CTX, true);
+  ColumnExpr->Profile(ID, CTX, true);
+}
+
 VectorType::VectorType(QualType vecType, unsigned nElements, QualType canonType,
                        VectorKind vecKind)
     : VectorType(Vector, vecType, nElements, canonType, vecKind) {}
@@ -938,6 +977,16 @@
     return Ctx.getExtVectorType(elementType, T->getNumElements());
   }
 
+  QualType VisitMatrixType(const MatrixType *T) {
+    QualType elementType = recurse(T->getElementType());
+    if (elementType.isNull())
+      return {};
+    if (elementType.getAsOpaquePtr() == T->getElementType().getAsOpaquePtr())
+      return QualType(T, 0);
+
+    return Ctx.getMatrixType(elementType, T->getNumRows(), T->getNumColumns());
+  }
+
   QualType VisitFunctionNoProtoType(const FunctionNoProtoType *T) {
     QualType returnType = recurse(T->getReturnType());
     if (returnType.isNull())
@@ -1757,6 +1806,14 @@
       return Visit(T->getElementType());
     }
 
+    Type *VisitDependentSizedMatrixType(const DependentSizedMatrixType *T) {
+      return Visit(T->getElementType());
+    }
+
+    Type *VisitMatrixType(const MatrixType *T) {
+      return Visit(T->getElementType());
+    }
+
     Type *VisitFunctionProtoType(const FunctionProtoType *T) {
       if (Syntactic && T->hasTrailingReturn())
         return const_cast<FunctionProtoType*>(T);
@@ -3675,6 +3732,8 @@
   case Type::Vector:
   case Type::ExtVector:
     return Cache::get(cast<VectorType>(T)->getElementType());
+  case Type::Matrix:
+    return Cache::get(cast<MatrixType>(T)->getElementType());
   case Type::FunctionNoProto:
     return Cache::get(cast<FunctionType>(T)->getReturnType());
   case Type::FunctionProto: {
@@ -3760,6 +3819,8 @@
   case Type::Vector:
   case Type::ExtVector:
     return computeTypeLinkageInfo(cast<VectorType>(T)->getElementType());
+  case Type::Matrix:
+    return computeTypeLinkageInfo(cast<MatrixType>(T)->getElementType());
   case Type::FunctionNoProto:
     return computeTypeLinkageInfo(cast<FunctionType>(T)->getReturnType());
   case Type::FunctionProto: {
@@ -3921,6 +3982,8 @@
   case Type::DependentSizedExtVector:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::Matrix:
+  case Type::DependentSizedMatrix:
   case Type::DependentAddressSpace:
   case Type::FunctionProto:
   case Type::FunctionNoProto:
Index: clang/lib/AST/MicrosoftMangle.cpp
===================================================================
--- clang/lib/AST/MicrosoftMangle.cpp
+++ clang/lib/AST/MicrosoftMangle.cpp
@@ -2755,6 +2755,23 @@
     << Range;
 }
 
+void MicrosoftCXXNameMangler::mangleType(const MatrixType *T, Qualifiers quals,
+                                         SourceRange Range) {
+  DiagnosticsEngine &Diags = Context.getDiags();
+  unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
+                                          "Cannot mangle this matrix type yet");
+  Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
+void MicrosoftCXXNameMangler::mangleType(const DependentSizedMatrixType *T,
+                                         Qualifiers quals, SourceRange Range) {
+  DiagnosticsEngine &Diags = Context.getDiags();
+  unsigned DiagID = Diags.getCustomDiagID(
+      DiagnosticsEngine::Error,
+      "Cannot mangle this dependent-sized matrix type yet");
+  Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
 void MicrosoftCXXNameMangler::mangleType(const DependentAddressSpaceType *T,
                                          Qualifiers, SourceRange Range) {
   DiagnosticsEngine &Diags = Context.getDiags();
Index: clang/lib/AST/ItaniumMangle.cpp
===================================================================
--- clang/lib/AST/ItaniumMangle.cpp
+++ clang/lib/AST/ItaniumMangle.cpp
@@ -2065,6 +2065,8 @@
   case Type::DependentSizedExtVector:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::Matrix:
+  case Type::DependentSizedMatrix:
   case Type::FunctionProto:
   case Type::FunctionNoProto:
   case Type::Paren:
@@ -3327,6 +3329,20 @@
   mangleType(T->getElementType());
 }
 
+void CXXNameMangler::mangleType(const MatrixType *T) {
+  Out << "Dm" << T->getNumRows() << "_" << T->getNumColumns() << '_';
+  mangleType(T->getElementType());
+}
+
+void CXXNameMangler::mangleType(const DependentSizedMatrixType *T) {
+  Out << "Dm";
+  mangleExpression(T->getRowExpr());
+  Out << '_';
+  mangleExpression(T->getColumnExpr());
+  Out << '_';
+  mangleType(T->getElementType());
+}
+
 void CXXNameMangler::mangleType(const DependentAddressSpaceType *T) {
   SplitQualType split = T->getPointeeType().split();
   mangleQualifiers(split.Quals, T);
Index: clang/lib/AST/ExprConstant.cpp
===================================================================
--- clang/lib/AST/ExprConstant.cpp
+++ clang/lib/AST/ExprConstant.cpp
@@ -10286,6 +10286,7 @@
   case Type::BlockPointer:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::Matrix:
   case Type::ObjCObject:
   case Type::ObjCInterface:
   case Type::ObjCObjectPointer:
Index: clang/lib/AST/ASTStructuralEquivalence.cpp
===================================================================
--- clang/lib/AST/ASTStructuralEquivalence.cpp
+++ clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -617,6 +617,39 @@
     break;
   }
 
+  case Type::DependentSizedMatrix: {
+    const DependentSizedMatrixType *Mat1 = cast<DependentSizedMatrixType>(T1);
+    const DependentSizedMatrixType *Mat2 = cast<DependentSizedMatrixType>(T2);
+    // Rows
+    if (!IsStructurallyEquivalent(Context, Mat1->getRowExpr(),
+                                  Mat2->getRowExpr())) {
+      return false;
+    }
+    // Columns
+    if (!IsStructurallyEquivalent(Context, Mat1->getColumnExpr(),
+                                  Mat2->getColumnExpr())) {
+      return false;
+    }
+    // Element Type
+    if (!IsStructurallyEquivalent(Context, Mat1->getElementType(),
+                                  Mat2->getElementType())) {
+      return false;
+    }
+    return true;
+  }
+
+  case Type::Matrix: {
+    const MatrixType *Mat1 = cast<MatrixType>(T1);
+    const MatrixType *Mat2 = cast<MatrixType>(T2);
+    if (!IsStructurallyEquivalent(Context, Mat1->getElementType(),
+                                  Mat2->getElementType()))
+      return false;
+    if (Mat1->getNumRows() != Mat2->getNumRows() ||
+        Mat1->getNumColumns() != Mat2->getNumColumns())
+      return false;
+    break;
+  }
+
   case Type::FunctionProto: {
     const auto *Proto1 = cast<FunctionProtoType>(T1);
     const auto *Proto2 = cast<FunctionProtoType>(T2);
Index: clang/lib/AST/ASTContext.cpp
===================================================================
--- clang/lib/AST/ASTContext.cpp
+++ clang/lib/AST/ASTContext.cpp
@@ -1926,6 +1926,18 @@
     break;
   }
 
+  case Type::Matrix: {
+    const auto *MT = cast<MatrixType>(T);
+    TypeInfo ElementInfo = getTypeInfo(MT->getElementType());
+    // The matrix type is intended to be ABI compatible with arrays with respect
+    // to alignment and size. We use LLVM's array type for storage.
+    Width = ElementInfo.Width * MT->getNumRows() * MT->getNumColumns();
+    // If the alignment is not a power of 2, round up to the next power of 2.
+    // This happens for non-power-of-2 length vectors.
+    Align = ElementInfo.Width;
+    break;
+  }
+
   case Type::Builtin:
     switch (cast<BuiltinType>(T)->getKind()) {
     default: llvm_unreachable("Unknown builtin type!");
@@ -3342,6 +3354,8 @@
   case Type::DependentVector:
   case Type::ExtVector:
   case Type::DependentSizedExtVector:
+  case Type::Matrix:
+  case Type::DependentSizedMatrix:
   case Type::DependentAddressSpace:
   case Type::ObjCObject:
   case Type::ObjCInterface:
@@ -3753,6 +3767,76 @@
   return QualType(New, 0);
 }
 
+/// getMatrixType - Return the unique reference to a matrix type of the
+/// specified element type and size. ElementTy must be a built-in integer or
+/// floating point type.
+QualType ASTContext::getMatrixType(QualType ElementTy, unsigned NumRows,
+                                   unsigned NumColumns) const {
+  llvm::FoldingSetNodeID ID;
+  MatrixType::Profile(ID, ElementTy, NumRows, NumColumns, Type::Matrix);
+
+  void *InsertPos = nullptr;
+  if (MatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos)) {
+    return QualType(MTP, 0);
+  }
+
+  QualType Canonical;
+  if (!ElementTy.isCanonical()) {
+    Canonical = getMatrixType(getCanonicalType(ElementTy), NumRows, NumColumns);
+
+    MatrixType *NewIP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+    assert(!NewIP && "Matrix type shouldn't already exist in the map");
+    (void)NewIP;
+  }
+
+  auto *New = new (*this, TypeAlignment)
+      MatrixType(ElementTy, NumRows, NumColumns, Canonical);
+  MatrixTypes.InsertNode(New, InsertPos);
+  Types.push_back(New);
+  return QualType(New, 0);
+}
+
+// getDependentSizedMatrixType - Return a unique reference to the
+// dependent matrix MatrixElementType must be a builtin type
+QualType ASTContext::getDependentSizedMatrixType(QualType MatrixElementType,
+                                                 Expr *RowExpr,
+                                                 Expr *ColumnExpr,
+                                                 SourceLocation AttrLoc) const {
+  llvm::FoldingSetNodeID ID;
+  DependentSizedMatrixType::Profile(
+      ID, *this, getCanonicalType(MatrixElementType), RowExpr, ColumnExpr);
+
+  void *InsertPos = nullptr;
+  DependentSizedMatrixType *Canon =
+      DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+  DependentSizedMatrixType *New;
+  if (Canon) {
+    // Already have a canonical version of the matrix type
+    // Use it as the canonical type for newly-built types
+    New = new (*this, TypeAlignment)
+        DependentSizedMatrixType(*this, MatrixElementType, QualType(Canon, 0),
+                                 RowExpr, ColumnExpr, AttrLoc);
+  } else {
+    QualType CanonicalMatrixElementType = getCanonicalType(MatrixElementType);
+    if (CanonicalMatrixElementType == MatrixElementType) {
+      New = new (*this, TypeAlignment) DependentSizedMatrixType(
+          *this, MatrixElementType, QualType(), RowExpr, ColumnExpr, AttrLoc);
+      DependentSizedMatrixType *CanonCheck =
+          DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+      assert(!CanonCheck && "Dependent-sized matrix canonical type broken");
+      (void)CanonCheck;
+      DependentSizedMatrixTypes.InsertNode(New, InsertPos);
+    } else {
+      QualType Canon = getDependentSizedMatrixType(
+          CanonicalMatrixElementType, RowExpr, ColumnExpr, SourceLocation());
+      New = new (*this, TypeAlignment) DependentSizedMatrixType(
+          *this, MatrixElementType, Canon, RowExpr, ColumnExpr, AttrLoc);
+    }
+  }
+  Types.push_back(New);
+  return QualType(New, 0);
+}
+
 QualType ASTContext::getDependentAddressSpaceType(QualType PointeeType,
                                                   Expr *AddrSpaceExpr,
                                                   SourceLocation AttrLoc) const {
@@ -7267,6 +7351,11 @@
       *NotEncodedT = T;
     return;
 
+  case Type::Matrix:
+    if (NotEncodedT)
+      *NotEncodedT = T;
+    return;
+
   // We could see an undeduced auto type here during error recovery.
   // Just ignore it.
   case Type::Auto:
@@ -8092,6 +8181,15 @@
          LHS->getNumElements() == RHS->getNumElements();
 }
 
+/// areCompatMatrixTypes - Return true if the two specified vector types are
+/// compatible.
+static bool areCompatMatrixTypes(const MatrixType *LHS, const MatrixType *RHS) {
+  assert(LHS->isCanonicalUnqualified() && RHS->isCanonicalUnqualified());
+  return LHS->getElementType() == RHS->getElementType() &&
+         LHS->getNumRows() == RHS->getNumRows() &&
+         LHS->getNumColumns() == RHS->getNumColumns();
+}
+
 bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
                                           QualType SecondVec) {
   assert(FirstVec->isVectorType() && "FirstVec should be a vector type");
@@ -9288,6 +9386,11 @@
                              RHSCan->castAs<VectorType>()))
       return LHS;
     return {};
+  case Type::Matrix:
+    if (areCompatMatrixTypes(LHSCan->castAs<MatrixType>(),
+                             RHSCan->castAs<MatrixType>()))
+      return LHS;
+    return {};
   case Type::ObjCObject: {
     // Check if the types are assignment compatible.
     // FIXME: This should be type compatibility, e.g. whether
Index: clang/include/clang/Serialization/TypeBitCodes.def
===================================================================
--- clang/include/clang/Serialization/TypeBitCodes.def
+++ clang/include/clang/Serialization/TypeBitCodes.def
@@ -58,5 +58,7 @@
 TYPE_BIT_CODE(DependentAddressSpace, DEPENDENT_ADDRESS_SPACE, 47)
 TYPE_BIT_CODE(DependentVector, DEPENDENT_SIZED_VECTOR, 48)
 TYPE_BIT_CODE(MacroQualified, MACRO_QUALIFIED, 49)
+TYPE_BIT_CODE(Matrix, MATRIX, 50)
+TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 51)
 
 #undef TYPE_BIT_CODE
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -1625,6 +1625,9 @@
   QualType BuildVectorType(QualType T, Expr *VecSize, SourceLocation AttrLoc);
   QualType BuildExtVectorType(QualType T, Expr *ArraySize,
                               SourceLocation AttrLoc);
+  QualType BuildMatrixType(QualType T, Expr *NumRows, Expr *NumColumns,
+                           SourceLocation AttrLoc);
+
   QualType BuildAddressSpaceAttr(QualType &T, LangAS ASIdx, Expr *AddrSpace,
                                  SourceLocation AttrLoc);
 
Index: clang/include/clang/Driver/Options.td
===================================================================
--- clang/include/clang/Driver/Options.td
+++ clang/include/clang/Driver/Options.td
@@ -1982,6 +1982,10 @@
 def fno_strict_return : Flag<["-"], "fno-strict-return">, Group<f_Group>,
   Flags<[CC1Option]>;
 
+def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
+    Flags<[CC1Option]>,
+    HelpText<"Enable matrix data type and related builtin functions">;
+
 def fallow_editor_placeholders : Flag<["-"], "fallow-editor-placeholders">,
   Group<f_Group>, Flags<[CC1Option]>,
   HelpText<"Treat editor placeholders as valid source code">;
Index: clang/include/clang/Basic/TypeNodes.td
===================================================================
--- clang/include/clang/Basic/TypeNodes.td
+++ clang/include/clang/Basic/TypeNodes.td
@@ -65,10 +65,12 @@
 def VariableArrayType : TypeNode<ArrayType>;
 def DependentSizedArrayType : TypeNode<ArrayType>, AlwaysDependent;
 def DependentSizedExtVectorType : TypeNode<Type>, AlwaysDependent;
+def DependentSizedMatrixType : TypeNode<Type>, AlwaysDependent;
 def DependentAddressSpaceType : TypeNode<Type>, AlwaysDependent;
 def VectorType : TypeNode<Type>;
 def DependentVectorType : TypeNode<Type>, AlwaysDependent;
 def ExtVectorType : TypeNode<VectorType>;
+def MatrixType : TypeNode<Type>;
 def FunctionType : TypeNode<Type, 1>;
 def FunctionProtoType : TypeNode<FunctionType>;
 def FunctionNoProtoType : TypeNode<FunctionType>;
Index: clang/include/clang/Basic/LangOptions.def
===================================================================
--- clang/include/clang/Basic/LangOptions.def
+++ clang/include/clang/Basic/LangOptions.def
@@ -351,6 +351,8 @@
 
 LANGOPT(RegisterStaticDestructors, 1, 1, "Register C++ static destructors")
 
+LANGOPT(EnableMatrix, 1, 0, "Enable or disable the builtin matrix type")
+
 COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0")
 
 #undef LANGOPT
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -2764,6 +2764,7 @@
 def err_attribute_too_few_arguments : Error<
   "%0 attribute takes at least %1 argument%s1">;
 def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">;
+def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">;
 def err_attribute_bad_neon_vector_size : Error<
   "Neon vector size must be 64 or 128 bits">;
 def err_attribute_requires_positive_integer : Error<
@@ -10629,6 +10630,9 @@
   "%select{non-pointer|function pointer|void pointer}0 argument to "
   "'__builtin_launder' is not allowed">;
 
+def err_builtin_matrix_disabled: Error<
+  "Builtin matrix support is disabled. Pass -fenable-matrix to enable it.">;
+
 def err_preserve_field_info_not_field : Error<
   "__builtin_preserve_field_info argument %0 not a field access">;
 def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -2464,6 +2464,15 @@
   let Documentation = [Undocumented];
 }
 
+def MatrixType : TypeAttr {
+  let Spellings = [Clang<"matrix_type">];
+  let Subjects = SubjectList<[TypedefName], ErrorDiag>;
+  let Args = [ExprArgument<"NumRows">, ExprArgument<"NumColumns">];
+  let Documentation = [Undocumented];
+  let ASTNode = 0;
+  let PragmaAttributeSupport = 0;
+}
+
 def Visibility : InheritableAttr {
   let Clone = 0;
   let Spellings = [GCC<"visibility">];
Index: clang/include/clang/AST/TypeProperties.td
===================================================================
--- clang/include/clang/AST/TypeProperties.td
+++ clang/include/clang/AST/TypeProperties.td
@@ -224,6 +224,41 @@
   }]>;
 }
 
+let Class = MatrixType in {
+  def : Property<"elementType", QualType> {
+    let Read = [{ node->getElementType() }];
+  }
+  def : Property<"numRows", UInt32> {
+    let Read = [{ node->getNumRows() }];
+  }
+  def : Property<"numColumns", UInt32> {
+    let Read = [{ node->getNumColumns() }];
+  }
+
+  def : Creator<[{
+    return ctx.getMatrixType(elementType, numRows, numColumns);
+  }]>;
+}
+
+let Class = DependentSizedMatrixType in {
+  def : Property<"elementType", QualType> {
+    let Read = [{ node->getElementType() }];
+  }
+  def : Property<"rows", ExprRef> {
+    let Read = [{ node->getRowExpr() }];
+  }
+  def : Property<"columns", ExprRef> {
+    let Read = [{ node->getColumnExpr() }];
+  }
+  def : Property<"attributeLoc", SourceLocation> {
+    let Read = [{ node->getAttributeLoc() }];
+  }
+
+  def : Creator<[{
+    return ctx.getDependentSizedMatrixType(elementType, rows, columns, attributeLoc);
+  }]>;
+}
+
 let Class = FunctionType in {
   def : Property<"returnType", QualType> {
     let Read = [{ node->getReturnType() }];
Index: clang/include/clang/AST/TypeLoc.h
===================================================================
--- clang/include/clang/AST/TypeLoc.h
+++ clang/include/clang/AST/TypeLoc.h
@@ -1774,6 +1774,18 @@
                                      DependentSizedExtVectorType> {
 };
 
+// Same as VectorType: FIXME: attribute locations.
+class MatrixTypeLoc
+    : public InheritingConcreteTypeLoc<TypeSpecTypeLoc, MatrixTypeLoc,
+                                       MatrixType> {};
+
+// Same as VectorType: FIXME: attribute locations.  Also look into making this
+// a subtype of the MatrixTypeLoc
+class DependentSizedMatrixTypeLoc
+    : public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
+                                       DependentSizedMatrixTypeLoc,
+                                       DependentSizedMatrixType> {};
+
 // FIXME: location of the '_Complex' keyword.
 class ComplexTypeLoc : public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
                                                         ComplexTypeLoc,
Index: clang/include/clang/AST/Type.h
===================================================================
--- clang/include/clang/AST/Type.h
+++ clang/include/clang/AST/Type.h
@@ -1657,6 +1657,19 @@
     enum { MaxNumElements = (1 << (29 - NumTypeBits)) - 1 };
   };
 
+  class MatrixTypeBitfields {
+    friend class MatrixType;
+
+    unsigned : NumTypeBits;
+
+    // Number of rows and columns
+    unsigned NumRows : 29 - NumTypeBits;
+    unsigned NumColumns : 29 - NumTypeBits;
+
+    enum { MaxNumRows = (1 << (29 - NumTypeBits)) - 1 };
+    enum { MaxNumColumns = (1 << (29 - NumTypeBits)) - 1 };
+  };
+
   class AttributedTypeBitfields {
     friend class AttributedType;
 
@@ -1766,6 +1779,7 @@
     TypeWithKeywordBitfields TypeWithKeywordBits;
     ElaboratedTypeBitfields ElaboratedTypeBits;
     VectorTypeBitfields VectorTypeBits;
+    MatrixTypeBitfields MatrixTypeBits;
     SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits;
     TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits;
     DependentTemplateSpecializationTypeBitfields
@@ -2024,6 +2038,7 @@
   bool isComplexIntegerType() const;            // GCC _Complex integer type.
   bool isVectorType() const;                    // GCC vector type.
   bool isExtVectorType() const;                 // Extended vector type.
+  bool isMatrixType() const;
   bool isDependentAddressSpaceType() const;     // value-dependent address space qualifier
   bool isObjCObjectPointerType() const;         // pointer to ObjC object
   bool isObjCRetainableType() const;            // ObjC object or block pointer
@@ -3386,6 +3401,114 @@
   }
 };
 
+/// MatrixType - This type is created using
+/// __attribute__((matrix_type(rows, columns))), where "rows" is the
+/// number of rows and "columns" is the number of columns.
+class MatrixType : public Type, public llvm::FoldingSetNode {
+protected:
+  friend class ASTContext;
+
+  QualType ElementType;
+
+  // MatrixElementType:   The type of the elements in the matrix
+  // NRows:               Number of rows
+  // NColumns:            Number of columns
+  // CanonElementType:    Canonical element type (if the matrix type is not
+  // canonical)
+  MatrixType(QualType MatrixElementType, unsigned NRows, unsigned NColumns,
+             QualType CanonElementType);
+
+  // typeClass:           The typeclass (defined in TypeNodes.def)
+  // MatrixElementType:   The type of elements in the matrix
+  // NRows:               The number of rows
+  // NColumns:            The number of columns
+  // CanonElementType:    Canonical type (if the matrixType is not canonical)
+  MatrixType(TypeClass typeClass, QualType MatrixType, unsigned NRows,
+             unsigned NColumns, QualType CanonElementType);
+
+public:
+  // The type of the elements being stored in the matrix
+  QualType getElementType() const { return ElementType; }
+
+  // The number of rows in the matrix
+  unsigned getNumRows() const { return MatrixTypeBits.NumRows; }
+
+  // The number of columns in the matrix
+  unsigned getNumColumns() const { return MatrixTypeBits.NumColumns; }
+
+  unsigned getNumElementsFlattened() const {
+    return MatrixTypeBits.NumRows * MatrixTypeBits.NumColumns;
+  }
+
+  // Check if the dimensions of the matrix fit in data storage type
+  static bool tooBig(unsigned NumRows, unsigned NumColumns) {
+    return NumRows > MatrixTypeBitfields::MaxNumRows ||
+           NumColumns > MatrixTypeBitfields::MaxNumColumns;
+  }
+
+  bool isSugared() const { return false; }
+  QualType desugar() const { return QualType(this, 0); }
+
+  void Profile(llvm::FoldingSetNodeID &ID) {
+    Profile(ID, getElementType(), getNumRows(), getNumColumns(),
+            getTypeClass());
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, QualType ElementType,
+                      unsigned NumRows, unsigned NumColumns,
+                      TypeClass TypeClass) {
+    ID.AddPointer(ElementType.getAsOpaquePtr());
+    ID.AddInteger(NumRows);
+    ID.AddInteger(NumColumns);
+    ID.AddInteger(TypeClass);
+  }
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == Matrix ||
+           T->getTypeClass() == DependentSizedMatrix;
+  }
+};
+
+/// DependentSizedMatrixType - Represents a matrix type where the type
+/// and size is dependnt on a template.
+///
+class DependentSizedMatrixType : public Type, public llvm::FoldingSetNode {
+  friend class ASTContext;
+
+  const ASTContext &Context;
+  Expr *RowExpr;
+  Expr *ColumnExpr;
+
+  /// The element type of the matrix
+  QualType ElementType;
+
+  SourceLocation loc;
+
+  DependentSizedMatrixType(const ASTContext &Context, QualType ElementType,
+                           QualType CanonicalType, Expr *RowExpr,
+                           Expr *ColumnExpr, SourceLocation loc);
+
+public:
+  QualType getElementType() const { return ElementType; }
+  Expr *getRowExpr() const { return RowExpr; }
+  Expr *getColumnExpr() const { return ColumnExpr; }
+  SourceLocation getAttributeLoc() const { return loc; }
+
+  bool isSugared() const { return false; }
+  QualType desugar() const { return QualType(this, 0); }
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == DependentSizedMatrix;
+  }
+
+  void Profile(llvm::FoldingSetNodeID &ID) {
+    Profile(ID, Context, getElementType(), getRowExpr(), getColumnExpr());
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
+                      QualType ElementType, Expr *RowExpr, Expr *ColumnExpr);
+};
+
 /// FunctionType - C99 6.7.5.3 - Function Declarators.  This is the common base
 /// class of FunctionNoProtoType and FunctionProtoType.
 class FunctionType : public Type {
@@ -6543,6 +6666,10 @@
   return isa<ExtVectorType>(CanonicalType);
 }
 
+inline bool Type::isMatrixType() const {
+  return isa<MatrixType>(CanonicalType);
+}
+
 inline bool Type::isDependentAddressSpaceType() const {
   return isa<DependentAddressSpaceType>(CanonicalType);
 }
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -1006,6 +1006,16 @@
 
 DEF_TRAVERSE_TYPE(ExtVectorType, { TRY_TO(TraverseType(T->getElementType())); })
 
+DEF_TRAVERSE_TYPE(MatrixType, { TRY_TO(TraverseType(T->getElementType())); })
+
+DEF_TRAVERSE_TYPE(DependentSizedMatrixType, {
+  if (T->getRowExpr())
+    TRY_TO(TraverseStmt(T->getRowExpr()));
+  if (T->getColumnExpr())
+    TRY_TO(TraverseStmt(T->getColumnExpr()));
+  TRY_TO(TraverseType(T->getElementType()));
+})
+
 DEF_TRAVERSE_TYPE(FunctionNoProtoType,
                   { TRY_TO(TraverseType(T->getReturnType())); })
 
@@ -1254,6 +1264,21 @@
   TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
 })
 
+// Same as VectorType: FIXME: MatrixTypeLoc is unfinished
+DEF_TRAVERSE_TYPELOC(MatrixType, {
+  TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
+DEF_TRAVERSE_TYPELOC(DependentSizedMatrixType, {
+  if (TL.getTypePtr()->getRowExpr()) {
+    TRY_TO(TraverseStmt(TL.getTypePtr()->getRowExpr()));
+  }
+  if (TL.getTypePtr()->getColumnExpr()) {
+    TRY_TO(TraverseStmt(TL.getTypePtr()->getColumnExpr()));
+  }
+  TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
 DEF_TRAVERSE_TYPELOC(FunctionNoProtoType,
                      { TRY_TO(TraverseTypeLoc(TL.getReturnLoc())); })
 
Index: clang/include/clang/AST/ASTContext.h
===================================================================
--- clang/include/clang/AST/ASTContext.h
+++ clang/include/clang/AST/ASTContext.h
@@ -193,6 +193,8 @@
       DependentAddressSpaceTypes;
   mutable llvm::FoldingSet<VectorType> VectorTypes;
   mutable llvm::FoldingSet<DependentVectorType> DependentVectorTypes;
+  mutable llvm::FoldingSet<MatrixType> MatrixTypes;
+  mutable llvm::FoldingSet<DependentSizedMatrixType> DependentSizedMatrixTypes;
   mutable llvm::FoldingSet<FunctionNoProtoType> FunctionNoProtoTypes;
   mutable llvm::ContextualFoldingSet<FunctionProtoType, ASTContext&>
     FunctionProtoTypes;
@@ -1309,6 +1311,21 @@
                                           Expr *SizeExpr,
                                           SourceLocation AttrLoc) const;
 
+  /// Return the unique reference to the matrix type of the specified element
+  /// type and size
+  ///
+  /// \pre \p MatrixType must be a built-in type.
+  QualType getMatrixType(QualType MatrixType, unsigned NumRows,
+                         unsigned NumColumns) const;
+
+  /// Return the unique reference to the matrix type of the specified element
+  /// type and size
+  ///
+  /// \pre \p MatrixElementType must be a built-in type.
+  QualType getDependentSizedMatrixType(QualType MatrixElementType,
+                                       Expr *RowExpr, Expr *ColumnExpr,
+                                       SourceLocation AttrLoc) const;
+
   QualType getDependentAddressSpaceType(QualType PointeeType,
                                         Expr *AddrSpaceExpr,
                                         SourceLocation AttrLoc) const;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to