fhahn updated this revision to Diff 266842.
fhahn marked 4 inline comments as done.
fhahn added a comment.

Use initialization step for all conversions (including for arithemtic types), 
add & call separate addMatrixBinaryArithmeticOverloads


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D76793/new/

https://reviews.llvm.org/D76793

Files:
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGExprScalar.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/lib/Sema/SemaOverload.cpp
  clang/test/CodeGen/matrix-type-operators.c
  clang/test/CodeGenCXX/matrix-type-operators.cpp
  clang/test/Sema/matrix-type-operators.c
  clang/test/SemaCXX/matrix-type-operators.cpp
  llvm/include/llvm/IR/MatrixBuilder.h

Index: llvm/include/llvm/IR/MatrixBuilder.h
===================================================================
--- llvm/include/llvm/IR/MatrixBuilder.h
+++ llvm/include/llvm/IR/MatrixBuilder.h
@@ -127,6 +127,16 @@
   /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
   /// matrixes.
   Value *CreateAdd(Value *LHS, Value *RHS) {
+    assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
+    if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy())
+      RHS = B.CreateVectorSplat(
+          cast<VectorType>(LHS->getType())->getNumElements(), RHS,
+          "scalar.splat");
+    else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy())
+      LHS = B.CreateVectorSplat(
+          cast<VectorType>(RHS->getType())->getNumElements(), LHS,
+          "scalar.splat");
+
     return cast<VectorType>(LHS->getType())
                    ->getElementType()
                    ->isFloatingPointTy()
@@ -137,6 +147,16 @@
   /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
   /// point matrixes.
   Value *CreateSub(Value *LHS, Value *RHS) {
+    assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
+    if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy())
+      RHS = B.CreateVectorSplat(
+          cast<VectorType>(LHS->getType())->getNumElements(), RHS,
+          "scalar.splat");
+    else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy())
+      LHS = B.CreateVectorSplat(
+          cast<VectorType>(RHS->getType())->getNumElements(), LHS,
+          "scalar.splat");
+
     return cast<VectorType>(LHS->getType())
                    ->getElementType()
                    ->isFloatingPointTy()
Index: clang/test/SemaCXX/matrix-type-operators.cpp
===================================================================
--- clang/test/SemaCXX/matrix-type-operators.cpp
+++ clang/test/SemaCXX/matrix-type-operators.cpp
@@ -114,3 +114,96 @@
   a[2] = f;
   // expected-error@-1 {{single subscript expressions are not allowed for matrix values}}
 }
+
+template <typename EltTy, unsigned Rows, unsigned Columns>
+struct MyMatrix {
+  using matrix_t = EltTy __attribute__((matrix_type(Rows, Columns)));
+
+  matrix_t value;
+};
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t add(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = A.value + B.value;
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}}
+  // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+
+  return A.value + B.value;
+  // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+}
+
+void test_add_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
+  unsigned v1 = add<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-note@-2 {{in instantiation of function template specialization 'add<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = add<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'add<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = add<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3);
+  // expected-note@-1 {{in instantiation of function template specialization 'add<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
+}
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t subtract(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = A.value - B.value;
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}}
+
+  return A.value - B.value;
+  // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}}
+}
+
+void test_subtract_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
+  unsigned v1 = subtract<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-note@-2 {{in instantiation of function template specialization 'subtract<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = subtract<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = subtract<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3);
+  // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
+}
+
+struct UserT {};
+
+struct StructWithC {
+  operator UserT() {
+    // expected-note@-1 {{candidate function}}
+    // expected-note@-2 {{candidate function}}
+    // expected-note@-3 {{candidate function}}
+    // expected-note@-4 {{candidate function}}
+    return {};
+  }
+};
+
+void test_DoubleWrapper(MyMatrix<double, 10, 9> &m, StructWithC &c) {
+  m.value = m.value + c;
+  // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<double, 10, 9>::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))') and 'StructWithC')}}
+
+  m.value = c + m.value;
+  // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}}
+  // expected-error@-2 {{invalid operands to binary expression ('StructWithC' and 'MyMatrix<double, 10, 9>::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))'))}}
+
+  m.value = m.value - c;
+  // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<double, 10, 9>::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))') and 'StructWithC')}}
+
+  m.value = c - m.value;
+  // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}}
+  // expected-error@-2 {{invalid operands to binary expression ('StructWithC' and 'MyMatrix<double, 10, 9>::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))'))}}
+}
Index: clang/test/Sema/matrix-type-operators.c
===================================================================
--- clang/test/Sema/matrix-type-operators.c
+++ clang/test/Sema/matrix-type-operators.c
@@ -102,3 +102,34 @@
   return &(*a)[0][1];
   // expected-error@-1 {{address of matrix element requested}}
 }
+
+typedef float sx10x5_t __attribute__((matrix_type(10, 5)));
+typedef float sx10x10_t __attribute__((matrix_type(10, 10)));
+
+void add(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
+  a = b + c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}}
+
+  a = b + b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+
+  a = 10 + b;
+  // expected-error@-1 {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+
+  a = b + &c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*'))}}
+  // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}}
+}
+
+void sub(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
+  a = b - c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}}
+
+  a = b - b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+
+  a = 10 - b;
+  // expected-error@-1 {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+
+  a = b - &c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*'))}}
+  // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}}
+}
Index: clang/test/CodeGenCXX/matrix-type-operators.cpp
===================================================================
--- clang/test/CodeGenCXX/matrix-type-operators.cpp
+++ clang/test/CodeGenCXX/matrix-type-operators.cpp
@@ -258,3 +258,362 @@
 
   return m[0][1];
 }
