fhahn created this revision.
Herald added a subscriber: tschuett.
Herald added a project: clang.

Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D72774

Files:
  clang/include/clang/Basic/Builtins.def
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/lib/Sema/SemaChecking.cpp
  clang/test/CodeGen/builtin-matrix.c

Index: clang/test/CodeGen/builtin-matrix.c
===================================================================
--- clang/test/CodeGen/builtin-matrix.c
+++ clang/test/CodeGen/builtin-matrix.c
@@ -225,3 +225,30 @@
   // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
   // CHECK-NEXT:    ret void
 }
+
+void multiply1(dx5x5_t *a, dx5x5_t *b, dx5x5_t *c) {
+  *a = __builtin_matrix_multiply(*b, *c);
+
+  // CHECK-LABEL: @multiply1(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %b, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %c, [25 x double]** %c.addr, align 8
+  // CHECK-NEXT:    %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %0 to <25 x double>*
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %3 = load [25 x double]*, [25 x double]** %c.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [25 x double]* %3 to <25 x double>*
+  // CHECK-NEXT:    %5 = load <25 x double>, <25 x double>* %4, align 8
+  // CHECK-NEXT:    %6 = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> %2, <25 x double> %5, i32 5, i32 5, i32 5)
+  // CHECK-NEXT:    %7 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %8 = bitcast [25 x double]* %7 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %6, <25 x double>* %8, align 8
+  // CHECK-NEXT:    ret void
+}
+// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) [[READNONE:#[0-9]]]
+
+// CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn }
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1617,6 +1617,7 @@
   case Builtin::BI__builtin_matrix_extract:
   case Builtin::BI__builtin_matrix_add:
   case Builtin::BI__builtin_matrix_subtract:
+  case Builtin::BI__builtin_matrix_multiply:
     if (!getLangOpts().EnableMatrix) {
       Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled);
       return ExprError();
@@ -1630,6 +1631,8 @@
     case Builtin::BI__builtin_matrix_add:
     case Builtin::BI__builtin_matrix_subtract:
       return SemaBuiltinMatrixEltwiseOverload(TheCall, TheCallResult);
+    case Builtin::BI__builtin_matrix_multiply:
+      return SemaBuiltinMatrixMultiplyOverload(TheCall, TheCallResult);
     default:
       llvm_unreachable("All matrix builtins should be handled here!");
     }
@@ -15372,3 +15375,98 @@
 
   return CallResult;
 }
+
+ExprResult Sema::SemaBuiltinMatrixMultiplyOverload(CallExpr *TheCall,
+                                                   ExprResult CallResult) {
+  if (checkArgCount(*this, TheCall, 2))
+    return ExprError();
+
+  Expr *Callee = TheCall->getCallee();
+  DeclRefExpr *DRE = cast<DeclRefExpr>(Callee->IgnoreParenCasts());
+  FunctionDecl *FDecl = cast<FunctionDecl>(DRE->getDecl());
+
+  Expr *AArg = TheCall->getArg(0);
+  Expr *BArg = TheCall->getArg(1);
+
+  bool ArgError = false;
+  // Some very basic type checking, both parameters must be matrices
+  if (!AArg->getType()->isMatrixType()) {
+    Diag(AArg->getBeginLoc(), diag::err_builtin_matrix_arg) << 0;
+    ArgError = true;
+  }
+  if (!BArg->getType()->isMatrixType()) {
+    Diag(BArg->getBeginLoc(), diag::err_builtin_matrix_arg) << 1;
+    ArgError = true;
+  }
+  if (ArgError)
+    return ExprError();
+
+  MatrixType const *AMType =
+      cast<MatrixType const>(AArg->getType().getCanonicalType());
+  MatrixType const *BMType =
+      cast<MatrixType const>(BArg->getType().getCanonicalType());
+
+  unsigned m = AMType->getNumRows();
+  unsigned n = AMType->getNumColumns();
+  unsigned r = BMType->getNumColumns();
+  // Full Type Checking
+
+  // Requirements:
+  // A (m x n) * B (n x r) = AB (m x r)
+  // The A Column must match the number of rows in B
+
+  if (BMType->getNumRows() != n) {
+    Diag(AArg->getBeginLoc(), diag::err_builtin_matrix_dimension_error);
+    return ExprError();
+  }
+
+  // Element types of both matrices must match
+  if (AMType->getElementType() != BMType->getElementType()) {
+    Diag(AArg->getBeginLoc(), diag::err_builtin_matrix_element_type)
+        << AMType->getElementType() << BMType->getElementType();
+    return ExprError();
+  }
+
+  // Set up the function prototype
+
+  if (!AArg->isRValue()) {
+    ExprResult Res = ImplicitCastExpr::Create(
+        Context, AArg->getType(), CK_LValueToRValue, AArg, nullptr, VK_RValue);
+    assert(!Res.isInvalid() && "Matrix Cast failed");
+    TheCall->setArg(0, Res.get());
+  }
+
+  if (!BArg->isRValue()) {
+    ExprResult Res = ImplicitCastExpr::Create(
+        Context, BArg->getType(), CK_LValueToRValue, BArg, nullptr, VK_RValue);
+    assert(!Res.isInvalid() && "Matrix Cast failed");
+    TheCall->setArg(1, Res.get());
+  }
+
+  // Function Return Type
+  QualType ReturnElementType = AMType->getElementType();
+  QualType ResultType = Context.getMatrixType(ReturnElementType, m, r);
+
+  llvm::SmallVector<QualType, 2> ParameterTypes = {
+      AArg->getType().getCanonicalType(), BArg->getType().getCanonicalType()};
+
+  // Create a new DeclRefExpr to refer to the new decl.
+  DeclRefExpr *NewDRE = DeclRefExpr::Create(
+      Context, DRE->getQualifierLoc(), SourceLocation(), FDecl,
+      /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy,
+      DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse());
+
+  // Set the callee in the CallExpr.
+  // FIXME: This loses syntactic information.
+  QualType CalleePtrTy = Context.getPointerType(FDecl->getType());
+  ExprResult PromotedCall =
+      ImpCastExprToType(NewDRE, CalleePtrTy, CK_BuiltinFnToFnPtr);
+  TheCall->setCallee(PromotedCall.get());
+
+  // Change the result type of the call to match the original value type. This
+  // is arbitrary, but the codegen for these builtins ins design to handle it
+  // gracefully.
+  TheCall->setType(ResultType);
+
+  return CallResult;
+}
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -2380,6 +2380,20 @@
     Value *Result = MB.CreateSub(Matrix1, Matrix2);
     return RValue::get(Result);
   }
