fhahn created this revision.
Herald added a subscriber: tschuett.
Herald added a project: clang.

Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D72773

Files:
  clang/include/clang/Basic/Builtins.def
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/lib/Sema/SemaChecking.cpp
  clang/test/CodeGen/builtin-matrix.c
  clang/test/CodeGenCXX/builtin-matrix.cpp
  clang/test/Sema/builtin-matrix.c
  clang/test/SemaCXX/builtin-matrix.cpp

Index: clang/test/SemaCXX/builtin-matrix.cpp
===================================================================
--- clang/test/SemaCXX/builtin-matrix.cpp
+++ clang/test/SemaCXX/builtin-matrix.cpp
@@ -42,3 +42,60 @@
   Mat1.value = *((decltype(Mat1)::matrix_t*) Ptr1);
   unsigned v1 = extract(Mat1); // expected-note {{in instantiation of function template specialization 'extract<unsigned int, 2, 2>' requested here}}
 }
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t add(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = __builtin_matrix_add(A.value, B.value);
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2))) ')}}
+  // expected-error@-2 {{Matrix types must match}}
+  // expected-error@-3 {{Matrix types must match}}
+
+  return __builtin_matrix_add(A.value, B.value);
+  // expected-error@-1 {{Matrix types must match}}P
+  // expected-error@-2 {{Matrix types must match}}P
+}
+
+void test_add_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t*) Ptr1);
+  unsigned v1 = add<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2))) ')}}
+  // expected-note@-2 {{in instantiation of function template specialization 'add<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = add<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'add<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = add<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3);
+  // expected-note@-1 {{in instantiation of function template specialization 'add<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
+}
+
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t subtract(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = __builtin_matrix_subtract(A.value, B.value);
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2))) ')}}
+  // expected-error@-2 {{Matrix types must match}}
+  // expected-error@-3 {{Matrix types must match}}
+
+  return __builtin_matrix_subtract(A.value, B.value);
+  // expected-error@-1 {{Matrix types must match}}P
+  // expected-error@-2 {{Matrix types must match}}P
+}
+
+void test_subtract_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t*) Ptr1);
+  unsigned v1 = subtract<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2))) ')}}
+  // expected-note@-2 {{in instantiation of function template specialization 'subtract<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = subtract<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = subtract<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3);
+  // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
+}
Index: clang/test/Sema/builtin-matrix.c
===================================================================
--- clang/test/Sema/builtin-matrix.c
+++ clang/test/Sema/builtin-matrix.c
@@ -40,3 +40,43 @@
    float v4 = __builtin_matrix_extract(
       *a, 1, 1, 1); // expected-error {{too many arguments to function call, expected 3, have 4}}
 }
+
+
+typedef float sx10x5_t __attribute__((matrix_type(10, 5)));
+typedef float sx5x10_t __attribute__((matrix_type(5, 10)));
+
+void add(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
+    a = __builtin_matrix_add(
+        b, c); // expected-error {{Matrix types must match}}
+
+    a = __builtin_matrix_add( // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10))) ') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10))) ')}}
+        b, b);
+
+    a = __builtin_matrix_add(
+        10, b); // expected-error {{First argument must be a matrix}}
+
+    a = __builtin_matrix_add(
+        b, &c); // expected-error {{Second argument must be a matrix}}
+
+    a = __builtin_matrix_add(
+        &a,  // expected-error {{First argument must be a matrix}}
+        &c); // expected-error {{Second argument must be a matrix}}
+}
+
+void sub(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
+    a = __builtin_matrix_subtract(
+        b, c); // expected-error {{Matrix types must match}}
+
+    a = __builtin_matrix_subtract( // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10))) ') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10))) ')}}
+        b, b);
+
+    a = __builtin_matrix_subtract(
+        10, b); // expected-error {{First argument must be a matrix}}
+
+    a = __builtin_matrix_subtract(
+        b, &c); // expected-error {{Second argument must be a matrix}}
+
+    a = __builtin_matrix_subtract(
+        &a,  // expected-error {{First argument must be a matrix}}
+        &c); // expected-error {{Second argument must be a matrix}}
+}
Index: clang/test/CodeGenCXX/builtin-matrix.cpp
===================================================================
--- clang/test/CodeGenCXX/builtin-matrix.cpp
+++ clang/test/CodeGenCXX/builtin-matrix.cpp
@@ -188,7 +188,6 @@
 }
 
 void test_extract_template(unsigned *Ptr1, float *Ptr2) {
-
   // CHECK-LABEL: define void @_Z21test_extract_templatePjPf(i32* %Ptr1, float* %Ptr2)
   // CHECK-NEXT:  entry:
   // CHECK-NEXT:    %Ptr1.addr = alloca i32*, align 8
@@ -223,3 +222,79 @@
   Mat1.value = *((decltype(Mat1)::matrix_t*) Ptr1);
   unsigned v1 = extract(Mat1);
 }