+
+template <typename EltTy0, unsigned R0, unsigned C0>
+typename MyMatrix<EltTy0, R0, C0>::matrix_t add(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, R0, C0> &B) {
+  return A.value + B.value;
+}
+
+void test_add_template() {
+  // CHECK-LABEL:    define void @_Z17test_add_templatev()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Mat1 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %Mat2 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %call = call <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0
+  // CHECK-NEXT:    %0 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    store <10 x float> %call, <10 x float>* %0, align 4
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %A, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %B)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %A.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    %B.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    %2 = load <10 x float>, <10 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0
+  // CHECK-NEXT:    %4 = bitcast [10 x float]* %value1 to <10 x float>*
+  // CHECK-NEXT:    %5 = load <10 x float>, <10 x float>* %4, align 4
+  // CHECK-NEXT:    %6 = fadd <10 x float> %2, %5
+  // CHECK-NEXT:    ret <10 x float> %6
+
+  MyMatrix<float, 2, 5> Mat1;
+  MyMatrix<float, 2, 5> Mat2;
+  Mat1.value = add(Mat1, Mat2);
+}
+
+template <typename EltTy0, unsigned R0, unsigned C0>
+typename MyMatrix<EltTy0, R0, C0>::matrix_t subtract(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, R0, C0> &B) {
+  return A.value - B.value;
+}
+
+void test_subtract_template() {
+  // CHECK-LABEL: define void @_Z22test_subtract_templatev()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Mat1 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %Mat2 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %call = call <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0
+  // CHECK-NEXT:    %0 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    store <10 x float> %call, <10 x float>* %0, align 4
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %A, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %B)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %A.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    %B.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    %2 = load <10 x float>, <10 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0
+  // CHECK-NEXT:    %4 = bitcast [10 x float]* %value1 to <10 x float>*
+  // CHECK-NEXT:    %5 = load <10 x float>, <10 x float>* %4, align 4
+  // CHECK-NEXT:    %6 = fsub <10 x float> %2, %5
+  // CHECK-NEXT:    ret <10 x float> %6
+
+  MyMatrix<float, 2, 5> Mat1;
+  MyMatrix<float, 2, 5> Mat2;
+  Mat1.value = subtract(Mat1, Mat2);
+}
+
+struct DoubleWrapper1 {
+  int x;
+  operator double() {
+    return x;
+  }
+};
+
+struct DoubleWrapper2 {
+  int x;
+  operator double() {
+    return x;
+  }
+};
+
+struct IntWrapper {
+  char x;
+  operator int() {
+    return x;
+  }
+};
+
+void test_DoubleWrapper(MyMatrix<double, 10, 9> &m, MyMatrix<int, 3, 4> &m2) {
+  // CHECK-LABEL:  define void @_Z18test_DoubleWrapperR8MyMatrixIdLj10ELj9EERS_IiLj3ELj4EE(%struct.MyMatrix.2* nonnull align 8 dereferenceable(720) %m, %struct.MyMatrix.3* nonnull align 4 dereferenceable(48) %m2)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca %struct.MyMatrix.2*, align 8
+  // CHECK-NEXT:    %m2.addr = alloca %struct.MyMatrix.3*, align 8
+  // CHECK-NEXT:    %w1 = alloca %struct.DoubleWrapper1, align 4
+  // CHECK-NEXT:    %w2 = alloca %struct.DoubleWrapper2, align 4
+  // CHECK-NEXT:    %w3 = alloca %struct.IntWrapper, align 1
+  // CHECK-NEXT:    store %struct.MyMatrix.2* %m, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.3* %m2, %struct.MyMatrix.3** %m2.addr, align 8
+  // CHECK-NEXT:    %x = getelementptr inbounds %struct.DoubleWrapper1, %struct.DoubleWrapper1* %w1, i32 0, i32 0
+  // CHECK-NEXT:    store i32 10, i32* %x, align 4
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [90 x double]* %value to <90 x double>*
+  // CHECK-NEXT:    %2 = load <90 x double>, <90 x double>* %1, align 8
+  // CHECK-NEXT:    %call = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1)
+  // CHECK-NEXT:    %scalar.splat.splatinsert = insertelement <90 x double> undef, double %call, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat = shufflevector <90 x double> %scalar.splat.splatinsert, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %3 = fadd <90 x double> %2, %scalar.splat.splat
+  // CHECK-NEXT:    %4 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %4, i32 0, i32 0
+  // CHECK-NEXT:    %5 = bitcast [90 x double]* %value1 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %3, <90 x double>* %5, align 8
+  // CHECK-NEXT:    %call2 = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1)
+  // CHECK-NEXT:    %6 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value3 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %6, i32 0, i32 0
+  // CHECK-NEXT:    %7 = bitcast [90 x double]* %value3 to <90 x double>*
+  // CHECK-NEXT:    %8 = load <90 x double>, <90 x double>* %7, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert4 = insertelement <90 x double> undef, double %call2, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat5 = shufflevector <90 x double> %scalar.splat.splatinsert4, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %9 = fadd <90 x double> %scalar.splat.splat5, %8
+  // CHECK-NEXT:    %10 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value6 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %10, i32 0, i32 0
+  // CHECK-NEXT:    %11 = bitcast [90 x double]* %value6 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %9, <90 x double>* %11, align 8
+  // CHECK-NEXT:    %call7 = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1)
+  // CHECK-NEXT:    %12 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value8 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %12, i32 0, i32 0
+  // CHECK-NEXT:    %13 = bitcast [90 x double]* %value8 to <90 x double>*
+  // CHECK-NEXT:    %14 = load <90 x double>, <90 x double>* %13, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert9 = insertelement <90 x double> undef, double %call7, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat10 = shufflevector <90 x double> %scalar.splat.splatinsert9, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %15 = fsub <90 x double> %scalar.splat.splat10, %14
+  // CHECK-NEXT:    %16 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value11 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %16, i32 0, i32 0
+  // CHECK-NEXT:    %17 = bitcast [90 x double]* %value11 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %15, <90 x double>* %17, align 8
+  // CHECK-NEXT:    %x12 = getelementptr inbounds %struct.DoubleWrapper2, %struct.DoubleWrapper2* %w2, i32 0, i32 0
+  // CHECK-NEXT:    store i32 20, i32* %x12, align 4
+
+  DoubleWrapper1 w1;
+  w1.x = 10;
+  m.value = m.value + w1;
+  m.value = w1 + m.value;
+  m.value = w1 - m.value;
+
+  // CHECK-NEXT:    %18 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value13 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %18, i32 0, i32 0
+  // CHECK-NEXT:    %19 = bitcast [90 x double]* %value13 to <90 x double>*
+  // CHECK-NEXT:    %20 = load <90 x double>, <90 x double>* %19, align 8
+  // CHECK-NEXT:    %call14 = call double @_ZN14DoubleWrapper2cvdEv(%struct.DoubleWrapper2* %w2)
+  // CHECK-NEXT:    %scalar.splat.splatinsert15 = insertelement <90 x double> undef, double %call14, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat16 = shufflevector <90 x double> %scalar.splat.splatinsert15, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %21 = fadd <90 x double> %20, %scalar.splat.splat16
+  // CHECK-NEXT:    %22 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value17 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %22, i32 0, i32 0
+  // CHECK-NEXT:    %23 = bitcast [90 x double]* %value17 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %21, <90 x double>* %23, align 8
+  // CHECK-NEXT:    %call18 = call double @_ZN14DoubleWrapper2cvdEv(%struct.DoubleWrapper2* %w2)
+  // CHECK-NEXT:    %24 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value19 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %24, i32 0, i32 0
+  // CHECK-NEXT:    %25 = bitcast [90 x double]* %value19 to <90 x double>*
+  // CHECK-NEXT:    %26 = load <90 x double>, <90 x double>* %25, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert20 = insertelement <90 x double> undef, double %call18, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat21 = shufflevector <90 x double> %scalar.splat.splatinsert20, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %27 = fadd <90 x double> %scalar.splat.splat21, %26
+  // CHECK-NEXT:    %28 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value22 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %28, i32 0, i32 0
+  // CHECK-NEXT:    %29 = bitcast [90 x double]* %value22 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %27, <90 x double>* %29, align 8
+  DoubleWrapper2 w2;
+  w2.x = 20;
+  m.value = m.value + w2;
+  m.value = w2 + m.value;
+
+  // CHECK-NEXT:    %x23 = getelementptr inbounds %struct.IntWrapper, %struct.IntWrapper* %w3, i32 0, i32 0
+  // CHECK-NEXT:    store i8 99, i8* %x23, align 1
+  // CHECK-NEXT:    %30 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8
+  // CHECK-NEXT:    %value24 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %30, i32 0, i32 0
+  // CHECK-NEXT:    %31 = bitcast [12 x i32]* %value24 to <12 x i32>*
+  // CHECK-NEXT:    %32 = load <12 x i32>, <12 x i32>* %31, align 4
+  // CHECK-NEXT:    %call25 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3)
+  // CHECK-NEXT:    %scalar.splat.splatinsert26 = insertelement <12 x i32> undef, i32 %call25, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat27 = shufflevector <12 x i32> %scalar.splat.splatinsert26, <12 x i32> undef, <12 x i32> zeroinitializer
+  // CHECK-NEXT:    %33 = add <12 x i32> %32, %scalar.splat.splat27
+  // CHECK-NEXT:    %34 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8
+  // CHECK-NEXT:    %value28 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %34, i32 0, i32 0
+  // CHECK-NEXT:    %35 = bitcast [12 x i32]* %value28 to <12 x i32>*
+  // CHECK-NEXT:    store <12 x i32> %33, <12 x i32>* %35, align 4
+  // CHECK-NEXT:    %call29 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3)
+  // CHECK-NEXT:    %36 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8
+  // CHECK-NEXT:    %value30 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %36, i32 0, i32 0
+  // CHECK-NEXT:    %37 = bitcast [12 x i32]* %value30 to <12 x i32>*
+  // CHECK-NEXT:    %38 = load <12 x i32>, <12 x i32>* %37, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert31 = insertelement <12 x i32> undef, i32 %call29, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat32 = shufflevector <12 x i32> %scalar.splat.splatinsert31, <12 x i32> undef, <12 x i32> zeroinitializer
+  // CHECK-NEXT:    %39 = add <12 x i32> %scalar.splat.splat32, %38
+  // CHECK-NEXT:    %40 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8
+  // CHECK-NEXT:    %value33 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %40, i32 0, i32 0
+  // CHECK-NEXT:    %41 = bitcast [12 x i32]* %value33 to <12 x i32>*
+  // CHECK-NEXT:    store <12 x i32> %39, <12 x i32>* %41, align 4
+
+  IntWrapper w3;
+  w3.x = 'c';
+  m2.value = m2.value + w3;
+  m2.value = w3 + m2.value;
+
+  // int conversion function in struct and implicit cast to element type double.
+  // CHECK-NEXT:    %42 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value34 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %42, i32 0, i32 0
+  // CHECK-NEXT:    %43 = bitcast [90 x double]* %value34 to <90 x double>*
+  // CHECK-NEXT:    %44 = load <90 x double>, <90 x double>* %43, align 8
+  // CHECK-NEXT:    %call35 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3)
+  // CHECK-NEXT:    %conv = sitofp i32 %call35 to double
+  // CHECK-NEXT:    %scalar.splat.splatinsert36 = insertelement <90 x double> undef, double %conv, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat37 = shufflevector <90 x double> %scalar.splat.splatinsert36, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %45 = fsub <90 x double> %44, %scalar.splat.splat37
+  // CHECK-NEXT:    %46 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value38 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %46, i32 0, i32 0
+  // CHECK-NEXT:    %47 = bitcast [90 x double]* %value38 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %45, <90 x double>* %47, align 8
+  // CHECK-NEXT:    %call39 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3)
+  // CHECK-NEXT:    %conv40 = sitofp i32 %call39 to double
+  // CHECK-NEXT:    %48 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value41 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %48, i32 0, i32 0
+  // CHECK-NEXT:    %49 = bitcast [90 x double]* %value41 to <90 x double>*
+  // CHECK-NEXT:    %50 = load <90 x double>, <90 x double>* %49, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert42 = insertelement <90 x double> undef, double %conv40, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat43 = shufflevector <90 x double> %scalar.splat.splatinsert42, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %51 = fsub <90 x double> %scalar.splat.splat43, %50
+  // CHECK-NEXT:    %52 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value44 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %52, i32 0, i32 0
+  // CHECK-NEXT:    %53 = bitcast [90 x double]* %value44 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %51, <90 x double>* %53, align 8
+  // CHECK-NEXT:    ret void
+  // CHECK-NEXT:  }
+
+  m.value = m.value - w3;
+  m.value = w3 - m.value;
+}
+
+template <class T, unsigned R, unsigned C>
+using matrix_type = T __attribute__((matrix_type(R, C)));
+struct identmatrix_t {
+  template <class T, unsigned N>
+  operator matrix_type<T, N, N>() const {
+    matrix_type<T, N, N> result;
+    for (unsigned i = 0; i != N; ++i)
+      result[i][i] = 1;
+    return result;
+  }
+};
+
+constexpr identmatrix_t identmatrix;
+void test_constexpr(matrix_type<float, 4, 4> &m, matrix_type<int, 5, 5> &m2) {
+  // CHECK-LABEL: define void @_Z14test_constexprRU11matrix_typeLm4ELm4EfRU11matrix_typeLm5ELm5Ei([16 x float]* nonnull align 4 dereferenceable(64) %m, [25 x i32]* nonnull align 4 dereferenceable(100) %m2)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [16 x float]*, align 8
+  // CHECK-NEXT:    %m2.addr = alloca [25 x i32]*, align 8
+  // CHECK-NEXT:    store [16 x float]* %m, [16 x float]** %m.addr, align 8
+  // CHECK-NEXT:    store [25 x i32]* %m2, [25 x i32]** %m2.addr, align 8
+  // CHECK-NEXT:    %0 = load [16 x float]*, [16 x float]** %m.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [16 x float]* %0 to <16 x float>*
+  // CHECK-NEXT:    %2 = load <16 x float>, <16 x float>* %1, align 4
+  // CHECK-NEXT:    %call = call <16 x float> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IfLj4EEEv(%struct.identmatrix_t* @_ZL11identmatrix)
+  // CHECK-NEXT:    %3 = fadd <16 x float> %2, %call
+  // CHECK-NEXT:    %4 = load [16 x float]*, [16 x float]** %m.addr, align 8
+  // CHECK-NEXT:    %5 = bitcast [16 x float]* %4 to <16 x float>*
+  // CHECK-NEXT:    store <16 x float> %3, <16 x float>* %5, align 4
+  m = m + identmatrix;
+
+  // CHECK-NEXT:    %call1 = call <25 x i32> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IiLj5EEEv(%struct.identmatrix_t* @_ZL11identmatrix)
+  // CHECK-NEXT:    %6 = load [25 x i32]*, [25 x i32]** %m2.addr, align 8
+  // CHECK-NEXT:    %7 = bitcast [25 x i32]* %6 to <25 x i32>*
+  // CHECK-NEXT:    %8 = load <25 x i32>, <25 x i32>* %7, align 4
+  // CHECK-NEXT:    %9 = add <25 x i32> %call1, %8
+  // CHECK-NEXT:    %10 = add <25 x i32> %9, <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
+  // CHECK-NEXT:    %11 = load [25 x i32]*, [25 x i32]** %m2.addr, align 8
+  // CHECK-NEXT:    %12 = bitcast [25 x i32]* %11 to <25 x i32>*
+  // CHECK-NEXT:    store <25 x i32> %10, <25 x i32>* %12, align 4
+  // CHECK-NEXT:    ret void
+  m2 = identmatrix + m2 + 1;
+}
+
+// CHECK-LABEL:  define linkonce_odr <16 x float> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IfLj4EEEv(%struct.identmatrix_t* %this)
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    %this.addr = alloca %struct.identmatrix_t*, align 8
+// CHECK-NEXT:    %result = alloca [16 x float], align 4
+// CHECK-NEXT:    %i = alloca i32, align 4
+// CHECK-NEXT:    store %struct.identmatrix_t* %this, %struct.identmatrix_t** %this.addr, align 8
+// CHECK-NEXT:    %this1 = load %struct.identmatrix_t*, %struct.identmatrix_t** %this.addr, align 8
+// CHECK-NEXT:    store i32 0, i32* %i, align 4
+// CHECK-NEXT:    br label %for.cond
+// CHECK-LABEL: for.cond:
+// CHECK-NEXT:    %0 = load i32, i32* %i, align 4
+// CHECK-NEXT:    %cmp = icmp ne i32 %0, 4
+// CHECK-NEXT:    br i1 %cmp, label %for.body, label %for.end
+// CHECK-LABEL: for.body:
+// CHECK-NEXT:    %1 = load i32, i32* %i, align 4
+// CHECK-NEXT:    %2 = load i32, i32* %i, align 4
+// CHECK-NEXT:    %3 = mul i32 %2, 4
+// CHECK-NEXT:    %4 = add i32 %3, %1
+// CHECK-NEXT:    %5 = bitcast [16 x float]* %result to <16 x float>*
+// CHECK-NEXT:    %6 = load <16 x float>, <16 x float>* %5, align 4
+// CHECK-NEXT:    %matins = insertelement <16 x float> %6, float 1.000000e+00, i32 %4
+// CHECK-NEXT:    store <16 x float> %matins, <16 x float>* %5, align 4
+// CHECK-NEXT:    br label %for.inc
+// CHECK-LABEL: for.inc:
+// CHECK-NEXT:    %7 = load i32, i32* %i, align 4
+// CHECK-NEXT:    %inc = add i32 %7, 1
+// CHECK-NEXT:    store i32 %inc, i32* %i, align 4
+// CHECK-NEXT:    br label %for.cond
+// CHECK-LABEL:  for.end:
+// CHECK-NEXT:    %8 = bitcast [16 x float]* %result to <16 x float>*
+// CHECK-NEXT:    %9 = load <16 x float>, <16 x float>* %8, align 4
+// CHECK-NEXT:    ret <16 x float> %9
+
+// CHECK-LABEL:  define linkonce_odr <25 x i32> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IiLj5EEEv(%struct.identmatrix_t* %this)
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    %this.addr = alloca %struct.identmatrix_t*, align 8
+// CHECK-NEXT:    %result = alloca [25 x i32], align 4
+// CHECK-NEXT:    %i = alloca i32, align 4
+// CHECK-NEXT:    store %struct.identmatrix_t* %this, %struct.identmatrix_t** %this.addr, align 8
+// CHECK-NEXT:    %this1 = load %struct.identmatrix_t*, %struct.identmatrix_t** %this.addr, align 8
+// CHECK-NEXT:    store i32 0, i32* %i, align 4
+// CHECK-NEXT:    br label %for.cond
+// CHECK-LABEL:  for.cond:
+// CHECK-NEXT:    %0 = load i32, i32* %i, align 4
+// CHECK-NEXT:    %cmp = icmp ne i32 %0, 5
+// CHECK-NEXT:    br i1 %cmp, label %for.body, label %for.end
+// CHECK-LABEL: for.body:
+// CHECK-NEXT:    %1 = load i32, i32* %i, align 4
+// CHECK-NEXT:    %2 = load i32, i32* %i, align 4
+// CHECK-NEXT:    %3 = mul i32 %2, 5
+// CHECK-NEXT:    %4 = add i32 %3, %1
+// CHECK-NEXT:    %5 = bitcast [25 x i32]* %result to <25 x i32>*
+// CHECK-NEXT:    %6 = load <25 x i32>, <25 x i32>* %5, align 4
+// CHECK-NEXT:    %matins = insertelement <25 x i32> %6, i32 1, i32 %4
+// CHECK-NEXT:    store <25 x i32> %matins, <25 x i32>* %5, align 4
+// CHECK-NEXT:    br label %for.inc
+// CHECK-LABEL: for.inc:
+// CHECK-NEXT:    %7 = load i32, i32* %i, align 4
+// CHECK-NEXT:    %inc = add i32 %7, 1
+// CHECK-NEXT:    store i32 %inc, i32* %i, align 4
+// CHECK-NEXT:    br label %for.cond
+// CHECK-LABEL: for.end:
+// CHECK-NEXT:    %8 = bitcast [25 x i32]* %result to <25 x i32>*
+// CHECK-NEXT:    %9 = load <25 x i32>, <25 x i32>* %8, align 4
+// CHECK-NEXT:    ret <25 x i32> %9
+// CHECK-NEXT:  }
Index: clang/test/CodeGen/matrix-type-operators.c
===================================================================
--- clang/test/CodeGen/matrix-type-operators.c
+++ clang/test/CodeGen/matrix-type-operators.c
@@ -312,3 +312,311 @@
   // CHECK-NEXT:    ret void
   b[2][j] = b[0][k];
 }