+
+  case Builtin::BI__builtin_matrix_multiply: {
+    MatrixBuilder<CGBuilderTy> MB(Builder);
+    Value *Matrix1 = EmitScalarExpr(E->getArg(0));
+    Value *Matrix2 = EmitScalarExpr(E->getArg(1));
+
+    const MatrixType *Matrix1Ty = getMatrixTy(E->getArg(0)->getType());
+    const MatrixType *Matrix2Ty = getMatrixTy(E->getArg(1)->getType());
+    Value *Result = MB.CreateMatrixMultiply(
+        Matrix1, Matrix2, Matrix1Ty->getNumRows(), Matrix1Ty->getNumColumns(),
+        Matrix2Ty->getNumColumns());
+    return RValue::get(Result);
+  }
+
   case Builtin::BIfinite:
   case Builtin::BI__finite:
   case Builtin::BIfinitef:
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11618,6 +11618,8 @@
                                               ExprResult CallResult);
   ExprResult SemaBuiltinMatrixEltwiseOverload(CallExpr *TheCall,
                                               ExprResult CallResult);
+  ExprResult SemaBuiltinMatrixMultiplyOverload(CallExpr *TheCall,
+                                               ExprResult CallResult);
 
 public:
   enum FormatStringType {
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10287,18 +10287,27 @@
 def err_builtin_matrix_disabled: Error<
   "Builtin matrix support is disabled. Pass -fenable-matrix to enable it.">;
 
+def err_builtin_matrix_element_type: Error<
+  "Element types of input matrixes do not match (%0 != %1)">;
+
 def err_builtin_matrix_arg: Error<
   "%select{First|Second}0 argument must be a matrix">;
 
+def err_builtin_matrix_pointer_arg: Error<
+  "%select{First|Second}0 argument must be a %select{pointer|pointer to integers or floats}1">;
+
 def err_builtin_matrix_scalar_int_arg: Error<
   "%select{Row|Column|Offset|Stride}0 argument must be %select{an unsigned integer|a constant unsigned integer expression}1">;
 
-def err_builtin_matrix_implicit_cast_error: Error<
-  "Implicit cast to from %0 to %1 failed">;
-
 def err_builtin_matrix_type_match: Error<
   "Matrix types must match">;
 
+def err_builtin_matrix_dimension_error: Error<
+  "Matrix dimensions do not match operation">;
+
+def err_builtin_matrix_implicit_cast_error: Error<
+  "Implicit cast to from %0 to %1 failed">;
+
 def err_preserve_field_info_not_field : Error<
   "__builtin_preserve_field_info argument %0 not a field access">;
 def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -577,6 +577,7 @@
 BUILTIN(__builtin_matrix_extract, "v.", "nt")
 BUILTIN(__builtin_matrix_subtract, "v.", "nt")
 BUILTIN(__builtin_matrix_add, "v.", "nt")
+BUILTIN(__builtin_matrix_multiply, "v.", "nt")
 
 // "Overloaded" Atomic operator builtins.  These are overloaded to support data
 // types of i8, i16, i32, i64, and i128.  The front-end sees calls to the
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D72774: [Ma... Florian Hahn via Phabricator via cfe-commits

Reply via email to