+
+template <typename EltTy0, unsigned R0, unsigned C0>
+typename MyMatrix<EltTy0, R0, C0>::matrix_t add(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, R0, C0> &B) {
+  return __builtin_matrix_add(A.value, B.value);
+}
+
+void test_add_template() {
+  // CHECK-LABEL:    define void @_Z17test_add_templatev()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Mat1 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %Mat2 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %call = call <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %Mat1, %struct.MyMatrix.1* dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0
+  // CHECK-NEXT:    %0 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    store <10 x float> %call, <10 x float>* %0, align 4
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %A, %struct.MyMatrix.1* dereferenceable(40) %B)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %A.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    %B.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    %2 = load <10 x float>, <10 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0
+  // CHECK-NEXT:    %4 = bitcast [10 x float]* %value1 to <10 x float>*
+  // CHECK-NEXT:    %5 = load <10 x float>, <10 x float>* %4, align 4
+  // CHECK-NEXT:    %6 = fadd <10 x float> %2, %5
+  // CHECK-NEXT:    ret <10 x float> %6
+
+  MyMatrix<float, 2, 5> Mat1;
+  MyMatrix<float, 2, 5> Mat2;
+  Mat1.value = add(Mat1, Mat2);
+}
+
+template <typename EltTy0, unsigned R0, unsigned C0>
+typename MyMatrix<EltTy0, R0, C0>::matrix_t subtract(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, R0, C0> &B) {
+  return __builtin_matrix_subtract(A.value, B.value);
+}
+
+void test_subtract_template() {
+  // CHECK-LABEL: define void @_Z22test_subtract_templatev()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Mat1 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %Mat2 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %call = call <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %Mat1, %struct.MyMatrix.1* dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0
+  // CHECK-NEXT:    %0 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    store <10 x float> %call, <10 x float>* %0, align 4
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %A, %struct.MyMatrix.1* dereferenceable(40) %B)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %A.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    %B.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    %2 = load <10 x float>, <10 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0
+  // CHECK-NEXT:    %4 = bitcast [10 x float]* %value1 to <10 x float>*
+  // CHECK-NEXT:    %5 = load <10 x float>, <10 x float>* %4, align 4
+  // CHECK-NEXT:    %6 = fsub <10 x float> %2, %5
+  // CHECK-NEXT:    ret <10 x float> %6
+
+  MyMatrix<float, 2, 5> Mat1;
+  MyMatrix<float, 2, 5> Mat2;
+  Mat1.value = subtract(Mat1, Mat2);
+}
Index: clang/test/CodeGen/builtin-matrix.c
===================================================================
--- clang/test/CodeGen/builtin-matrix.c
+++ clang/test/CodeGen/builtin-matrix.c
@@ -155,3 +155,73 @@
   // CHECK-NEXT:    store i32 %8, i32* %v3, align 4
   // CHECK-NEXT:    ret void
 }