+
+void add_matrix_matrix(dx5x5_t a, dx5x5_t b, dx5x5_t c, ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) {
+  a = b + c;
+  ai = bi + ci;
+
+  // CHECK-LABEL: @add_matrix_matrix(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %ai.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %bi.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %ci.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %b.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %b, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %c.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %c, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ai, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %bi, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ci, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %6 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %7 = load <25 x double>, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %8 = fadd <25 x double> %6, %7
+  // CHECK-NEXT:    store <25 x double> %8, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %9 = load <27 x i32>, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %10 = load <27 x i32>, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %11 = add <27 x i32> %9, %10
+  // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    ret void
+}
+
+void add_matrix_scalar_float(dx5x5_t a, fx2x3_t b, float vf, double vd) {
+  a = a + vf;
+  a = a + vd;
+
+  // CHECK-LABEL: define void @add_matrix_scalar_float(<25 x double> %a, <6 x float> %b, float %vf, double %vd)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [6 x float], align 4
+  // CHECK-NEXT:    %vf.addr = alloca float, align 4
+  // CHECK-NEXT:    %vd.addr = alloca double, align 8
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [6 x float]* %b.addr to <6 x float>*
+  // CHECK-NEXT:    store <6 x float> %b, <6 x float>* %1, align 4
+  // CHECK-NEXT:    store float %vf, float* %vf.addr, align 4
+  // CHECK-NEXT:    store double %vd, double* %vd.addr, align 8
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %3 = load float, float* %vf.addr, align 4
+  // CHECK-NEXT:    %conv = fpext float %3 to double
+  // CHECK-NEXT:    %scalar.splat.splatinsert = insertelement <25 x double> undef, double %conv, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer
+  // CHECK-NEXT:    %4 = fadd <25 x double> %2, %scalar.splat.splat
+  // CHECK-NEXT:    store <25 x double> %4, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %5 = load <25 x double>, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %6 = load double, double* %vd.addr, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert1 = insertelement <25 x double> undef, double %6, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat2 = shufflevector <25 x double> %scalar.splat.splatinsert1, <25 x double> undef, <25 x i32> zeroinitializer
+  // CHECK-NEXT:    %7 = fadd <25 x double> %5, %scalar.splat.splat2
+  // CHECK-NEXT:    store <25 x double> %7, <25 x double>* %0, align 8
+
+  b = b + vf;
+  b = b + vd;
+
+  // CHECK-NEXT:    %8 = load <6 x float>, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %9 = load float, float* %vf.addr, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert3 = insertelement <6 x float> undef, float %9, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat4 = shufflevector <6 x float> %scalar.splat.splatinsert3, <6 x float> undef, <6 x i32> zeroinitializer
+  // CHECK-NEXT:    %10 = fadd <6 x float> %8, %scalar.splat.splat4
+  // CHECK-NEXT:    store <6 x float> %10, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %11 = load <6 x float>, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %12 = load double, double* %vd.addr, align 8
+  // CHECK-NEXT:    %conv5 = fptrunc double %12 to float
+  // CHECK-NEXT:    %scalar.splat.splatinsert6 = insertelement <6 x float> undef, float %conv5, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat7 = shufflevector <6 x float> %scalar.splat.splatinsert6, <6 x float> undef, <6 x i32> zeroinitializer
+  // CHECK-NEXT:    %13 = fadd <6 x float> %11, %scalar.splat.splat7
+  // CHECK-NEXT:    store <6 x float> %13, <6 x float>* %1, align 4
+  // CHECK-NEXT:    ret void
+}
+
+typedef int llix9x3_t __attribute__((matrix_type(9, 3)));
+
+void add_matrix_scalar_ints(ix9x3_t a, llix9x3_t b, short vs, long int vli, unsigned long long int vulli) {
+  a = a + vs;
+  a = a + vli;
+  a = a + vulli;
+
+  // CHECK-LABEL: define void @add_matrix_scalar_ints(<27 x i32> %a, <27 x i32> %b, i16 signext %vs, i64 %vli, i64 %vulli)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %b.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %vs.addr = alloca i16, align 2
+  // CHECK-NEXT:    %vli.addr = alloca i64, align 8
+  // CHECK-NEXT:    %vulli.addr = alloca i64, align 8
+  // CHECK-NEXT:    %0 = bitcast [27 x i32]* %a.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %a, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %1 = bitcast [27 x i32]* %b.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %b, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    store i16 %vs, i16* %vs.addr, align 2
+  // CHECK-NEXT:    store i64 %vli, i64* %vli.addr, align 8
+  // CHECK-NEXT:    store i64 %vulli, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %2 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %3 = load i16, i16* %vs.addr, align 2
+  // CHECK-NEXT:    %conv = sext i16 %3 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert = insertelement <27 x i32> undef, i32 %conv, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat = shufflevector <27 x i32> %scalar.splat.splatinsert, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %4 = add <27 x i32> %2, %scalar.splat.splat
+  // CHECK-NEXT:    store <27 x i32> %4, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %5 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %6 = load i64, i64* %vli.addr, align 8
+  // CHECK-NEXT:    %conv1 = trunc i64 %6 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert2 = insertelement <27 x i32> undef, i32 %conv1, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat3 = shufflevector <27 x i32> %scalar.splat.splatinsert2, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %7 = add <27 x i32> %5, %scalar.splat.splat3
+  // CHECK-NEXT:    store <27 x i32> %7, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %8 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %9 = load i64, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %conv4 = trunc i64 %9 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert5 = insertelement <27 x i32> undef, i32 %conv4, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat6 = shufflevector <27 x i32> %scalar.splat.splatinsert5, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %10 = add <27 x i32> %8, %scalar.splat.splat6
+  // CHECK-NEXT:    store <27 x i32> %10, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %11 = load i16, i16* %vs.addr, align 2
+
+  b = vs + b;
+  b = vli + b;
+  b = vulli + b;
+
+  // CHECK-NEXT:    %conv7 = sext i16 %11 to i32
+  // CHECK-NEXT:    %12 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert8 = insertelement <27 x i32> undef, i32 %conv7, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat9 = shufflevector <27 x i32> %scalar.splat.splatinsert8, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %13 = add <27 x i32> %scalar.splat.splat9, %12
+  // CHECK-NEXT:    store <27 x i32> %13, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %14 = load i64, i64* %vli.addr, align 8
+  // CHECK-NEXT:    %conv10 = trunc i64 %14 to i32
+  // CHECK-NEXT:    %15 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert11 = insertelement <27 x i32> undef, i32 %conv10, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat12 = shufflevector <27 x i32> %scalar.splat.splatinsert11, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %16 = add <27 x i32> %scalar.splat.splat12, %15
+  // CHECK-NEXT:    store <27 x i32> %16, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %17 = load i64, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %conv13 = trunc i64 %17 to i32
+  // CHECK-NEXT:    %18 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert14 = insertelement <27 x i32> undef, i32 %conv13, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat15 = shufflevector <27 x i32> %scalar.splat.splatinsert14, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %19 = add <27 x i32> %scalar.splat.splat15, %18
+  // CHECK-NEXT:    store <27 x i32> %19, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    ret void
+}
+
+void sub_matrix_matrix(dx5x5_t a, dx5x5_t b, dx5x5_t c, ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) {
+  a = b - c;
+  ai = bi - ci;
+
+  // CHECK-LABEL: @sub_matrix_matrix(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %ai.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %bi.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %ci.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %b.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %b, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %c.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %c, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ai, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %bi, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ci, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %6 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %7 = load <25 x double>, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %8 = fsub <25 x double> %6, %7
+  // CHECK-NEXT:    store <25 x double> %8, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %9 = load <27 x i32>, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %10 = load <27 x i32>, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %11 = sub <27 x i32> %9, %10
+  // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    ret void
+}
+
+void sub_matrix_scalar_float(dx5x5_t a, fx2x3_t b, float vf, double vd) {
+  a = a - vf;
+  a = a - vd;
+
+  // CHECK-LABEL: define void @sub_matrix_scalar_float(<25 x double> %a, <6 x float> %b, float %vf, double %vd)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [6 x float], align 4
+  // CHECK-NEXT:    %vf.addr = alloca float, align 4
+  // CHECK-NEXT:    %vd.addr = alloca double, align 8
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [6 x float]* %b.addr to <6 x float>*
+  // CHECK-NEXT:    store <6 x float> %b, <6 x float>* %1, align 4
+  // CHECK-NEXT:    store float %vf, float* %vf.addr, align 4
+  // CHECK-NEXT:    store double %vd, double* %vd.addr, align 8
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %3 = load float, float* %vf.addr, align 4
+  // CHECK-NEXT:    %conv = fpext float %3 to double
+  // CHECK-NEXT:    %scalar.splat.splatinsert = insertelement <25 x double> undef, double %conv, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer
+  // CHECK-NEXT:    %4 = fsub <25 x double> %2, %scalar.splat.splat
+  // CHECK-NEXT:    store <25 x double> %4, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %5 = load <25 x double>, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %6 = load double, double* %vd.addr, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert1 = insertelement <25 x double> undef, double %6, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat2 = shufflevector <25 x double> %scalar.splat.splatinsert1, <25 x double> undef, <25 x i32> zeroinitializer
+  // CHECK-NEXT:    %7 = fsub <25 x double> %5, %scalar.splat.splat2
+  // CHECK-NEXT:    store <25 x double> %7, <25 x double>* %0, align 8
+
+  b = b - vf;
+  b = b - vd;
+
+  // CHECK-NEXT:    %8 = load <6 x float>, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %9 = load float, float* %vf.addr, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert3 = insertelement <6 x float> undef, float %9, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat4 = shufflevector <6 x float> %scalar.splat.splatinsert3, <6 x float> undef, <6 x i32> zeroinitializer
+  // CHECK-NEXT:    %10 = fsub <6 x float> %8, %scalar.splat.splat4
+  // CHECK-NEXT:    store <6 x float> %10, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %11 = load <6 x float>, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %12 = load double, double* %vd.addr, align 8
+  // CHECK-NEXT:    %conv5 = fptrunc double %12 to float
+  // CHECK-NEXT:    %scalar.splat.splatinsert6 = insertelement <6 x float> undef, float %conv5, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat7 = shufflevector <6 x float> %scalar.splat.splatinsert6, <6 x float> undef, <6 x i32> zeroinitializer
+  // CHECK-NEXT:    %13 = fsub <6 x float> %11, %scalar.splat.splat7
+  // CHECK-NEXT:    store <6 x float> %13, <6 x float>* %1, align 4
+  // CHECK-NEXT:    ret void
+}
+
+void sub_matrix_scalar_ints(ix9x3_t a, llix9x3_t b, short vs, long int vli, unsigned long long int vulli) {
+  a = a - vs;
+  a = a - vli;
+  a = a - vulli;
+
+  // CHECK-LABEL: define void @sub_matrix_scalar_ints(<27 x i32> %a, <27 x i32> %b, i16 signext %vs, i64 %vli, i64 %vulli)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %b.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %vs.addr = alloca i16, align 2
+  // CHECK-NEXT:    %vli.addr = alloca i64, align 8
+  // CHECK-NEXT:    %vulli.addr = alloca i64, align 8
+  // CHECK-NEXT:    %0 = bitcast [27 x i32]* %a.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %a, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %1 = bitcast [27 x i32]* %b.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %b, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    store i16 %vs, i16* %vs.addr, align 2
+  // CHECK-NEXT:    store i64 %vli, i64* %vli.addr, align 8
+  // CHECK-NEXT:    store i64 %vulli, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %2 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %3 = load i16, i16* %vs.addr, align 2
+  // CHECK-NEXT:    %conv = sext i16 %3 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert = insertelement <27 x i32> undef, i32 %conv, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat = shufflevector <27 x i32> %scalar.splat.splatinsert, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %4 = sub <27 x i32> %2, %scalar.splat.splat
+  // CHECK-NEXT:    store <27 x i32> %4, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %5 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %6 = load i64, i64* %vli.addr, align 8
+  // CHECK-NEXT:    %conv1 = trunc i64 %6 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert2 = insertelement <27 x i32> undef, i32 %conv1, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat3 = shufflevector <27 x i32> %scalar.splat.splatinsert2, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %7 = sub <27 x i32> %5, %scalar.splat.splat3
+  // CHECK-NEXT:    store <27 x i32> %7, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %8 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %9 = load i64, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %conv4 = trunc i64 %9 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert5 = insertelement <27 x i32> undef, i32 %conv4, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat6 = shufflevector <27 x i32> %scalar.splat.splatinsert5, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %10 = sub <27 x i32> %8, %scalar.splat.splat6
+  // CHECK-NEXT:    store <27 x i32> %10, <27 x i32>* %0, align 4
+
+  b = vs - b;
+  b = vli - b;
+  b = vulli - b;
+
+  // CHECK-NEXT:    %11 = load i16, i16* %vs.addr, align 2
+  // CHECK-NEXT:    %conv7 = sext i16 %11 to i32
+  // CHECK-NEXT:    %12 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert8 = insertelement <27 x i32> undef, i32 %conv7, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat9 = shufflevector <27 x i32> %scalar.splat.splatinsert8, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %13 = sub <27 x i32> %scalar.splat.splat9, %12
+  // CHECK-NEXT:    store <27 x i32> %13, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %14 = load i64, i64* %vli.addr, align 8
+  // CHECK-NEXT:    %conv10 = trunc i64 %14 to i32
+  // CHECK-NEXT:    %15 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert11 = insertelement <27 x i32> undef, i32 %conv10, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat12 = shufflevector <27 x i32> %scalar.splat.splatinsert11, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %16 = sub <27 x i32> %scalar.splat.splat12, %15
+  // CHECK-NEXT:    store <27 x i32> %16, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %17 = load i64, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %conv13 = trunc i64 %17 to i32
+  // CHECK-NEXT:    %18 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert14 = insertelement <27 x i32> undef, i32 %conv13, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat15 = shufflevector <27 x i32> %scalar.splat.splatinsert14, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %19 = sub <27 x i32> %scalar.splat.splat15, %18
+  // CHECK-NEXT:    store <27 x i32> %19, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    ret void
+}
Index: clang/lib/Sema/SemaOverload.cpp
===================================================================
--- clang/lib/Sema/SemaOverload.cpp
+++ clang/lib/Sema/SemaOverload.cpp
@@ -7687,6 +7687,10 @@
   /// candidates.
   TypeSet VectorTypes;
 
