fhahn created this revision.
Herald added a subscriber: tschuett.
Herald added a project: clang.

Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D72785

Files:
  clang/include/clang/Basic/Builtins.def
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/lib/Sema/SemaChecking.cpp

Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1619,6 +1619,7 @@
   case Builtin::BI__builtin_matrix_subtract:
   case Builtin::BI__builtin_matrix_multiply:
   case Builtin::BI__builtin_matrix_transpose:
+  case Builtin::BI__builtin_matrix_scalar_multiply:
   case Builtin::BI__builtin_matrix_column_load:
   case Builtin::BI__builtin_matrix_column_store:
     if (!getLangOpts().EnableMatrix) {
@@ -1638,6 +1639,8 @@
       return SemaBuiltinMatrixMultiplyOverload(TheCall, TheCallResult);
     case Builtin::BI__builtin_matrix_transpose:
       return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult);
+    case Builtin::BI__builtin_matrix_scalar_multiply:
+      return SemaBuiltinMatrixScalarOverload(TheCall, TheCallResult);
     case Builtin::BI__builtin_matrix_column_load:
       return SemaBuiltinMatrixColumnLoadOverload(TheCall, TheCallResult);
     case Builtin::BI__builtin_matrix_column_store:
@@ -15537,6 +15540,89 @@
   return CallResult;
 }
 
+ExprResult Sema::SemaBuiltinMatrixScalarOverload(CallExpr *TheCall,
+                                                 ExprResult CallResult) {
+  if (checkArgCount(*this, TheCall, 2)) {
+    return ExprError();
+  }
+
+  // First argument must be a matrix type
+  Expr *MatrixArg = TheCall->getArg(0);
+  Expr *ScalarArg = TheCall->getArg(1);
+
+  if (!MatrixArg->getType()->isMatrixType()) {
+    Diag(MatrixArg->getBeginLoc(), diag::err_builtin_matrix_scalar_type_error)
+        << 0;
+    return ExprError();
+  }
+
+  MatrixType const *MType =
+      cast<MatrixType const>(MatrixArg->getType().getCanonicalType());
+
+  // If the scalar type and matrix type don't match, try to cast it, otherwise,
+  // be sad
+  if (MType->getElementType() != ScalarArg->getType()) {
+    ExprResult TypeCastRes = ImplicitCastExpr::Create(
+        Context, MType->getElementType(), CK_IntegralToFloating, ScalarArg,
+        nullptr, VK_RValue);
+
+    if (!ScalarArg->getType()->isFloatingType() &&
+        !ScalarArg->getType()->isIntegralType(Context)) {
+      Diag(ScalarArg->getBeginLoc(), diag::err_builtin_matrix_scalar_type_error)
+          << 1;
+      return ExprError();
+    }
+
+    if (TypeCastRes.isInvalid()) {
+      Diag(MatrixArg->getBeginLoc(),
+           diag::err_builtin_matrix_implicit_cast_error)
+          << MType->getElementType() << ScalarArg->getType();
+      return ExprError();
+    }
+
+    ScalarArg = TypeCastRes.get();
+    TheCall->setArg(1, ScalarArg);
+  }
+
+  if (!MatrixArg->isRValue()) {
+    ExprResult CastExprResult = ImplicitCastExpr::Create(
+        Context, MatrixArg->getType(), CK_LValueToRValue, MatrixArg, nullptr,
+        VK_RValue);
+    assert(!CastExprResult.isInvalid() && "Matrix cast to R-value failed");
+    MatrixArg = CastExprResult.get();
+    TheCall->setArg(0, MatrixArg);
+  }
+
+  // Create the new function prototype
+  llvm::SmallVector<QualType, 2> ParameterTypes = {MatrixArg->getType(),
+                                                   ScalarArg->getType()};
+
+  Expr *Callee = TheCall->getCallee();
+  DeclRefExpr *DRE = cast<DeclRefExpr>(Callee->IgnoreParenCasts());
+  FunctionDecl *FDecl = cast<FunctionDecl>(DRE->getDecl());
+
+  // Create a new DeclRefExpr to refer to the new decl.
+  DeclRefExpr *NewDRE = DeclRefExpr::Create(
+      Context, DRE->getQualifierLoc(), SourceLocation(), FDecl,
+      /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy,
+      DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse());
+
+  // Set the callee in the CallExpr.
+  // FIXME: This loses syntactic information.
+  QualType CalleePtrTy = Context.getPointerType(FDecl->getType());
+  ExprResult PromotedCall = ImpCastExprToType(NewDRE, CalleePtrTy,
+                                              CK_BuiltinFnToFnPtr);
+  TheCall->setCallee(PromotedCall.get());
+
+  // Change the result type of the call to match the original value type. This
+  // is arbitrary, but the codegen for these builtins ins design to handle it
+  // gracefully.
+  TheCall->setType(MatrixArg->getType());
+
+
+  return CallResult;
+}
+
 ExprResult Sema::SemaBuiltinMatrixColumnLoadOverload(CallExpr *TheCall,
                                                      ExprResult CallResult) {
   // Must have exactly four operands
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -2451,6 +2451,14 @@
     return RValue::get(Result);
   }
 
+  case Builtin::BI__builtin_matrix_scalar_multiply: {
+    MatrixBuilder<CGBuilderTy> MB(Builder);
+    Value *Matrix = EmitScalarExpr(E->getArg(0));
+    Value *Scalar = EmitScalarExpr(E->getArg(1));
+    Value *Result = MB.CreateScalarMultiply(Matrix, Scalar);
+    return RValue::get(Result);
+  }
+
   case Builtin::BIfinite:
   case Builtin::BI__finite:
   case Builtin::BIfinitef:
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11618,6 +11618,8 @@
                                               ExprResult CallResult);
   ExprResult SemaBuiltinMatrixEltwiseOverload(CallExpr *TheCall,
                                               ExprResult CallResult);
+  ExprResult SemaBuiltinMatrixScalarOverload(CallExpr *TheCall,
+                                             ExprResult CallResult);
   ExprResult SemaBuiltinMatrixMultiplyOverload(CallExpr *TheCall,
                                                ExprResult CallResult);
   ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall,
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10308,6 +10308,10 @@
 def err_builtin_matrix_implicit_cast_error: Error<
   "Implicit cast to from %0 to %1 failed">;
 
+def err_builtin_matrix_scalar_type_error: Error<
+  "%select{First|Scalar}0 argument must be a "
+  "%select{matrix|float or integer}0">;
+
 def err_preserve_field_info_not_field : Error<
   "__builtin_preserve_field_info argument %0 not a field access">;
 def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -579,6 +579,7 @@
 BUILTIN(__builtin_matrix_add, "v.", "nt")
 BUILTIN(__builtin_matrix_multiply, "v.", "nt")
 BUILTIN(__builtin_matrix_transpose, "v.", "nFt")
+BUILTIN(__builtin_matrix_scalar_multiply, "v.", "nFt")
 BUILTIN(__builtin_matrix_column_load, "v.", "nFt")
 BUILTIN(__builtin_matrix_column_store, "v.", "nFt")
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D72785: [Ma... Florian Hahn via Phabricator via cfe-commits

Reply via email to