+
+void add1(dx5x5_t a, dx5x5_t b, dx5x5_t c,  ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) {
+  a = __builtin_matrix_add(b, c);
+  ai = __builtin_matrix_add(bi, ci);
+
+  // CHECK-LABEL: @add1(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %ai.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %bi.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %ci.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %b.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %b, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %c.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %c, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ai, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %bi, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ci, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %6 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %7 = load <25 x double>, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %8 = fadd <25 x double> %6, %7
+  // CHECK-NEXT:    store <25 x double> %8, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %9 = load <27 x i32>, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %10 = load <27 x i32>, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %11 = add <27 x i32> %9, %10
+  // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    ret void
+}
+
+void sub1(dx5x5_t a, dx5x5_t b, dx5x5_t c,  ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) {
+  a = __builtin_matrix_subtract(b, c);
+  ai = __builtin_matrix_subtract(bi, ci);
+
+  // CHECK-LABEL: @sub1(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %ai.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %bi.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %ci.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %b.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %b, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %c.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %c, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ai, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %bi, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ci, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %6 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %7 = load <25 x double>, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %8 = fsub <25 x double> %6, %7
+  // CHECK-NEXT:    store <25 x double> %8, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %9 = load <27 x i32>, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %10 = load <27 x i32>, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %11 = sub <27 x i32> %9, %10
+  // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    ret void
+}
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1615,6 +1615,8 @@
 
   case Builtin::BI__builtin_matrix_insert:
   case Builtin::BI__builtin_matrix_extract:
+  case Builtin::BI__builtin_matrix_add:
+  case Builtin::BI__builtin_matrix_subtract:
     if (!getLangOpts().EnableMatrix) {
       Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled);
       return ExprError();
@@ -1625,6 +1627,9 @@
       return SemaBuiltinMatrixInsertOverload(TheCall, TheCallResult);
     case Builtin::BI__builtin_matrix_extract:
       return SemaBuiltinExtractMatrixOverload(TheCall, TheCallResult);
+    case Builtin::BI__builtin_matrix_add:
+    case Builtin::BI__builtin_matrix_subtract:
+      return SemaBuiltinMatrixEltwiseOverload(TheCall, TheCallResult);
     default:
       llvm_unreachable("All matrix builtins should be handled here!");
     }
@@ -15278,3 +15283,92 @@
 
   return CallResult;
 }
+
+ExprResult Sema::SemaBuiltinMatrixEltwiseOverload(CallExpr *TheCall,
+                                                  ExprResult CallResult) {
+  // The elementwise binary operations take two parameters, both are matrices of
+  // the same size with the same element type A -- type matrix (row, column,
+  // elt) B -- type matrix (row, column, elt)
+  //
+  // Returns: Matrix that is the result of the operation, with the same
+  // dimensions of A and B.
+  if (checkArgCount(*this, TheCall, 2))
+    return ExprError();
+
+  Expr *Callee = TheCall->getCallee();
+  DeclRefExpr *DRE = cast<DeclRefExpr>(Callee->IgnoreParenCasts());
+  FunctionDecl *FDecl = cast<FunctionDecl>(DRE->getDecl());
+
+  Expr *AArg = TheCall->getArg(0);
+  Expr *BArg = TheCall->getArg(1);
+
+  // Some typechecking to ensure that both matrices are of the same type
+  {
+    QualType AType = AArg->getType();
+    QualType BType = BArg->getType();
+
+    bool ArgError = false;
+    // Some very basic type checking, both parameters must be matrices
+    if (!AType->isMatrixType()) {
+      Diag(AArg->getBeginLoc(), diag::err_builtin_matrix_arg) << 0;
+      ArgError = true;
+    }
+    if (!BType->isMatrixType()) {
+      Diag(BArg->getBeginLoc(), diag::err_builtin_matrix_arg) << 1;
+      ArgError = true;
+    }
+    if (ArgError)
+      return ExprError();
+
+    MatrixType const *AMType = cast<MatrixType const>(AType.getCanonicalType());
+    MatrixType const *BMType = cast<MatrixType const>(BType.getCanonicalType());
+
+    // Matrices must have identical types (element types, number of rows and
+    // columns must match)
+    // TODO: Check that the rows and columns match. If the element types don't
+    // try to cast one or the other from integer to float
+    //       Then make the return type store floating type
+    if (AMType != BMType) {
+      Diag(AArg->getBeginLoc(), diag::err_builtin_matrix_type_match);
+      return ExprError();
+    }
+  }
+
+  // Convert l-valued matrices to r-values
+  if (!AArg->isRValue()) {
+    ExprResult Res = ImplicitCastExpr::Create(
+        Context, AArg->getType(), CK_LValueToRValue, AArg, nullptr, VK_RValue);
+    assert(!Res.isInvalid() && "Matrix cast failed\n");
+    TheCall->setArg(0, Res.get());
+  }
+  if (!BArg->isRValue()) {
+    ExprResult Res = ImplicitCastExpr::Create(
+        Context, BArg->getType(), CK_LValueToRValue, BArg, nullptr, VK_RValue);
+    assert(!Res.isInvalid() && "Matrix cast failed\n");
+    TheCall->setArg(1, Res.get());
+  }
+
+  // get the function prototype set up
+  llvm::SmallVector<QualType, 2> ParameterTypes = {
+      AArg->getType().getCanonicalType(), BArg->getType().getCanonicalType()};
+
+  // Create a new DeclRefExpr to refer to the new decl.
+  DeclRefExpr *NewDRE = DeclRefExpr::Create(
+      Context, DRE->getQualifierLoc(), SourceLocation(), FDecl,
+      /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy,
+      DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse());
+
+  // Set the callee in the CallExpr.
+  // FIXME: This loses syntactic information.
+  QualType CalleePtrTy = Context.getPointerType(FDecl->getType());
+  ExprResult PromotedCall = ImpCastExprToType(NewDRE, CalleePtrTy,
+                                              CK_BuiltinFnToFnPtr);
+  TheCall->setCallee(PromotedCall.get());
+
+  // Change the result type of the call to match the original value type. This
+  // is arbitrary, but the codegen for these builtins ins design to handle it
+  // gracefully.
+  TheCall->setType(AArg->getType());
+
+  return CallResult;
+}
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -2366,6 +2366,20 @@
     return RValue::get(Result);
   }
 