+  /// The set of matrix types that will be used in the built-in
+  /// candidates.
+  TypeSet MatrixTypes;
+
   /// A flag indicating non-record types are viable candidates
   bool HasNonRecordTypes;
 
@@ -7747,6 +7751,11 @@
   iterator vector_begin() { return VectorTypes.begin(); }
   iterator vector_end() { return VectorTypes.end(); }
 
+  llvm::iterator_range<iterator> matrix_types() { return MatrixTypes; }
+  iterator matrix_begin() { return MatrixTypes.begin(); }
+  iterator matrix_end() { return MatrixTypes.end(); }
+
+  bool containsMatrixType(QualType Ty) const { return MatrixTypes.count(Ty); }
   bool hasNonRecordTypes() { return HasNonRecordTypes; }
   bool hasArithmeticOrEnumeralTypes() { return HasArithmeticOrEnumeralTypes; }
   bool hasNullPtrType() const { return HasNullPtrType; }
@@ -7921,6 +7930,11 @@
     // extension.
     HasArithmeticOrEnumeralTypes = true;
     VectorTypes.insert(Ty);
+  } else if (Ty->isMatrixType()) {
+    // Similar to vector types, we treat vector types as arithmetic types in
+    // many contexts as an extension.
+    HasArithmeticOrEnumeralTypes = true;
+    MatrixTypes.insert(Ty);
   } else if (Ty->isNullPtrType()) {
     HasNullPtrType = true;
   } else if (AllowUserConversions && TyRec) {
@@ -8149,6 +8163,13 @@
 
   }
 
