yubing updated this revision to Diff 319797.
yubing added a comment.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

Fix some bugs in lowerTileDPBSSD, lowerTileStore, lowerTileLoad


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D93594/new/

https://reviews.llvm.org/D93594

Files:
  clang/lib/Headers/amxintrin.h
  llvm/include/llvm/CodeGen/Passes.h
  llvm/lib/Target/X86/CMakeLists.txt
  llvm/lib/Target/X86/X86.h
  llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
  llvm/lib/Target/X86/X86TargetMachine.cpp
  llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll
  llvm/test/CodeGen/X86/O0-pipeline.ll

Index: llvm/test/CodeGen/X86/O0-pipeline.ll
===================================================================
--- llvm/test/CodeGen/X86/O0-pipeline.ll
+++ llvm/test/CodeGen/X86/O0-pipeline.ll
@@ -18,7 +18,9 @@
 ; CHECK-NEXT:     Pre-ISel Intrinsic Lowering
 ; CHECK-NEXT:     FunctionPass Manager
 ; CHECK-NEXT:       Expand Atomic instructions
-; CHECK-NEXT:       Lower AMX type for load/store
+; CHECK-NEXT:       Dominator Tree Construction
+; CHECK-NEXT:       Natural Loop Information
+; CHECK-NEXT:       Lower AMX intrinsics
 ; CHECK-NEXT:       Module Verifier
 ; CHECK-NEXT:       Lower Garbage Collection Instructions
 ; CHECK-NEXT:       Shadow Stack GC Lowering