+  case Builtin::BI__builtin_matrix_add: {
+    MatrixBuilder<CGBuilderTy> MB(Builder);
+    Value *Matrix1 = EmitScalarExpr(E->getArg(0));
+    Value *Matrix2 = EmitScalarExpr(E->getArg(1));
+    Value *Result = MB.CreateAdd(Matrix1, Matrix2);
+    return RValue::get(Result);
+  }
+  case Builtin::BI__builtin_matrix_subtract: {
+    MatrixBuilder<CGBuilderTy> MB(Builder);
+    Value *Matrix1 = EmitScalarExpr(E->getArg(0));
+    Value *Matrix2 = EmitScalarExpr(E->getArg(1));
+    Value *Result = MB.CreateSub(Matrix1, Matrix2);
+    return RValue::get(Result);
+  }
   case Builtin::BIfinite:
   case Builtin::BI__finite:
   case Builtin::BIfinitef:
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11616,6 +11616,8 @@
                                              ExprResult CallResult);
   ExprResult SemaBuiltinExtractMatrixOverload(CallExpr *TheCall,
                                               ExprResult CallResult);
+  ExprResult SemaBuiltinMatrixEltwiseOverload(CallExpr *TheCall,
+                                              ExprResult CallResult);
 
 public:
   enum FormatStringType {
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10296,6 +10296,9 @@
 def err_builtin_matrix_implicit_cast_error: Error<
   "Implicit cast to from %0 to %1 failed">;
 
+def err_builtin_matrix_type_match: Error<
+  "Matrix types must match">;
+
 def err_preserve_field_info_not_field : Error<
   "__builtin_preserve_field_info argument %0 not a field access">;
 def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -575,6 +575,8 @@
 
 BUILTIN(__builtin_matrix_insert,  "v.", "nt")
 BUILTIN(__builtin_matrix_extract, "v.", "nt")
+BUILTIN(__builtin_matrix_subtract, "v.", "nt")
+BUILTIN(__builtin_matrix_add, "v.", "nt")
 
 // "Overloaded" Atomic operator builtins.  These are overloaded to support data
 // types of i8, i16, i32, i64, and i128.  The front-end sees calls to the
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D72773: [Ma... Florian Hahn via Phabricator via cfe-commits

Reply via email to