+  /// Helper to add an overload candidate for a binary builtin with types \p L
+  /// and \p R.
+  void AddCandidate(QualType L, QualType R) {
+    QualType LandR[2] = {L, R};
+    S.AddBuiltinCandidate(LandR, Args, CandidateSet);
+  }
+
 public:
   BuiltinOperatorOverloadBuilder(
     Sema &S, ArrayRef<Expr *> Args,
@@ -8567,6 +8588,27 @@
     }
   }
 
+  /// Add binary operators overloads for each candidate matrix type M1, M2:
+  ///  * (M1, M1) -> M1
+  ///  * (M1, M1.getElementType()) -> M1
+  ///  * (M2.getElementType(), M2) -> M2
+  ///  * (M2, M2) -> M2 // Only if M2 is not part of CandidateTypes[0].
+  void addMatrixBinaryArithmeticOverloads() {
+    if (!HasArithmeticOrEnumeralCandidateType)
+      return;
+
+    for (QualType M1 : CandidateTypes[0].matrix_types()) {
+      AddCandidate(M1, cast<MatrixType>(M1)->getElementType());
+      AddCandidate(M1, M1);
+    }
+
+    for (QualType M2 : CandidateTypes[1].matrix_types()) {
+      AddCandidate(cast<MatrixType>(M2)->getElementType(), M2);
+      if (!CandidateTypes[0].containsMatrixType(M2))
+        AddCandidate(M2, M2);
+    }
+  }
+
   // C++2a [over.built]p14:
   //
   //   For every integral type T there exists a candidate operator function