Index: llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll
@@ -0,0 +1,198 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -lower-amx-intrinsics %s -S | FileCheck %s
+
+define dso_local void @test_amx_load_non_O0(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr) {
+; CHECK-LABEL: @test_amx_load_non_O0(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AMX:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[ROW:%.*]], i16 [[COL:%.*]], i8* [[PTR:%.*]], i64 [[STRIDE:%.*]])
+; CHECK-NEXT:    [[VEC:%.*]] = bitcast x86_amx [[AMX]] to <256 x i32>
+; CHECK-NEXT:    store <256 x i32> [[VEC]], <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, i8* %ptr, i64 %stride)
+  %vec = bitcast x86_amx %amx to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+define dso_local void @test_amx_load(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_load(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = udiv i16 [[COL:%.*]], 4
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i64 [[STRIDE:%.*]], 4
+; CHECK-NEXT:    br label [[TILELOAD_UNROLL_ROWS_HEADER:%.*]]
+; CHECK:       tileload.unroll.rows.header:
+; CHECK-NEXT:    [[TILELOAD_UNROLL_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILELOAD_UNROLL_ROWS_STEP:%.*]], [[TILELOAD_UNROLL_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP11:%.*]], [[TILELOAD_UNROLL_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_UNROLL_ROWS_BODY:%.*]]
+; CHECK:       tileload.unroll.rows.body:
+; CHECK-NEXT:    br label [[TILELOAD_UNROLL_COLS_HEADER:%.*]]
+; CHECK:       tileload.unroll.cols.header:
+; CHECK-NEXT:    [[TILELOAD_UNROLL_COLS_IV:%.*]] = phi i16 [ 0, [[TILELOAD_UNROLL_ROWS_BODY]] ], [ [[TILELOAD_UNROLL_COLS_STEP:%.*]], [[TILELOAD_UNROLL_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW]], [[TILELOAD_UNROLL_ROWS_BODY]] ], [ [[TMP11]], [[TILELOAD_UNROLL_COLS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_UNROLL_COLS_BODY:%.*]]
+; CHECK:       tileload.unroll.cols.body:
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TILELOAD_UNROLL_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i16 [[TILELOAD_UNROLL_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast i8* [[PTR:%.*]] to i32*
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i32, i32* [[TMP6]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = load i32, i32* [[TMP7]], align 4
+; CHECK-NEXT:    [[TMP9:%.*]] = mul i16 [[TILELOAD_UNROLL_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP10:%.*]] = add i16 [[TMP9]], [[TILELOAD_UNROLL_COLS_IV]]
+; CHECK-NEXT:    [[TMP11]] = insertelement <256 x i32> [[VEC_PHI]], i32 [[TMP8]], i16 [[TMP10]]
+; CHECK-NEXT:    br label [[TILELOAD_UNROLL_COLS_LATCH]]
+; CHECK:       tileload.unroll.cols.latch:
+; CHECK-NEXT:    [[TILELOAD_UNROLL_COLS_STEP]] = add i16 [[TILELOAD_UNROLL_COLS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_UNROLL_COLS_COND:%.*]] = icmp ne i16 [[TILELOAD_UNROLL_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    br i1 [[TILELOAD_UNROLL_COLS_COND]], label [[TILELOAD_UNROLL_COLS_HEADER]], label [[TILELOAD_UNROLL_ROWS_LATCH]]
+; CHECK:       tileload.unroll.rows.latch:
+; CHECK-NEXT:    [[TILELOAD_UNROLL_ROWS_STEP]] = add i16 [[TILELOAD_UNROLL_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_UNROLL_ROWS_COND:%.*]] = icmp ne i16 [[TILELOAD_UNROLL_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILELOAD_UNROLL_ROWS_COND]], label [[TILELOAD_UNROLL_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    store <256 x i32> [[TMP11]], <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, i8* %ptr, i64 %stride)
+  %vec = bitcast x86_amx %amx to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+define dso_local void @test_amx_dp(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_dp(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx
+; CHECK-NEXT:    [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx
+; CHECK-NEXT:    [[C_AMX:%.*]] = bitcast <256 x i32> [[C:%.*]] to x86_amx
+; CHECK-NEXT:    [[TMP0:%.*]] = udiv i16 [[COL:%.*]], 4
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    br label [[TILEDPBSSD_UNROLL_ROWS_HEADER:%.*]]
+; CHECK:       tiledpbssd.unroll.rows.header:
+; CHECK-NEXT:    [[TILEDPBSSD_UNROLL_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILEDPBSSD_UNROLL_ROWS_STEP:%.*]], [[TILEDPBSSD_UNROLL_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW:%.*]] = phi <256 x i32> [ [[C]], [[ENTRY]] ], [ [[TMP18:%.*]], [[TILEDPBSSD_UNROLL_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_UNROLL_ROWS_BODY:%.*]]
+; CHECK:       tiledpbssd.unroll.rows.body:
+; CHECK-NEXT:    br label [[TILEDPBSSD_UNROLL_COLS_HEADER:%.*]]
+; CHECK:       tiledpbssd.unroll.cols.header:
+; CHECK-NEXT:    [[TILEDPBSSD_UNROLL_COLS_IV:%.*]] = phi i16 [ 0, [[TILEDPBSSD_UNROLL_ROWS_BODY]] ], [ [[TILEDPBSSD_UNROLL_COLS_STEP:%.*]], [[TILEDPBSSD_UNROLL_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW]], [[TILEDPBSSD_UNROLL_ROWS_BODY]] ], [ [[TMP18]], [[TILEDPBSSD_UNROLL_COLS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_UNROLL_COLS_BODY:%.*]]
+; CHECK:       tiledpbssd.unroll.cols.body:
+; CHECK-NEXT:    br label [[TILEDPBSSD_UNROLL_INNER_HEADER:%.*]]
+; CHECK:       tiledpbssd.unroll.inner.header:
+; CHECK-NEXT:    [[TILEDPBSSD_UNROLL_INNER_IV:%.*]] = phi i16 [ 0, [[TILEDPBSSD_UNROLL_COLS_BODY]] ], [ [[TILEDPBSSD_UNROLL_INNER_STEP:%.*]], [[TILEDPBSSD_UNROLL_INNER_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <256 x i32> [ [[VEC_PHI_COL]], [[TILEDPBSSD_UNROLL_COLS_BODY]] ], [ [[TMP18]], [[TILEDPBSSD_UNROLL_INNER_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_UNROLL_INNER_BODY:%.*]]
+; CHECK:       tiledpbssd.unroll.inner.body:
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i16 [[TILEDPBSSD_UNROLL_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = add i16 [[TMP2]], [[TILEDPBSSD_UNROLL_COLS_IV]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i16 [[TILEDPBSSD_UNROLL_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = add i16 [[TMP4]], [[TILEDPBSSD_UNROLL_INNER_IV]]
+; CHECK-NEXT:    [[TMP6:%.*]] = mul i16 [[TILEDPBSSD_UNROLL_INNER_IV]], 16
+; CHECK-NEXT:    [[TMP7:%.*]] = add i16 [[TMP6]], [[TILEDPBSSD_UNROLL_COLS_IV]]
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <256 x i32> [[VEC_PHI]], i16 [[TMP3]]
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <256 x i32> [[A]], i16 [[TMP5]]
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast i32 [[TMP9]] to <4 x i8>
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <256 x i32> [[B]], i16 [[TMP7]]
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast i32 [[TMP11]] to <4 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i8> [[TMP12]] to <4 x i32>
+; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i8> [[TMP10]] to <4 x i32>
+; CHECK-NEXT:    [[TMP15:%.*]] = mul <4 x i32> [[TMP14]], [[TMP13]]
+; CHECK-NEXT:    [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP15]])
+; CHECK-NEXT:    [[TMP17:%.*]] = add i32 [[TMP8]], [[TMP16]]
+; CHECK-NEXT:    [[TMP18]] = insertelement <256 x i32> [[VEC_PHI]], i32 [[TMP17]], i16 [[TMP3]]
+; CHECK-NEXT:    br label [[TILEDPBSSD_UNROLL_INNER_LATCH]]
+; CHECK:       tiledpbssd.unroll.inner.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_UNROLL_INNER_STEP]] = add i16 [[TILEDPBSSD_UNROLL_INNER_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_UNROLL_INNER_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_UNROLL_INNER_STEP]], [[TMP1]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_UNROLL_INNER_COND]], label [[TILEDPBSSD_UNROLL_INNER_HEADER]], label [[TILEDPBSSD_UNROLL_COLS_LATCH]]
+; CHECK:       tiledpbssd.unroll.cols.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_UNROLL_COLS_STEP]] = add i16 [[TILEDPBSSD_UNROLL_COLS_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_UNROLL_COLS_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_UNROLL_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_UNROLL_COLS_COND]], label [[TILEDPBSSD_UNROLL_COLS_HEADER]], label [[TILEDPBSSD_UNROLL_ROWS_LATCH]]
+; CHECK:       tiledpbssd.unroll.rows.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_UNROLL_ROWS_STEP]] = add i16 [[TILEDPBSSD_UNROLL_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_UNROLL_ROWS_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_UNROLL_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_UNROLL_ROWS_COND]], label [[TILEDPBSSD_UNROLL_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    store <256 x i32> [[TMP18]], <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a.amx = bitcast <256 x i32> %a to x86_amx
+  %b.amx = bitcast <256 x i32> %b to x86_amx
+  %c.amx = bitcast <256 x i32> %c to x86_amx
+  %acc = call x86_amx @llvm.x86.tdpbssd.internal(i16 %row, i16 %col, i16 %k, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx)
+  %vec = bitcast x86_amx %acc to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+define dso_local void @test_amx_store(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr, <256 x i32> %vec) #0 {
+; CHECK-LABEL: @test_amx_store(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AMX:%.*]] = bitcast <256 x i32> [[VEC:%.*]] to x86_amx
+; CHECK-NEXT:    [[TMP0:%.*]] = udiv i16 [[COL:%.*]], 4
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i64 [[STRIDE:%.*]], 4
+; CHECK-NEXT:    br label [[TILESTORE_UNROLL_ROWS_HEADER:%.*]]
+; CHECK:       tilestore.unroll.rows.header:
+; CHECK-NEXT:    [[TILESTORE_UNROLL_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILESTORE_UNROLL_ROWS_STEP:%.*]], [[TILESTORE_UNROLL_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    br label [[TILESTORE_UNROLL_ROWS_BODY:%.*]]
+; CHECK:       tilestore.unroll.rows.body:
+; CHECK-NEXT:    br label [[TILESTORE_UNROLL_COLS_HEADER:%.*]]
+; CHECK:       tilestore.unroll.cols.header:
+; CHECK-NEXT:    [[TILESTORE_UNROLL_COLS_IV:%.*]] = phi i16 [ 0, [[TILESTORE_UNROLL_ROWS_BODY]] ], [ [[TILESTORE_UNROLL_COLS_STEP:%.*]], [[TILESTORE_UNROLL_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    br label [[TILESTORE_UNROLL_COLS_BODY:%.*]]
+; CHECK:       tilestore.unroll.cols.body:
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TILESTORE_UNROLL_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i16 [[TILESTORE_UNROLL_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast i8* [[PTR:%.*]] to i32*
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i32, i32* [[TMP6]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = mul i16 [[TILESTORE_UNROLL_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP9:%.*]] = add i16 [[TMP8]], [[TILESTORE_UNROLL_COLS_IV]]
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <256 x i32> [[VEC]], i16 [[TMP9]]
+; CHECK-NEXT:    store i32 [[TMP10]], i32* [[TMP7]], align 4
+; CHECK-NEXT:    br label [[TILESTORE_UNROLL_COLS_LATCH]]
+; CHECK:       tilestore.unroll.cols.latch:
+; CHECK-NEXT:    [[TILESTORE_UNROLL_COLS_STEP]] = add i16 [[TILESTORE_UNROLL_COLS_IV]], 1
+; CHECK-NEXT:    [[TILESTORE_UNROLL_COLS_COND:%.*]] = icmp ne i16 [[TILESTORE_UNROLL_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    br i1 [[TILESTORE_UNROLL_COLS_COND]], label [[TILESTORE_UNROLL_COLS_HEADER]], label [[TILESTORE_UNROLL_ROWS_LATCH]]
+; CHECK:       tilestore.unroll.rows.latch:
+; CHECK-NEXT:    [[TILESTORE_UNROLL_ROWS_STEP]] = add i16 [[TILESTORE_UNROLL_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILESTORE_UNROLL_ROWS_COND:%.*]] = icmp ne i16 [[TILESTORE_UNROLL_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILESTORE_UNROLL_ROWS_COND]], label [[TILESTORE_UNROLL_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = bitcast <256 x i32> %vec to x86_amx
+  call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %ptr, i64 %stride, x86_amx %amx)
+  ret void
+}
+
+define dso_local void @test_amx_zero(i16 signext %row, i16 signext %col, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_zero(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    store <256 x i32> zeroinitializer, <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
+  %vec = bitcast x86_amx %amx to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
+
+attributes #0 = { noinline nounwind optnone }
Index: llvm/lib/Target/X86/X86TargetMachine.cpp
===================================================================
--- llvm/lib/Target/X86/X86TargetMachine.cpp
+++ llvm/lib/Target/X86/X86TargetMachine.cpp
@@ -62,6 +62,7 @@
   RegisterTargetMachine<X86TargetMachine> Y(getTheX86_64Target());
 
   PassRegistry &PR = *PassRegistry::getPassRegistry();
+  initializeX86LowerAMXIntrinsicsLegacyPassPass(PR);
   initializeX86LowerAMXTypeLegacyPassPass(PR);
   initializeGlobalISel(PR);
   initializeWinEHStatePassPass(PR);
@@ -410,7 +411,12 @@
 
 void X86PassConfig::addIRPasses() {
   addPass(createAtomicExpandPass());
-  addPass(createX86LowerAMXTypePass());
+
+  if (TM->getOptLevel() == CodeGenOpt::None)
+    addPass(createX86LowerAMXIntrinsicsPass());
+  else {
+    addPass(createX86LowerAMXTypePass());
+  }
 
   TargetPassConfig::addIRPasses();
 
Index: llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
===================================================================
--- /dev/null
+++ llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
@@ -0,0 +1,531 @@
+//===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file Pass to transform amx intrinsics to scalar operation.
+/// This pass is only enabled with -O0. With -O0, the def of shape to amx
+/// intrinsics is near the amx intrinsics code. We are not bale to find a
+/// point which post-dominate all the shape and dominate all amx intrinsics.
+/// To decouple the dependency of the shape, we transform amx intrinsics
+/// to scalar operation, so that compiling doesn't fail. In long term, we
+/// should improve fast register allocation to allocate amx register.
+//===----------------------------------------------------------------------===//
+//
+#include "X86.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/ValueTypes.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsX86.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+
+using namespace llvm;
+using namespace PatternMatch;
+
+#define DEBUG_TYPE "lower-amx-intrinsics"
+
+static BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit,
+                              Value *Bound, Value *Step, StringRef Name,
+                              IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
+                              LoopInfo &LI) {
+  LLVMContext &Ctx = Preheader->getContext();
+  BasicBlock *Header = BasicBlock::Create(
+      Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
+  BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
+                                        Header->getParent(), Exit);
+  BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
+                                         Header->getParent(), Exit);
+
+  Type *I16Ty = Type::getInt16Ty(Ctx);
+  BranchInst::Create(Body, Header);
+  BranchInst::Create(Latch, Body);
+  PHINode *IV =
+      PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
+  IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
+
+  B.SetInsertPoint(Latch);
+  Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
+  Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
+  BranchInst::Create(Header, Exit, Cond, Latch);
+  IV->addIncoming(Inc, Latch);
+
+  BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
+  BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
+  PreheaderBr->setSuccessor(0, Header);
+  DTU.applyUpdatesPermissive({
+      {DominatorTree::Delete, Preheader, Tmp},
+      {DominatorTree::Insert, Header, Body},
+      {DominatorTree::Insert, Body, Latch},
+      {DominatorTree::Insert, Latch, Header},
+      {DominatorTree::Insert, Latch, Exit},
+      {DominatorTree::Insert, Preheader, Header},
+  });
+
+  L->addBasicBlockToLoop(Header, LI);
+  L->addBasicBlockToLoop(Body, LI);
+  L->addBasicBlockToLoop(Latch, LI);
+  return Body;
+}
+
+static Value *createTileLoadLoops(BasicBlock *Start, BasicBlock *End,
+                                  IRBuilderBase &B, DomTreeUpdater &DTU,
+                                  LoopInfo &LI, Value *Row, Value *Col,
+                                  Value *Ptr, Value *Stride) {
+  Loop *RowLoop = LI.AllocateLoop();
+  Loop *ColLoop = LI.AllocateLoop();
+  RowLoop->addChildLoop(ColLoop);
+  if (Loop *ParentL = LI.getLoopFor(Start))
+    ParentL->addChildLoop(RowLoop);
+  else
+    LI.addTopLevelLoop(RowLoop);
+
+  BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
+                                   "tileload.unroll.rows", B, DTU, RowLoop, LI);
+  BasicBlock *RowLatch = RowBody->getSingleSuccessor();
+
+  // uint16_t ColStep = B.getInt32Ty()->getPrimitiveSizeInBits() / 8;
+  uint16_t ColStep = 1;
+  BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(ColStep),
+                                   "tileload.unroll.cols", B, DTU, ColLoop, LI);
+
+  BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
+  BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor();
+  BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
+  Value *CurrentRow = &*RowLoopHeader->begin();
+  Value *CurrentCol = &*ColumnLoopHeader->begin();
+
+  // tileload.unroll.rows.header:
+  // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %40,
+  // %tileload.unroll.rows.latch ]
+  B.SetInsertPoint(RowLoopHeader->getTerminator());
+  FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
+  Value *VecZero = Constant::getNullValue(V256I32Ty);
+  PHINode *VecPhi_Row_Loop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
+  VecPhi_Row_Loop->addIncoming(VecZero, Start);
+
+  // tileload.unroll.cols.header:
+  // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.unroll.rows.body ], [
+  // %40, %tileload.unroll.cols.latch ]
+  B.SetInsertPoint(ColumnLoopHeader->getTerminator());
+  // Value *UndefVec = UndefValue::get(V256I32Ty);
+  PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
+  VecPhi->addIncoming(VecPhi_Row_Loop, RowBody);
+
+  // tileload.unroll.cols.body:
+  // %elt = load i32 i32 *ptr
+  // %mul = mul i16 %row.iv, i16 16
+  // %add = add i16 %mul, i16 %col.iv
+  // %vec2 = insertelement <16 x i32> %vecphi, i32 %elt, i16 %idx
+  B.SetInsertPoint(ColBody->getTerminator());
+  Type *EltTy = V256I32Ty->getElementType();
+  Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
+  Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
+  Value *Offset =
+      B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
+  unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
+  Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
+  Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
+  Value *Elt = B.CreateLoad(EltTy, EltPtr);
+  Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
+  Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
+  VecPhi->addIncoming(ResVec, ColLoopLatch);
+  VecPhi_Row_Loop->addIncoming(ResVec, RowLatch);
+
+  return ResVec;
+}
+
+static void createTileStoreLoops(BasicBlock *Start, BasicBlock *End,
+                                 IRBuilderBase &B, DomTreeUpdater &DTU,
+                                 LoopInfo &LI, Value *Row, Value *Col,
+                                 Value *Ptr, Value *Stride, Value *Tile) {
+  Loop *RowLoop = LI.AllocateLoop();
+  Loop *ColLoop = LI.AllocateLoop();
+  RowLoop->addChildLoop(ColLoop);
+  if (Loop *ParentL = LI.getLoopFor(Start))
+    ParentL->addChildLoop(RowLoop);
+  else
+    LI.addTopLevelLoop(RowLoop);
+
+  BasicBlock *RowBody =
+      createLoop(Start, End, Row, B.getInt16(1), "tilestore.unroll.rows", B,
+                 DTU, RowLoop, LI);
+  BasicBlock *RowLatch = RowBody->getSingleSuccessor();
+
+  uint16_t ColStep = 1;
+  BasicBlock *ColBody =
+      createLoop(RowBody, RowLatch, Col, B.getInt16(ColStep),
+                 "tilestore.unroll.cols", B, DTU, ColLoop, LI);
+
+  BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor();
+  BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
+  Value *CurrentRow = &*RowLoopHeader->begin();
+  Value *CurrentCol = &*ColumnLoopHeader->begin();
+
+  Value *Vec = nullptr;
+  if (auto BitCast = dyn_cast<BitCastInst>(Tile))
+    Vec = BitCast->getOperand(0);
+  assert(Vec && Vec->getType()->isVectorTy() &&
+         "bitcast from non-v256i32 to x86amx");
+
+  B.SetInsertPoint(ColumnLoopHeader->getTerminator());
+  FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
+  Type *EltTy = V256I32Ty->getElementType();
+
+  // cols.body:
+  B.SetInsertPoint(ColBody->getTerminator());
+  Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
+  Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
+  Value *Offset =
+      B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
+  unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
+  Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
+  Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
+  // %mul = mul i16 %row.iv, i16 16
+  // %idx = add i16 %mul, i16 %col.iv
+  // %vec = extractelement <16 x i32> %vec, i16 %idx
+  // store i32 %vec, i32* %ptr
+  Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
+  Value *Elt = B.CreateExtractElement(Vec, Idx);
+
+  B.CreateStore(Elt, EltPtr);
+}
+
+static Value *createTileDPBSSDLoops(BasicBlock *Start, BasicBlock *End,
+                                    IRBuilderBase &B, DomTreeUpdater &DTU,
+                                    LoopInfo &LI, Value *Row, Value *Col,
+                                    Value *K, Value *Acc, Value *LHS,
+                                    Value *RHS) {
+  Loop *RowLoop = LI.AllocateLoop();
+  Loop *ColLoop = LI.AllocateLoop();
+  Loop *InnerLoop = LI.AllocateLoop();
+  ColLoop->addChildLoop(InnerLoop);
+  RowLoop->addChildLoop(ColLoop);
+  if (Loop *ParentL = LI.getLoopFor(Start))
+    ParentL->addChildLoop(RowLoop);
+  else
+    LI.addTopLevelLoop(RowLoop);
+
+  BasicBlock *RowBody =
+      createLoop(Start, End, Row, B.getInt16(1), "tiledpbssd.unroll.rows", B,
+                 DTU, RowLoop, LI);
+  BasicBlock *RowLatch = RowBody->getSingleSuccessor();
+
+  BasicBlock *ColBody =
+      createLoop(RowBody, RowLatch, Col, B.getInt16(1),
+                 "tiledpbssd.unroll.cols", B, DTU, ColLoop, LI);
+  BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
+
+  B.SetInsertPoint(ColBody->getTerminator());
+  BasicBlock *InnerBody =
+      createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
+                 "tiledpbssd.unroll.inner", B, DTU, InnerLoop, LI);
+
+  BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor();
+  BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
+  BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
+  BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
+  Value *CurrentRow = &*RowLoopHeader->begin();
+  Value *CurrentCol = &*ColumnLoopHeader->begin();
+  Value *CurrentInner = &*InnerLoopHeader->begin();
+
+  FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
+  // Type *EltTy = V256I32Ty->getElementType();
+  Value *VecC, *VecA, *VecB;
+  if (auto BitCast = dyn_cast<BitCastInst>(Acc))
+    VecC = BitCast->getOperand(0);
+  assert(VecC->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx");
+  // TODO else create BitCast from x86amx to v256i32.
+  // Store x86amx to memory, and reload from memory
+  // to vector. However with -O0, it doesn't happen.
+  if (auto BitCast = dyn_cast<BitCastInst>(LHS))
+    VecA = BitCast->getOperand(0);
+  assert(VecA->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx");
+  if (auto BitCast = dyn_cast<BitCastInst>(RHS))
+    VecB = BitCast->getOperand(0);
+  assert(VecB->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx");
+
+  // tiledpbssd.unroll.rows.header:
+  // %vec.phi.rows = phi <256 x i32> [ %vec_c, %continue ], [ %NewVecC,
+  // %tiledpbssd.unroll.rows.latch ]
+  B.SetInsertPoint(RowLoopHeader->getTerminator());
+  PHINode *VecPhi_Row_Loop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
+  VecPhi_Row_Loop->addIncoming(VecC, Start);
+
+  // tiledpbssd.unroll.cols.header:
+  // %vec.phi.cols = phi <256 x i32> [ %vec.phi.rows,
+  // %tiledpbssd.unroll.rows.body ], [ %NewVecC, %tiledpbssd.unroll.cols.latch ]
+  B.SetInsertPoint(ColumnLoopHeader->getTerminator());
+  PHINode *VecPhi_Col_Loop = B.CreatePHI(V256I32Ty, 2, "vec.phi.col");
+  VecPhi_Col_Loop->addIncoming(VecPhi_Row_Loop, RowBody);
+
+  // Generate PHI vector for C.
+  B.SetInsertPoint(InnerLoopHeader->getTerminator());
+  PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
+  VecCPhi->addIncoming(VecPhi_Col_Loop, ColBody);
+
+  // Generate accmulate multiply in innerbody.
+  B.SetInsertPoint(InnerBody->getTerminator());
+  Value *IdxC =
+      B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
+  Value *IdxA =
+      B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
+  Value *IdxB =
+      B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
+
+  FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
+  FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
+  Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
+  Value *EltA = B.CreateExtractElement(VecA, IdxA);
+  Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
+  Value *EltB = B.CreateExtractElement(VecB, IdxB);
+  Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
+  Value *SubVecR = B.CreateAddReduce(B.CreateMul(
+      B.CreateSExt(SubVecA, V4I32Ty), B.CreateSExt(SubVecB, V4I32Ty)));
+  Value *ResElt = B.CreateAdd(EltC, SubVecR);
+  Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
+  VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
+  VecPhi_Row_Loop->addIncoming(NewVecC, RowLatch);
+  VecPhi_Col_Loop->addIncoming(NewVecC, ColLoopLatch);
+
+  return NewVecC;
+}
+
+namespace {
+class X86LowerAMXIntrinsics {
+  Function &Func;
+
+public:
+  X86LowerAMXIntrinsics(Function &F, DominatorTree *DT, LoopInfo *LI)
+      : Func(F), DT(DT), LI(LI) {}
+  bool visit();
+
+private:
+  DominatorTree *DT;
+  LoopInfo *LI;
+  bool lowerTileLoad(Instruction *TileLoad);
+  bool lowerTileDPBSSD(Instruction *TileDPBSSD);
+  bool lowerTileStore(Instruction *TileStore);
+  bool lowerTileZero(Instruction *TileZero);
+};
+
+bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) {
+  Value *M, *N, *K, *C, *A, *B;
+  match(TileDPBSSD, m_Intrinsic<Intrinsic::x86_tdpbssd_internal>(
+                        m_Value(M), m_Value(N), m_Value(K), m_Value(C),
+                        m_Value(A), m_Value(B)));
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  Instruction *InsertI = TileDPBSSD;
+  IRBuilder<> Builder_Prepare(TileDPBSSD);
+  Builder_Prepare.SetInsertPoint(TileDPBSSD);
+  // We visit the loop with (m, n/4, k/4):
+  // %n_dword = udiv i16 %n, 4
+  // %k_dword = udiv i16 %k, 4
+  Value *N_DWord = Builder_Prepare.CreateUDiv(N, Builder_Prepare.getInt16(4));
+  Value *K_DWord = Builder_Prepare.CreateUDiv(K, Builder_Prepare.getInt16(4));
+  BasicBlock *Start = InsertI->getParent();
+  BasicBlock *End =
+      SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
+  IRBuilder<> Builder(TileDPBSSD);
+  Value *ResVec = createTileDPBSSDLoops(Start, End, Builder, DTU, *LI, M,
+                                        N_DWord, K_DWord, C, A, B);
+
+  // Delete tileloadd6 intrinsic and bitcast instruction.
+  for (auto UI = TileDPBSSD->use_begin(), UE = TileDPBSSD->use_end();
+       UI != UE;) {
+    Instruction *I = cast<Instruction>((UI++)->getUser());
+    Value *Vec;
+    if (match(I, m_BitCast(m_Value(Vec)))) {
+      I->replaceAllUsesWith(ResVec);
+      I->eraseFromParent();
+    }
+  }
+  TileDPBSSD->eraseFromParent();
+  return true;
+}
+
+bool X86LowerAMXIntrinsics::lowerTileLoad(Instruction *TileLoad) {
+  Value *M, *N, *Ptr, *Stride;
+  match(TileLoad, m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
+                      m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  Instruction *InsertI = TileLoad;
+  IRBuilder<> Builder_Prepare(TileLoad);
+  Builder_Prepare.SetInsertPoint(TileLoad);
+  Value *N_DWord = Builder_Prepare.CreateUDiv(N, Builder_Prepare.getInt16(4));
+  Value *Stride_DWord =
+      Builder_Prepare.CreateUDiv(Stride, Builder_Prepare.getInt64(4));
+  BasicBlock *Start = InsertI->getParent();
+  BasicBlock *End =
+      SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
+  IRBuilder<> Builder(TileLoad);
+  Value *ResVec = createTileLoadLoops(Start, End, Builder, DTU, *LI, M, N_DWord,
+                                      Ptr, Stride_DWord);
+
+  // Delete tileloadd6 intrinsic and bitcast instruction.
+  for (auto UI = TileLoad->use_begin(), UE = TileLoad->use_end(); UI != UE;) {
+    Instruction *I = cast<Instruction>((UI++)->getUser());
+    Value *Vec;
+    if (match(I, m_BitCast(m_Value(Vec)))) {
+      I->replaceAllUsesWith(ResVec);
+      I->eraseFromParent();
+    }
+  }
+  TileLoad->eraseFromParent();
+  return true;
+}
+
+bool X86LowerAMXIntrinsics::lowerTileStore(Instruction *TileStore) {
+  Value *M, *N, *Ptr, *Stride, *Tile;
+  match(TileStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
+                       m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride),
+                       m_Value(Tile)));
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  Instruction *InsertI = TileStore;
+  IRBuilder<> Builder_Prepare(TileStore);
+  Builder_Prepare.SetInsertPoint(TileStore);
+  Value *N_DWord = Builder_Prepare.CreateUDiv(N, Builder_Prepare.getInt16(4));
+  Value *Stride_DWord =
+      Builder_Prepare.CreateUDiv(Stride, Builder_Prepare.getInt64(4));
+  BasicBlock *Start = InsertI->getParent();
+  BasicBlock *End =
+      SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
+  IRBuilder<> Builder(TileStore);
+  createTileStoreLoops(Start, End, Builder, DTU, *LI, M, N_DWord, Ptr,
+                       Stride_DWord, Tile);
+
+  TileStore->eraseFromParent();
+  return true;
+}
+
+bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
+  IRBuilder<> Builder(TileZero);
+  FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
+  Value *VecZero = Constant::getNullValue(V256I32Ty);
+  for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) {
+    Instruction *I = cast<Instruction>((UI++)->getUser());
+    Value *Vec;
+    if (match(I, m_BitCast(m_Value(Vec)))) {
+      I->replaceAllUsesWith(VecZero);
+      I->eraseFromParent();
+    }
+  }
+  TileZero->eraseFromParent();
+  return true;
+}
+
+bool X86LowerAMXIntrinsics::visit() {
+  bool C;
+  SmallVector<Instruction *, 8> TileDPBSSDs;
+  SmallVector<Instruction *, 8> TileLoads;
+  SmallVector<Instruction *, 8> TileStores;
+  SmallVector<Instruction *, 8> TileZeros;
+
+  for (BasicBlock *BB : post_order(&Func)) {
+    for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
+         II != IE;) {
+      Instruction &Inst = *II++;
+      if (match(&Inst, m_Intrinsic<Intrinsic::x86_tdpbssd_internal>())) {
+        // %amx1 = bitcast <256 x i32> %vec to x86_amx
+        // %res = call x86_amx @llvm.x86.tdpbssd.internal(i16 m, i16 n, i16 k,
+        //                                                x86_amx, %amx1, ...)
+        // %vec2 = bitcast x86_amx %res to <256 x i32>
+        TileDPBSSDs.push_back(&Inst);
+      } else if (match(&Inst,
+                       m_Intrinsic<Intrinsic::x86_tileloadd64_internal>())) {
+        // %17 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %13, i16 %14,
+        //                                                   i8* %15, i64 %16)
+        // %18 = bitcast x86_amx %17 to <256 x i32>
+        TileLoads.push_back(&Inst);
+      } else if (match(&Inst,
+                       m_Intrinsic<Intrinsic::x86_tilestored64_internal>())) {
+        // %89 = bitcast <256 x i32> %88 to x86_amx
+        // call void @llvm.x86.tilestored64.internal(i16 %84, i16 %85, i8* %86,
+        //                                           i64 %87, x86_amx %89)
+        TileStores.push_back(&Inst);
+      } else if (match(&Inst,
+                       m_Intrinsic<Intrinsic::x86_tilezero_internal>())) {
+        // %89 = bitcast <256 x i32> %88 to x86_amx
+        // call void @llvm.x86.tilezero.internal(i16 %84, i16 %85)
+        TileZeros.push_back(&Inst);
+      }
+    }
+  }
+
+  for (auto *Inst : TileLoads) {
+    C |= lowerTileLoad(Inst);
+  }
+  for (auto *Inst : TileDPBSSDs) {
+    C |= lowerTileDPBSSD(Inst);
+  }
+  for (auto *Inst : TileStores) {
+    C |= lowerTileStore(Inst);
+  }
+  for (auto *Inst : TileZeros) {
+    C |= lowerTileZero(Inst);
+  }
+
+  return C;
+}
+} // anonymous namespace
+
+namespace {
+
+class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
+public:
+  static char ID;
+
+  X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
+    initializeX86LowerAMXIntrinsicsLegacyPassPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override {
+    if (!F.hasFnAttribute(Attribute::OptimizeNone))
+      return false;
+
+    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+    auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+
+    X86LowerAMXIntrinsics LAT(F, &DT, &LI);
+    bool C = LAT.visit();
+    return C;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addPreserved<DominatorTreeWrapperPass>();
+    AU.addRequired<LoopInfoWrapperPass>();
+    AU.addPreserved<LoopInfoWrapperPass>();
+  }
+};
+
+} // anonymous namespace
+
+static const char PassName[] = "Lower AMX intrinsics";
+char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
+                      false, false)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
+                    false, false)
+
+FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
+  return new X86LowerAMXIntrinsicsLegacyPass();
+}
Index: llvm/lib/Target/X86/X86.h
===================================================================
--- llvm/lib/Target/X86/X86.h
+++ llvm/lib/Target/X86/X86.h
@@ -169,6 +169,7 @@
 void initializeX86PreTileConfigPass(PassRegistry &);
 void initializeX86TileConfigPass(PassRegistry &);
 void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &);
+void initializeX86LowerAMXIntrinsicsLegacyPassPass(PassRegistry &);
 
 namespace X86AS {
 enum : unsigned {
Index: llvm/lib/Target/X86/CMakeLists.txt
===================================================================
--- llvm/lib/Target/X86/CMakeLists.txt
+++ llvm/lib/Target/X86/CMakeLists.txt
@@ -33,6 +33,7 @@
   X86DomainReassignment.cpp
   X86DiscriminateMemOps.cpp
   X86LowerAMXType.cpp
+  X86LowerAMXIntrinsics.cpp
   X86TileConfig.cpp
   X86PreTileConfig.cpp
   X86ExpandPseudo.cpp
Index: llvm/include/llvm/CodeGen/Passes.h
===================================================================
--- llvm/include/llvm/CodeGen/Passes.h
+++ llvm/include/llvm/CodeGen/Passes.h
@@ -492,6 +492,8 @@
   /// The pass transform load/store <256 x i32> to AMX load/store intrinsics
   /// or split the data to two <128 x i32>.
   FunctionPass *createX86LowerAMXTypePass();
+
+  FunctionPass *createX86LowerAMXIntrinsicsPass();
 } // End llvm namespace
 
 #endif
Index: clang/lib/Headers/amxintrin.h
===================================================================
--- clang/lib/Headers/amxintrin.h
+++ clang/lib/Headers/amxintrin.h
@@ -258,7 +258,7 @@
 }
 
 __DEFAULT_FN_ATTRS_INT8
-static void __tile_dpbsud(__tile1024i *dst, __tile1024i src1,
+static void __tile_dpbssd(__tile1024i *dst, __tile1024i src1,
                           __tile1024i src2) {
   dst->tile = _tile_dpbssd_internal(src1.row, src2.col, src1.col, dst->tile,
                                     src1.tile, src2.tile);
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to