@@ -9140,6 +9182,7 @@
     } else {
       OpBuilder.addBinaryPlusOrMinusPointerOverloads(Op);
       OpBuilder.addGenericBinaryArithmeticOverloads();
+      OpBuilder.addMatrixBinaryArithmeticOverloads();
     }
     break;
 
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -10357,6 +10357,11 @@
     return compType;
   }
 
+  if (LHS.get()->getType()->isConstantMatrixType() ||
+      RHS.get()->getType()->isConstantMatrixType()) {
+    return CheckMatrixElementwiseOperands(LHS, RHS, Loc, CompLHSTy);
+  }
+
   QualType compType = UsualArithmeticConversions(
       LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic);
   if (LHS.isInvalid() || RHS.isInvalid())
@@ -10452,6 +10457,11 @@
     return compType;
   }
 
+  if (LHS.get()->getType()->isConstantMatrixType() ||
+      RHS.get()->getType()->isConstantMatrixType()) {
+    return CheckMatrixElementwiseOperands(LHS, RHS, Loc, CompLHSTy);
+  }
+
   QualType compType = UsualArithmeticConversions(
       LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic);
   if (LHS.isInvalid() || RHS.isInvalid())
@@ -12047,6 +12057,63 @@
   return GetSignedVectorType(LHS.get()->getType());
 }
 
+static bool tryConvertScalarToMatrixElementTy(Sema &S, QualType ElementType,
+                                              ExprResult *Scalar) {
+  InitializedEntity Entity =
+      InitializedEntity::InitializeTemporary(ElementType);
+  InitializationKind Kind = InitializationKind::CreateCopy(
+      Scalar->get()->getBeginLoc(), SourceLocation());
+  Expr *Arg = Scalar->get();
+  InitializationSequence InitSeq(S, Entity, Kind, Arg);
+  *Scalar = InitSeq.Perform(S, Entity, Kind, Arg);
+  return !Scalar->isInvalid();
+}
+
+QualType Sema::CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
+                                              SourceLocation Loc,
+                                              bool IsCompAssign) {
+  if (!IsCompAssign) {
+    LHS = DefaultFunctionArrayLvalueConversion(LHS.get());
+    if (LHS.isInvalid())
+      return QualType();
+  }
+  RHS = DefaultFunctionArrayLvalueConversion(RHS.get());
+  if (RHS.isInvalid())
+    return QualType();
+
+  // For conversion purposes, we ignore any qualifiers.
+  // For example, "const float" and "float" are equivalent.
+  QualType LHSType = LHS.get()->getType().getUnqualifiedType();
+  QualType RHSType = RHS.get()->getType().getUnqualifiedType();
+
+  const MatrixType *LHSMatType = LHSType->getAs<MatrixType>();
+  const MatrixType *RHSMatType = RHSType->getAs<MatrixType>();
+  assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix");
+
+  if (Context.hasSameType(LHSType, RHSType))
+    return LHSType;
+
+  // Type conversion may change LHS/RHS. Keep copies to the original results, in
+  // case we have to return InvalidOperands.
+  ExprResult OriginalLHS = LHS;
+  ExprResult OriginalRHS = RHS;
+  if (LHSMatType && !RHSMatType) {
+    if (tryConvertScalarToMatrixElementTy(*this, LHSMatType->getElementType(),
+                                          &RHS))
+      return LHSType;
+    return InvalidOperands(Loc, OriginalLHS, OriginalRHS);
+  }
+
+  if (!LHSMatType && RHSMatType) {
+    if (tryConvertScalarToMatrixElementTy(*this, RHSMatType->getElementType(),
+                                          &LHS))
+      return RHSType;
+    return InvalidOperands(Loc, OriginalLHS, OriginalRHS);
+  }
+
+  return InvalidOperands(Loc, LHS, RHS);
+}
+
 inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS,
                                            SourceLocation Loc,
                                            BinaryOperatorKind Opc) {
Index: clang/lib/CodeGen/CGExprScalar.cpp
===================================================================
--- clang/lib/CodeGen/CGExprScalar.cpp
+++ clang/lib/CodeGen/CGExprScalar.cpp
@@ -3554,6 +3554,11 @@
     }
   }
 
+  if (op.Ty->isConstantMatrixType()) {
+    llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+    return MB.CreateAdd(op.LHS, op.RHS);
+  }
+
   if (op.Ty->isUnsignedIntegerType() &&
       CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
       !CanElideOverflowCheck(CGF.getContext(), op))
@@ -3738,6 +3743,11 @@
       }
     }
 
+    if (op.Ty->isConstantMatrixType()) {
+      llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+      return MB.CreateSub(op.LHS, op.RHS);
+    }
+
     if (op.Ty->isUnsignedIntegerType() &&
         CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
         !CanElideOverflowCheck(CGF.getContext(), op))
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11215,6 +11215,11 @@
   QualType CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS,
                                       SourceLocation Loc);
 
+  /// Type checking for matrix binary operators.
+  QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
+                                          SourceLocation Loc,
+                                          bool IsCompAssign);
+
   bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType);
   bool isLaxVectorConversion(QualType srcType, QualType destType);
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to