LuoYuanke created this revision.
Herald added subscribers: llvm-commits, cfe-commits, hiraditya, mgorny, 
qcolombet, MatzeB.
Herald added projects: clang, LLVM.
LuoYuanke requested review of this revision.
Herald added a subscriber: jdoerfert.

Change-Id: I935e1080916ffcb72af54c2c83faa8b2e97d5cb0


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D87981

Files:
  clang/include/clang/Basic/BuiltinsX86_64.def
  clang/lib/Headers/amxintrin.h
  clang/test/CodeGen/AMX/amx_api.c
  llvm/include/llvm/CodeGen/LiveIntervalUnion.h
  llvm/include/llvm/CodeGen/LiveRegMatrix.h
  llvm/include/llvm/CodeGen/Passes.h
  llvm/include/llvm/CodeGen/TileShapeInfo.h
  llvm/include/llvm/CodeGen/VirtRegMap.h
  llvm/include/llvm/IR/Intrinsics.td
  llvm/include/llvm/IR/IntrinsicsX86.td
  llvm/lib/CodeGen/InlineSpiller.cpp
  llvm/lib/CodeGen/LiveIntervalUnion.cpp
  llvm/lib/CodeGen/LiveRegMatrix.cpp
  llvm/lib/CodeGen/VirtRegMap.cpp
  llvm/lib/IR/Function.cpp
  llvm/lib/Target/X86/CMakeLists.txt
  llvm/lib/Target/X86/X86.h
  llvm/lib/Target/X86/X86ExpandPseudo.cpp
  llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
  llvm/lib/Target/X86/X86ISelLowering.cpp
  llvm/lib/Target/X86/X86InstrAMX.td
  llvm/lib/Target/X86/X86InstrInfo.cpp
  llvm/lib/Target/X86/X86LowerAMXType.cpp
  llvm/lib/Target/X86/X86MachineFunctionInfo.h
  llvm/lib/Target/X86/X86RegisterInfo.cpp
  llvm/lib/Target/X86/X86RegisterInfo.h
  llvm/lib/Target/X86/X86RegisterInfo.td
  llvm/lib/Target/X86/X86Subtarget.h
  llvm/lib/Target/X86/X86TargetMachine.cpp
  llvm/lib/Target/X86/X86TileConfig.cpp
  llvm/test/CodeGen/X86/AMX/amx-config.ll
  llvm/test/CodeGen/X86/AMX/amx-spill.ll
  llvm/test/CodeGen/X86/AMX/amx-type.ll
  llvm/utils/TableGen/IntrinsicEmitter.cpp

Index: llvm/utils/TableGen/IntrinsicEmitter.cpp
===================================================================
--- llvm/utils/TableGen/IntrinsicEmitter.cpp
+++ llvm/utils/TableGen/IntrinsicEmitter.cpp
@@ -246,7 +246,8 @@
   IIT_SUBDIVIDE4_ARG = 45,
   IIT_VEC_OF_BITCASTS_TO_INT = 46,
   IIT_V128 = 47,
-  IIT_BF16 = 48
+  IIT_BF16 = 48,
+  IIT_V256 = 49
 };
 
 static void EncodeFixedValueType(MVT::SimpleValueType VT,
@@ -384,6 +385,7 @@
     case 32: Sig.push_back(IIT_V32); break;
     case 64: Sig.push_back(IIT_V64); break;
     case 128: Sig.push_back(IIT_V128); break;
+    case 256: Sig.push_back(IIT_V256); break;
     case 512: Sig.push_back(IIT_V512); break;
     case 1024: Sig.push_back(IIT_V1024); break;
     }
Index: llvm/test/CodeGen/X86/AMX/amx-type.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/X86/AMX/amx-type.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -lower-amx-type %s -S | FileCheck %s
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+%struct.__tile_str = type { i16, i16, <256 x i32> }
+
+@buf = dso_local global [1024 x i8] zeroinitializer, align 16
+@buf2 = dso_local global [1024 x i8] zeroinitializer, align 16
+
+define dso_local void @test_load(i8* %in, i8* %out) local_unnamed_addr #2 {
+; CHECK-LABEL: @test_load(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[IN:%.*]] to <256 x i32>*
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast i8* [[OUT:%.*]] to <256 x i32>*
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <256 x i32>* [[TMP1]] to <128 x i32>*
+; CHECK-NEXT:    [[TMP4:%.*]] = load <128 x i32>, <128 x i32>* [[TMP3]], align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP3]], i32 1
+; CHECK-NEXT:    [[TMP6:%.*]] = load <128 x i32>, <128 x i32>* [[TMP5]], align 64
+; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <256 x i32>* [[TMP2]] to <128 x i32>*
+; CHECK-NEXT:    store <128 x i32> [[TMP4]], <128 x i32>* [[TMP7]], align 64
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP7]], i32 1
+; CHECK-NEXT:    store <128 x i32> [[TMP6]], <128 x i32>* [[TMP8]], align 64
+; CHECK-NEXT:    ret void
+;
+  %1 = bitcast i8* %in to <256 x i32>*
+  %2 = bitcast i8* %out to <256 x i32>*
+  %3 = load <256 x i32>, <256 x i32>* %1, align 64, !tbaa !8
+  store <256 x i32> %3, <256 x i32>* %2, align 64, !tbaa !8
+  ret void
+}
+
+define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i64 %2) local_unnamed_addr #0 {
+; CHECK-LABEL: @__tile_loadd(
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2:!tbaa !.*]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7:!tbaa !.*]]
+; CHECK-NEXT:    [[TMP8:%.*]] = shl i64 [[TMP2:%.*]], 32
+; CHECK-NEXT:    [[TMP9:%.*]] = ashr exact i64 [[TMP8]], 32
+; CHECK-NEXT:    [[TMP10:%.*]] = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP1:%.*]], i64 [[TMP9]]) [[ATTR3:#.*]]
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64, <256 x i32> [[TMP10]])
+; CHECK-NEXT:    ret void
+;
+  %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 0
+  %5 = load i16, i16* %4, align 64, !tbaa !2
+  %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 1
+  %7 = load i16, i16* %6, align 2, !tbaa !7
+  %8 = shl i64 %2, 32
+  %9 = ashr exact i64 %8, 32
+  %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9) #3
+  %11 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
+  store <256 x i32> %10, <256 x i32>* %11, align 64, !tbaa !8
+  ret void
+}
+
+define dso_local void @__tile_dpbsud(%struct.__tile_str* nocapture %0, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #0 {
+; CHECK-LABEL: @__tile_dpbsud(
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2, [[TBAA7]]
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
+; CHECK-NEXT:    [[TMP12:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64)
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP14:%.*]] = bitcast <256 x i32>* [[TMP13]] to i8*
+; CHECK-NEXT:    [[TMP15:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64)
+; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP16]] to i8*
+; CHECK-NEXT:    [[TMP18:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64)
+; CHECK-NEXT:    [[TMP19:%.*]] = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], <256 x i32> [[TMP12]], <256 x i32> [[TMP15]], <256 x i32> [[TMP18]]) [[ATTR3]]
+; CHECK-NEXT:    [[TMP20:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, <256 x i32> [[TMP19]])
+; CHECK-NEXT:    ret void
+;
+  %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0
+  %5 = load i16, i16* %4, align 64, !tbaa !2
+  %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1
+  %7 = load i16, i16* %6, align 2, !tbaa !7
+  %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 1
+  %9 = load i16, i16* %8, align 2, !tbaa !7
+  %10 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
+  %11 = load <256 x i32>, <256 x i32>* %10, align 64, !tbaa !8
+  %12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2
+  %13 = load <256 x i32>, <256 x i32>* %12, align 64, !tbaa !8
+  %14 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
+  %15 = load <256 x i32>, <256 x i32>* %14, align 64, !tbaa !8
+  %16 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, <256 x i32> %11, <256 x i32> %13, <256 x i32> %15) #3
+  store <256 x i32> %16, <256 x i32>* %10, align 64, !tbaa !8
+  ret void
+}
+
+define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #1 {
+; CHECK-LABEL: @__tile_stored(
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP8]] to i8*
+; CHECK-NEXT:    [[TMP10:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP9]], i64 64)
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP1:%.*]], 32
+; CHECK-NEXT:    [[TMP12:%.*]] = ashr exact i64 [[TMP11]], 32
+; CHECK-NEXT:    tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP12]], <256 x i32> [[TMP10]]) [[ATTR3]]
+; CHECK-NEXT:    ret void
+;
+  %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0
+  %5 = load i16, i16* %4, align 64, !tbaa !2
+  %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1
+  %7 = load i16, i16* %6, align 2, !tbaa !7
+  %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
+  %9 = load <256 x i32>, <256 x i32>* %8, align 64, !tbaa !8
+  %10 = shl i64 %1, 32
+  %11 = ashr exact i64 %10, 32
+  tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %11, <256 x i32> %9) #3
+  ret void
+}
+
+declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
+declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
+
+attributes #0 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #1 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #2 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+avx,+avx2,+avx512f,+cx8,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #3 = { nounwind }
+
+!llvm.module.flags = !{!0}
+!llvm.ident = !{!1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{!"clang version 12.0.0 (ssh://git-amr-1.devtools.intel.com:29418/dpd_icl-llvm_project_worldread f3c78a3f053379a2511e00e9ce2c13383ea3f835)"}
+!2 = !{!3, !4, i64 0}
+!3 = !{!"__tile_str", !4, i64 0, !4, i64 2, !5, i64 1024}
+!4 = !{!"short", !5, i64 0}
+!5 = !{!"omnipotent char", !6, i64 0}
+!6 = !{!"Simple C/C++ TBAA"}
+!7 = !{!3, !4, i64 2}
+!8 = !{!5, !5, i64 0}
Index: llvm/test/CodeGen/X86/AMX/amx-spill.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/X86/AMX/amx-spill.ll
@@ -0,0 +1,107 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -verify-machineinstrs | FileCheck %s
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+@buf = dso_local global [1024 x i8] zeroinitializer, align 16
+@buf2 = dso_local global [1024 x i8] zeroinitializer, align 16
+
+define dso_local void @test_api(i32 %0, i16 signext %1, i16 signext %2) local_unnamed_addr #2 {
+; CHECK-LABEL: test_api:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    subq $2936, %rsp # imm = 0xB78
+; CHECK-NEXT:    .cfi_def_cfa_offset 2944
+; CHECK-NEXT:    movb %dl, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %dl, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %dl, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %dl, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %si, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    ldtilecfg {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movl $buf, %r8d
+; CHECK-NEXT:    movl $32, %eax
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm1
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm1
+; CHECK-NEXT:    movabsq $64, %rcx
+; CHECK-NEXT:    tilestored %tmm1, 896(%rsp,%rcx) # 1024-byte Folded Spill
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm3
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm4
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm2
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm5
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm0
+; CHECK-NEXT:    testl %edi, %edi
+; CHECK-NEXT:    je .LBB0_2
+; CHECK-NEXT:  # %bb.1:
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm6
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm7
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm1
+; CHECK-NEXT:    jmp .LBB0_3
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    movl $buf2, %ecx
+; CHECK-NEXT:    tileloadd (%rcx,%rax), %tmm6
+; CHECK-NEXT:    tileloadd (%rcx,%rax), %tmm7
+; CHECK-NEXT:    tileloadd (%rcx,%rax), %tmm1
+; CHECK-NEXT:  .LBB0_3:
+; CHECK-NEXT:    tdpbssd %tmm7, %tmm6, %tmm1
+; CHECK-NEXT:    movabsq $64, %rax
+; CHECK-NEXT:    tileloadd 896(%rsp,%rax), %tmm7 # 1024-byte Folded Reload
+; CHECK-NEXT:    tdpbssd %tmm7, %tmm1, %tmm3
+; CHECK-NEXT:    tdpbssd %tmm4, %tmm3, %tmm2
+; CHECK-NEXT:    tdpbssd %tmm5, %tmm2, %tmm0
+; CHECK-NEXT:    movl $buf, %eax
+; CHECK-NEXT:    movl $32, %ecx
+; CHECK-NEXT:    tilestored %tmm0, (%rax,%rcx)
+; CHECK-NEXT:    addq $2936, %rsp # imm = 0xB78
+; CHECK-NEXT:    .cfi_def_cfa_offset 8
+; CHECK-NEXT:    retq
+  %4 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %5 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %6 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %7 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %11 = icmp eq i32 %0, 0
+  br i1 %11, label %16, label %12
+
+12:                                               ; preds = %3
+  %13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %15 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  br label %20
+
+16:                                               ; preds = %3
+  %17 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %18 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %19 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  br label %20
+
+20:                                               ; preds = %16, %12
+  %21 = phi <256 x i32> [ %17, %16 ], [ %13, %12 ]
+  %22 = phi <256 x i32> [ %18, %16 ], [ %14, %12 ]
+  %23 = phi <256 x i32> [ %19, %16 ], [ %15, %12 ]
+  %24 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %1, <256 x i32> %23, <256 x i32> %21, <256 x i32> %22) #3
+  %25 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %6, <256 x i32> %24, <256 x i32> %5) #3
+  %26 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %8, <256 x i32> %25, <256 x i32> %7) #3
+  %27 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %2, i16 %2, i16 %2, <256 x i32> %10, <256 x i32> %26, <256 x i32> %9) #3
+  tail call void @llvm.x86.tilestored64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %27) #3
+  ret void
+}
+
+declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
+declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
+
+attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #3 = { nounwind }
Index: llvm/test/CodeGen/X86/AMX/amx-config.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/X86/AMX/amx-config.ll
@@ -0,0 +1,72 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -verify-machineinstrs | FileCheck %s
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+@buf = dso_local global [1024 x i8] zeroinitializer, align 16
+@buf2 = dso_local global [1024 x i8] zeroinitializer, align 16
+
+; Function Attrs: nounwind uwtable
+define dso_local void @test_api(i32 %0, i16 signext %1, i16 signext %2) local_unnamed_addr #2 {
+; CHECK-LABEL: test_api:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movsbl %sil, %eax
+; CHECK-NEXT:    movb %al, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %si, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %al, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %al, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    ldtilecfg -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    testl %edi, %edi
+; CHECK-NEXT:    je .LBB0_2
+; CHECK-NEXT:  # %bb.1:
+; CHECK-NEXT:    movl $buf, %ecx
+; CHECK-NEXT:    jmp .LBB0_3
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    movl $buf2, %ecx
+; CHECK-NEXT:  .LBB0_3:
+; CHECK-NEXT:    movl $32, %edi
+; CHECK-NEXT:    tileloadd (%rcx,%rdi), %tmm0
+; CHECK-NEXT:    tileloadd (%rcx,%rdi), %tmm2
+; CHECK-NEXT:    tileloadd (%rcx,%rdi), %tmm1
+; CHECK-NEXT:    tdpbssd %tmm2, %tmm0, %tmm1
+; CHECK-NEXT:    movl $buf, %ecx
+; CHECK-NEXT:    movl $32, %esi
+; CHECK-NEXT:    tilestored %tmm1, (%rcx,%rsi)
+; CHECK-NEXT:    retq
+  %4 = icmp eq i32 %0, 0
+  %5 = shl i16 %1, 8
+  %6 = ashr exact i16 %5, 8
+  br i1 %4, label %11, label %7
+
+7:                                                ; preds = %3
+  %8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  br label %15
+
+11:                                               ; preds = %3
+  %12 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  br label %15
+
+15:                                               ; preds = %11, %7
+  %16 = phi <256 x i32> [ %12, %11 ], [ %8, %7 ]
+  %17 = phi <256 x i32> [ %13, %11 ], [ %9, %7 ]
+  %18 = phi <256 x i32> [ %14, %11 ], [ %10, %7 ]
+  %19 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %6, i16 %2, i16 %1, <256 x i32> %18, <256 x i32> %16, <256 x i32> %17) #3
+  tail call void @llvm.x86.tilestored64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %19) #3
+  ret void
+}
+
+declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
+
+declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
+
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
+
+attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+avx,+avx2,+avx512f,+cx8,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #3 = { nounwind }
Index: llvm/lib/Target/X86/X86TileConfig.cpp
===================================================================
--- /dev/null
+++ llvm/lib/Target/X86/X86TileConfig.cpp
@@ -0,0 +1,293 @@
+//===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+
+#include "X86.h"
+#include "X86InstrBuilder.h"
+#include "X86MachineFunctionInfo.h"
+#include "X86RegisterInfo.h"
+#include "X86Subtarget.h"
+#include "llvm/CodeGen/LiveIntervals.h"
+#include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/CodeGen/TileShapeInfo.h"
+#include "llvm/CodeGen/VirtRegMap.h"
+#include "llvm/InitializePasses.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "tile-config"
+
+namespace {
+
+class X86TileConfig : public MachineFunctionPass {
+  // context
+  MachineFunction *MF = nullptr;
+  const X86Subtarget *ST = nullptr;
+  const TargetRegisterInfo *TRI;
+  const TargetInstrInfo *TII;
+  MachineDominatorTree *DomTree = nullptr;
+  MachineRegisterInfo *MRI = nullptr;
+  VirtRegMap *VRM = nullptr;
+  LiveIntervals *LIS = nullptr;
+
+  MachineInstr &getTileConfigPoint();
+  void tileConfig();
+
+public:
+  X86TileConfig() : MachineFunctionPass(ID) {}
+
+  /// Return the pass name.
+  StringRef getPassName() const override { return "Tile Configure"; }
+
+  /// X86TileConfig analysis usage.
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+  /// Perform register allocation.
+  bool runOnMachineFunction(MachineFunction &mf) override;
+
+  MachineFunctionProperties getRequiredProperties() const override {
+    return MachineFunctionProperties().set(
+        MachineFunctionProperties::Property::NoPHIs);
+  }
+
+  static char ID;
+};
+
+} // end anonymous namespace
+
+char X86TileConfig::ID = 0;
+
+INITIALIZE_PASS_BEGIN(X86TileConfig, "tileconfig", "Tile Register Configure",
+                      false, false)
+INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
+INITIALIZE_PASS_DEPENDENCY(VirtRegMap)
+INITIALIZE_PASS_END(X86TileConfig, "tileconfig", "Tile Register Configure",
+                    false, false)
+
+void X86TileConfig::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<MachineDominatorTree>();
+  AU.addRequired<LiveIntervals>();
+  AU.addPreserved<SlotIndexes>();
+  AU.addRequired<VirtRegMap>();
+  AU.setPreservesAll();
+  MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+static unsigned getTilePhysRegIndex(Register PhysReg) {
+  assert((PhysReg >= X86::TMM0 && X86::TMM0 <= X86::TMM7) &&
+         "Tile register number is invalid");
+  return (PhysReg - X86::TMM0);
+}
+
+static MachineInstr *buildConfigMI(MachineBasicBlock &MBB,
+                                   MachineBasicBlock::iterator MI, int FrameIdx,
+                                   const TargetInstrInfo *TII) {
+  return addFrameReference(
+      BuildMI(MBB, MI, DebugLoc(), TII->get(X86::LDTILECFG)), FrameIdx);
+}
+
+static MachineInstr *
+storeRegToStackSlot(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI,
+                    Register SrcReg, unsigned BitSize, int FrameIdx, int Offset,
+                    const TargetInstrInfo *TII, const TargetRegisterClass *RC,
+                    const TargetRegisterInfo *TRI) {
+
+  unsigned SubIdx = (BitSize == 8) ? X86::sub_8bit : X86::sub_16bit;
+  unsigned Opc = (BitSize == 8) ? X86::MOV8mr : X86::MOV16mr;
+  MachineInstr *NewMI =
+      addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), FrameIdx,
+                        Offset)
+          .addReg(SrcReg);
+  MachineOperand &MO = NewMI->getOperand(5);
+  if (BitSize < TRI->getRegSizeInBits(*RC))
+    MO.setSubReg(SubIdx);
+  return NewMI;
+}
+
+static MachineInstr *storeImmToStackSlot(MachineBasicBlock &MBB,
+                                         MachineBasicBlock::iterator MI,
+                                         int64_t Imm, unsigned BitSize,
+                                         int FrameIdx, int Offset,
+                                         const TargetInstrInfo *TII) {
+  unsigned Opc = (BitSize == 8) ? X86::MOV8mi : X86::MOV16mi;
+  return addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)),
+                           FrameIdx, Offset)
+      .addImm(Imm);
+}
+
+MachineInstr &
+X86TileConfig::getTileConfigPoint() {
+  DenseMap<Register, ShapeT> PhysShapeInfo;
+  MachineBasicBlock *MBB = nullptr;
+  DenseSet<const MachineInstr *> MIs;
+  for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) {
+    unsigned VirtReg = Register::index2VirtReg(i);
+    if (MRI->reg_nodbg_empty(VirtReg))
+      continue;
+    const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
+    if (RC.getID() != X86::TILERegClassID)
+      continue;
+
+    // FIXME: The region split should be done before the Greedy RA.
+    // Here we assume only one config for all tile registers.
+    //
+    // Find the common dominator for all MI that define tile register.
+    for (const MachineOperand &MO : MRI->def_operands(VirtReg)) {
+      if (MO.isUndef())
+        continue;
+      auto *MI = MO.getParent();
+      if (!MBB)
+        MBB = const_cast<MachineBasicBlock *>(MI->getParent());
+      MBB = DomTree->findNearestCommonDominator(
+          MBB, const_cast<MachineBasicBlock *>(MI->getParent()));
+    }
+    // Collect the instructions that define shape.
+    ShapeT Shape = VRM->getShape(VirtReg);
+    std::array<MachineOperand *, 2> ShapeMOs = {Shape.getRow(), Shape.getCol()};
+    for (auto *ShapeMO : ShapeMOs) {
+      Register ShapeReg = ShapeMO->getReg();
+      for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) {
+        auto *ShapeMI = MO.getParent();
+        MIs.insert(ShapeMI);
+      }
+    }
+
+#if !defined(NDEBUG)
+    Register PhysReg = VRM->getPhys(VirtReg);
+    if (PhysShapeInfo.count(PhysReg))
+      assert(PhysShapeInfo[PhysReg] == VRM->getShape(VirtReg) &&
+             "The physical register is assigned to virtual registers"
+             "with different shape");
+#endif
+  }
+  // Shape def should dominate tile config MBB.
+  // TODO: Improve for shape that is immediate.
+  for (auto *MI : MIs) {
+    const MachineBasicBlock *ShapeMBB = MI->getParent();
+    if (DomTree->dominates(ShapeMBB, MBB))
+      continue;
+    if (MI->isMoveImmediate())
+      continue;
+    report_fatal_error("Failed to config tile register, "
+                       "please define the shape earlier");
+  }
+
+  // ldtilecfg should be inserted after the MI that define the shape.
+  MachineBasicBlock::reverse_instr_iterator I, E;
+  for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) {
+    auto *MI = &*I;
+    if (MIs.count(MI) && (!MI->isMoveImmediate()))
+      break;
+  }
+  MachineBasicBlock::iterator MII;
+  if (I == E)
+    MII = MBB->getFirstNonPHI();
+  else {
+    MII = MachineBasicBlock::iterator(&*I);
+    MII++;
+  }
+  return *MII;
+}
+
+void X86TileConfig::tileConfig() {
+  MachineInstr &MI = getTileConfigPoint();
+  MachineBasicBlock *MBB = MI.getParent();
+  // Allocate stack buffer to config
+  unsigned Size = ST->getTileConfigSize();
+  Align Alignment = ST->getTileConfigAlignment();
+
+  int SS = MF->getFrameInfo().CreateStackObject(Size, Alignment, false);
+  BitVector PhysRegs(TRI->getNumRegs());
+
+  // Insert ldtilecfg to the MBB
+  for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) {
+    unsigned VirtReg = Register::index2VirtReg(i);
+    if (MRI->reg_nodbg_empty(VirtReg))
+      continue;
+    const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
+    if (RC.getID() != X86::TILERegClassID)
+      continue;
+    Register PhysReg = VRM->getPhys(VirtReg);
+    if (PhysRegs.test(PhysReg))
+      continue;
+    PhysRegs.set(PhysReg);
+    ShapeT Shape = VRM->getShape(VirtReg);
+    Register RowReg = Shape.getRow()->getReg();
+    Register ColReg = Shape.getCol()->getReg();
+
+    unsigned Index = getTilePhysRegIndex(PhysReg);
+    int RowOffset = 48 + Index;
+    int ColOffset = 16 + Index * 2;
+
+    unsigned BitSize = 8;
+    for (auto &Pair : {std::make_pair(RowReg, RowOffset),
+                       std::make_pair(ColReg, ColOffset)}) {
+      int64_t Imm;
+      int ImmCount = 0;
+      // All def must be the same value, otherwise it is invalid MIs.
+      // Immediate is prefered.
+      for (const MachineOperand &MO : MRI->def_operands(Pair.first)) {
+        auto *Inst = MO.getParent();
+        if (Inst->isMoveImmediate()) {
+          ImmCount++;
+          Imm = Inst->getOperand(1).getImm();
+          break;
+        }
+      }
+      auto StoreConfig = [&](int Offset) {
+        MachineInstr *NewMI = nullptr;
+        if (ImmCount)
+          NewMI = storeImmToStackSlot(*MBB, MI, Imm, BitSize, SS, Offset, TII);
+        else {
+          const TargetRegisterClass *RC = MRI->getRegClass(Pair.first);
+          NewMI = storeRegToStackSlot(*MBB, MI, Pair.first, BitSize, SS, Offset,
+                                      TII, RC, TRI);
+        }
+        SlotIndex SIdx = LIS->InsertMachineInstrInMaps(*NewMI);
+        if (!ImmCount) {
+          // Extend the live interval.
+          SmallVector<SlotIndex, 8> EndPoints = {SIdx.getRegSlot()};
+          LiveInterval &Int = LIS->getInterval(Pair.first);
+          LIS->extendToIndices(Int, EndPoints);
+        }
+      };
+      StoreConfig(Pair.second);
+      BitSize += 8;
+    }
+  }
+  MachineInstr *NewMI = buildConfigMI(*MBB, MI, SS, TII);
+  LIS->InsertMachineInstrInMaps(*NewMI);
+}
+
+bool X86TileConfig::runOnMachineFunction(MachineFunction &mf) {
+  LLVM_DEBUG(dbgs() << "********** TILE REGISTER CONFIGURE**********\n"
+                    << "********** Function: " << mf.getName() << '\n');
+  MF = &mf;
+  MRI = &mf.getRegInfo();
+  ST = &mf.getSubtarget<X86Subtarget>();
+  TRI = ST->getRegisterInfo();
+  TII = mf.getSubtarget().getInstrInfo();
+  DomTree = &getAnalysis<MachineDominatorTree>();
+  VRM = &getAnalysis<VirtRegMap>();
+  LIS = &getAnalysis<LiveIntervals>();
+
+  if (VRM->isShapeMapEmpty())
+    return false;
+
+  tileConfig();
+  return true;
+}
+
+FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); }
Index: llvm/lib/Target/X86/X86TargetMachine.cpp
===================================================================
--- llvm/lib/Target/X86/X86TargetMachine.cpp
+++ llvm/lib/Target/X86/X86TargetMachine.cpp
@@ -62,6 +62,7 @@
   RegisterTargetMachine<X86TargetMachine> Y(getTheX86_64Target());
 
   PassRegistry &PR = *PassRegistry::getPassRegistry();
+  initializeX86LowerAMXTypeLegacyPassPass(PR);
   initializeGlobalISel(PR);
   initializeWinEHStatePassPass(PR);
   initializeFixupBWInstPassPass(PR);
@@ -71,6 +72,7 @@
   initializeX86FixupSetCCPassPass(PR);
   initializeX86CallFrameOptimizationPass(PR);
   initializeX86CmovConverterPassPass(PR);
+  initializeX86TileConfigPass(PR);
   initializeX86ExpandPseudoPass(PR);
   initializeX86ExecutionDomainFixPass(PR);
   initializeX86DomainReassignmentPass(PR);
@@ -378,6 +380,7 @@
   void addPreEmitPass() override;
   void addPreEmitPass2() override;
   void addPreSched2() override;
+  bool addPreRewrite() override;
 
   std::unique_ptr<CSEConfigBase> getCSEConfig() const override;
 };
@@ -406,6 +409,7 @@
 
 void X86PassConfig::addIRPasses() {
   addPass(createAtomicExpandPass());
+  addPass(createX86LowerAMXTypePass());
 
   TargetPassConfig::addIRPasses();
 
@@ -564,6 +568,11 @@
   addPass(createX86LoadValueInjectionRetHardeningPass());
 }
 
+bool X86PassConfig::addPreRewrite() {
+  addPass(createX86TileConfigPass());
+  return true;
+}
+
 std::unique_ptr<CSEConfigBase> X86PassConfig::getCSEConfig() const {
   return getStandardCSEConfigForOpt(TM->getOptLevel());
 }
Index: llvm/lib/Target/X86/X86Subtarget.h
===================================================================
--- llvm/lib/Target/X86/X86Subtarget.h
+++ llvm/lib/Target/X86/X86Subtarget.h
@@ -457,6 +457,8 @@
   /// entry to the function and which must be maintained by every function.
   Align stackAlignment = Align(4);
 
+  Align tileConfigAlignment = Align(4);
+
   /// Max. memset / memcpy size that is turned into rep/movs, rep/stos ops.
   ///
   // FIXME: this is a known good value for Yonah. How about others?
@@ -540,6 +542,9 @@
     return &getInstrInfo()->getRegisterInfo();
   }
 
+  unsigned getTileConfigSize() const { return 64; }
+  Align getTileConfigAlignment() const { return tileConfigAlignment; }
+
   /// Returns the minimum alignment known to hold of the
   /// stack frame on entry to the function and which must be maintained by every
   /// function for this subtarget.
Index: llvm/lib/Target/X86/X86RegisterInfo.td
===================================================================
--- llvm/lib/Target/X86/X86RegisterInfo.td
+++ llvm/lib/Target/X86/X86RegisterInfo.td
@@ -633,6 +633,6 @@
 def BNDR : RegisterClass<"X86", [v2i64], 128, (sequence "BND%u", 0, 3)>;
 
 // Tiles
-let isAllocatable = 0 in
-def TILE : RegisterClass<"X86", [untyped], 0,
+let CopyCost = -1 in // Don't allow copy tile register
+def TILE : RegisterClass<"X86", [v256i32], 8192,
                          (sequence "TMM%u", 0, 7)> {let Size = 8192;}
Index: llvm/lib/Target/X86/X86RegisterInfo.h
===================================================================
--- llvm/lib/Target/X86/X86RegisterInfo.h
+++ llvm/lib/Target/X86/X86RegisterInfo.h
@@ -141,6 +141,11 @@
   Register getFramePtr() const { return FramePtr; }
   // FIXME: Move to FrameInfok
   unsigned getSlotSize() const { return SlotSize; }
+
+  bool getRegAllocationHints(Register VirtReg, ArrayRef<MCPhysReg> Order,
+                             SmallVectorImpl<MCPhysReg> &Hints,
+                             const MachineFunction &MF, const VirtRegMap *VRM,
+                             const LiveRegMatrix *Matrix) const override;
 };
 
 } // End llvm namespace
Index: llvm/lib/Target/X86/X86RegisterInfo.cpp
===================================================================
--- llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -16,6 +16,7 @@
 #include "X86FrameLowering.h"
 #include "X86MachineFunctionInfo.h"
 #include "X86Subtarget.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
@@ -24,6 +25,7 @@
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/TargetFrameLowering.h"
 #include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/LiveRegMatrix.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Type.h"
@@ -812,3 +814,75 @@
     StackReg = getX86SubSuperRegister(StackReg, 32);
   return StackReg;
 }
+
+static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,
+                           const MachineRegisterInfo *MRI) {
+  if (VRM->hasShape(VirtReg))
+    return VRM->getShape(VirtReg);
+
+  const MachineOperand &Def = *MRI->def_begin(VirtReg);
+  MachineInstr* MI = const_cast<MachineInstr *>(Def.getParent());
+  unsigned OpCode = MI->getOpcode();
+  switch (OpCode) {
+  default:
+    llvm_unreachable("Unexpected machine instruction on tile register!");
+    break;
+  // We only collect the tile shape that is defined.
+  case X86::PTILELOADDV:
+  case X86::PTDPBSSDV:
+    MachineOperand &MO1 = MI->getOperand(1);
+    MachineOperand &MO2 = MI->getOperand(2);
+    ShapeT Shape(&MO1, &MO2, MRI);
+    VRM->assignVirt2Shape(VirtReg, Shape);
+    return Shape;
+  }
+}
+
+bool X86RegisterInfo::getRegAllocationHints(
+    Register VirtReg, ArrayRef<MCPhysReg> Order,
+    SmallVectorImpl<MCPhysReg> &Hints, const MachineFunction &MF,
+    const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const {
+  const MachineRegisterInfo *MRI = &MF.getRegInfo();
+  const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
+  bool BaseImplRetVal = TargetRegisterInfo::getRegAllocationHints(
+      VirtReg, Order, Hints, MF, VRM, Matrix);
+
+  if(RC.getID() != X86::TILERegClassID)
+    return BaseImplRetVal;
+
+  ShapeT VirtShape = getTileShape(VirtReg, const_cast<VirtRegMap *>(VRM), MRI);
+  auto addHint = [&](MCPhysReg PhysReg) {
+    Register VReg = Matrix->getOneVReg(PhysReg);
+    if (VReg == MCRegister::NoRegister) { // Not allocated yet
+      Hints.push_back(PhysReg);
+      return;
+    }
+    ShapeT PhysShape = getTileShape(VReg, const_cast<VirtRegMap *>(VRM), MRI);
+    if (PhysShape == VirtShape)
+      Hints.push_back(PhysReg);
+  };
+
+  SmallSet<unsigned, 4> CopyHints;
+  CopyHints.insert(Hints.begin(), Hints.end());
+  Hints.clear();
+  for (auto Hint : CopyHints) {
+    if (RC.contains(Hint) && !MRI->isReserved(Hint))
+      addHint(Hint);
+  }
+  for (MCPhysReg PhysReg : Order) {
+    if (!MRI->isReserved(PhysReg))
+      addHint(PhysReg);
+  }
+
+#define DEBUG_TYPE "tile-hint"
+  LLVM_DEBUG({
+    dbgs() << "Hints for virtual register " << format_hex(VirtReg, 8) << "\n";
+    for (auto Hint : Hints) {
+      dbgs() << "tmm" << Hint << ",";
+    }
+    dbgs() << "\n";
+  });
+#undef DEBUG_TYPE
+
+  return true;
+}
Index: llvm/lib/Target/X86/X86MachineFunctionInfo.h
===================================================================
--- llvm/lib/Target/X86/X86MachineFunctionInfo.h
+++ llvm/lib/Target/X86/X86MachineFunctionInfo.h
@@ -17,6 +17,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/TileShapeInfo.h"
 
 namespace llvm {
 
Index: llvm/lib/Target/X86/X86LowerAMXType.cpp
===================================================================
--- /dev/null
+++ llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -0,0 +1,270 @@
+#include "X86.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/ValueTypes.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsX86.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "lower-amx-type"
+
+namespace {
+class X86LowerAMXType {
+  Function &Func;
+  const DataLayout &DL;
+  DenseSet<Instruction *> LDSet;
+  DenseSet<Instruction *> STSet;
+  DenseMap<Value *, std::pair<LoadInst *, LoadInst *>> LoadMap;
+
+public:
+  X86LowerAMXType(Function &F) : Func(F), DL(F.getParent()->getDataLayout()) {}
+  bool Visit();
+  bool VisitLD();
+  bool VisitST();
+  void SplitST(Instruction *Inst);
+  void SplitLD(Instruction *Inst);
+};
+
+// Split v256i32 load/store to 2 v128i32, so that ISel can
+// lower it to proper vector size.
+void X86LowerAMXType::SplitST(Instruction *Inst) {
+  StoreInst *ST = dyn_cast<StoreInst>(Inst);
+  IRBuilder<> Builder(ST);
+  LLVMContext &Ctx = Builder.getContext();
+  Type *Ty = ST->getValueOperand()->getType();
+  EVT VT = EVT::getEVT(Ty);
+  EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx);
+  Type *HalfTy = HalfVT.getTypeForEVT(Ctx);
+
+  LoadInst *Lo, *Hi;
+  std::tie(Lo, Hi) = LoadMap[ST->getValueOperand()];
+  Value *Ptr = ST->getPointerOperand();
+  PointerType *HalfPtrTy = HalfTy->getPointerTo(ST->getPointerAddressSpace());
+  Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy);
+  Builder.CreateAlignedStore(Lo, HalfPtr, Lo->getAlign(), ST->isVolatile());
+
+  HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1));
+  Builder.CreateAlignedStore(Hi, HalfPtr, Hi->getAlign(), ST->isVolatile());
+}
+
+bool X86LowerAMXType::VisitST() {
+  if (STSet.empty())
+    return false;
+  for (auto *Inst : STSet) {
+    Value *Row, *Col;
+    const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst->getOperand(0));
+    if (!II)
+      Row = Col = nullptr;
+    else {
+      switch (II->getIntrinsicID()) {
+      default:
+        Row = Col = nullptr;
+        break;
+      case Intrinsic::x86_tileloadd64_internal:
+      case Intrinsic::x86_tdpbssd_internal: {
+        Row = II->getArgOperand(0);
+        Col = II->getArgOperand(1);
+        break;
+      }
+      }
+    }
+    if (!Row) {
+      SplitST(Inst);
+      continue;
+    }
+    IRBuilder<> Builder(Inst);
+    LLVMContext &Ctx = Builder.getContext();
+    // Use the maximun column as stride. It must be the same with load stride.
+    Value *Stride = Builder.getInt64(64);
+    Value *I8Ptr =
+        Builder.CreateBitCast(Inst->getOperand(1), Type::getInt8PtrTy(Ctx));
+    std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride,
+                                   Inst->getOperand(0)};
+
+    Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
+  }
+  return true;
+}
+
+void X86LowerAMXType::SplitLD(Instruction *Inst) {
+  LoadInst *LD = dyn_cast<LoadInst>(Inst);
+  IRBuilder<> Builder(LD);
+  LLVMContext &Ctx = Builder.getContext();
+  Type *Ty = LD->getType();
+  EVT VT = EVT::getEVT(Ty);
+  EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx);
+  Type *HalfTy = HalfVT.getTypeForEVT(Ctx);
+
+  Value *Ptr = LD->getPointerOperand();
+  PointerType *HalfPtrTy = HalfTy->getPointerTo(LD->getPointerAddressSpace());
+  Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy);
+  assert(LD->getAlign() <= 512);
+  auto *Lo = Builder.CreateAlignedLoad(HalfTy, HalfPtr, LD->getAlign(),
+                                       LD->isVolatile());
+
+  HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1));
+  auto *Hi = Builder.CreateAlignedLoad(HalfTy, HalfPtr, LD->getAlign(),
+                                       LD->isVolatile());
+
+  LoadMap[Inst] = std::make_pair(Lo, Hi);
+}
+
+bool X86LowerAMXType::VisitLD() {
+  if (LDSet.empty())
+    return false;
+  for (auto &Inst : LDSet) {
+    int Count = 0;
+    Value *NewInst = nullptr;
+    // The user should be all AMX intrinsics or all LLVM instruction.
+    // Don't support it is used by both AMX intrinsics and LLVM instructions.
+    for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
+      Use &U = *I++;
+      const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U.getUser());
+      if (!II) {
+        Count++;
+        continue;
+      }
+      if (NewInst)
+        continue;
+      Value *Row, *Col;
+      switch (II->getIntrinsicID()) {
+      default:
+        report_fatal_error("Non-AMX intrinsic use tile type.");
+        break;
+      case Intrinsic::x86_tdpbssd_internal: {
+        unsigned OpNo = U.getOperandNo();
+        switch (OpNo) {
+        case 3:
+          Row = II->getArgOperand(0);
+          Col = II->getArgOperand(1);
+          break;
+        case 4:
+          Row = II->getArgOperand(0);
+          Col = II->getArgOperand(2);
+          break;
+        case 5:
+          Row = II->getArgOperand(2);
+          Col = II->getArgOperand(1);
+          break;
+        }
+        break;
+      }
+      case Intrinsic::x86_tilestored64_internal: {
+        Row = II->getArgOperand(0);
+        Col = II->getArgOperand(1);
+        break;
+      }
+      }
+      assert(Count == 0 && "Can NOT mix amx intrinsic and LLVM instruction");
+      // FIXME: The shape def should be ahead of load.
+      IRBuilder<> Builder(Inst);
+      LLVMContext &Ctx = Builder.getContext();
+      // Use the maximun column as stride.
+      Value *Stride = Builder.getInt64(64);
+      Value *I8Ptr =
+          Builder.CreateBitCast(Inst->getOperand(0), Type::getInt8PtrTy(Ctx));
+      std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
+
+      NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
+                                        None, Args);
+
+      Inst->replaceAllUsesWith(NewInst);
+    }
+    if (!NewInst)
+      SplitLD(Inst);
+  }
+  return true;
+}
+
+bool X86LowerAMXType::Visit() {
+  bool C;
+  auto IsAMXType = [](FixedVectorType *VTy) {
+    if (!VTy)
+      return false;
+    if (!VTy->getScalarType()->isIntegerTy(32))
+      return false;
+    if (VTy->getNumElements() != 256)
+      return false;
+
+    return true;
+  };
+
+  for (BasicBlock &BB : Func) {
+    for (Instruction &Inst : BB) {
+      LoadInst *LD = dyn_cast<LoadInst>(&Inst);
+      // Check load instruction.
+      // %3 = load <256 x i32>, <256 x i32>* %1, align 64
+      if (LD) {
+        FixedVectorType *VTy = dyn_cast<FixedVectorType>(Inst.getType());
+        if (!IsAMXType(VTy))
+          continue;
+        LDSet.insert(&Inst);
+        continue;
+      }
+      // Check store instruction.
+      // store <256 x i32> %3, <256 x i32>* %2, align 64
+      StoreInst *ST = dyn_cast<StoreInst>(&Inst);
+      if (!ST)
+        continue;
+      FixedVectorType *VTy =
+          dyn_cast<FixedVectorType>(ST->getOperand(0)->getType());
+      if (!IsAMXType(VTy))
+        continue;
+      STSet.insert(&Inst);
+    }
+  }
+
+  C = VisitLD() | VisitST();
+  for (auto *Inst : STSet)
+    Inst->eraseFromParent();
+  for (auto *Inst : LDSet)
+    Inst->eraseFromParent();
+  return C;
+}
+} // anonymous namespace
+
+namespace {
+
+class X86LowerAMXTypeLegacyPass : public FunctionPass {
+public:
+  static char ID;
+
+  X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
+    initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override {
+    X86LowerAMXType LAT(F);
+    bool C = LAT.Visit();
+    return C;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+  }
+
+private:
+  Function *F;
+};
+
+} // anonymous namespace
+
+static const char pass_name[] = "Lower AMX type for load/store";
+char X86LowerAMXTypeLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, pass_name, false,
+                      false)
+INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, pass_name, false,
+                    false)
+
+FunctionPass *llvm::createX86LowerAMXTypePass() {
+  return new X86LowerAMXTypeLegacyPass();
+}
Index: llvm/lib/Target/X86/X86InstrInfo.cpp
===================================================================
--- llvm/lib/Target/X86/X86InstrInfo.cpp
+++ llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -3758,13 +3758,27 @@
   const MachineFunction &MF = *MBB.getParent();
   assert(MF.getFrameInfo().getObjectSize(FrameIdx) >= TRI->getSpillSize(*RC) &&
          "Stack slot too small for store");
-  unsigned Alignment = std::max<uint32_t>(TRI->getSpillSize(*RC), 16);
-  bool isAligned =
+  if (RC->getID() != X86::TILERegClassID) {
+    unsigned Alignment = std::max<uint32_t>(TRI->getSpillSize(*RC), 16);
+    bool isAligned =
       (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) ||
       RI.canRealignStack(MF);
-  unsigned Opc = getStoreRegOpcode(SrcReg, RC, isAligned, Subtarget);
-  addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx)
-    .addReg(SrcReg, getKillRegState(isKill));
+    unsigned Opc = getStoreRegOpcode(SrcReg, RC, isAligned, Subtarget);
+    addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx)
+      .addReg(SrcReg, getKillRegState(isKill));
+  } else {
+    unsigned Opc = X86::TILESTORED;
+    // tilestored %tmm, (%sp, %idx)
+    MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();
+    Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass);
+    MachineInstr *NewMI = BuildMI(MBB, MI, DebugLoc(), get(X86::MOV64ri),
+                                  VirtReg).addImm(64);
+    NewMI = addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx)
+              .addReg(SrcReg, getKillRegState(isKill));
+    MachineOperand &MO = NewMI->getOperand(2);
+    MO.setReg(VirtReg);
+    MO.setIsKill(true);
+  }
 }
 
 void X86InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
@@ -3772,13 +3786,27 @@
                                         Register DestReg, int FrameIdx,
                                         const TargetRegisterClass *RC,
                                         const TargetRegisterInfo *TRI) const {
-  const MachineFunction &MF = *MBB.getParent();
-  unsigned Alignment = std::max<uint32_t>(TRI->getSpillSize(*RC), 16);
-  bool isAligned =
+  if (RC->getID() != X86::TILERegClassID) {
+    const MachineFunction &MF = *MBB.getParent();
+    unsigned Alignment = std::max<uint32_t>(TRI->getSpillSize(*RC), 16);
+    bool isAligned =
       (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) ||
       RI.canRealignStack(MF);
-  unsigned Opc = getLoadRegOpcode(DestReg, RC, isAligned, Subtarget);
-  addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg), FrameIdx);
+    unsigned Opc = getLoadRegOpcode(DestReg, RC, isAligned, Subtarget);
+    addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg), FrameIdx);
+  } else {
+    unsigned Opc = X86::TILELOADD;
+    // tileloadd (%sp, %idx), %tmm
+    MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();
+    Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass);
+    MachineInstr *NewMI = BuildMI(MBB, MI, DebugLoc(), get(X86::MOV64ri),
+                                  VirtReg).addImm(64);
+    NewMI = addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg),
+                                      FrameIdx);
+    MachineOperand &MO = NewMI->getOperand(3);
+    MO.setReg(VirtReg);
+    MO.setIsKill(true);
+  }
 }
 
 bool X86InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg,
Index: llvm/lib/Target/X86/X86InstrAMX.td
===================================================================
--- llvm/lib/Target/X86/X86InstrAMX.td
+++ llvm/lib/Target/X86/X86InstrAMX.td
@@ -23,6 +23,7 @@
     def STTILECFG : I <0x49, MRM0m, (outs), (ins opaquemem:$src),
                        "sttilecfg\t$src",
                        [(int_x86_sttilecfg addr:$src)]>, VEX, T8PD;
+    let mayLoad = 1 in
     def TILELOADD : I<0x4b, MRMSrcMemFSIB, (outs TILE:$dst),
                       (ins sibmem:$src),
                       "tileloadd\t{$src, $dst|$dst, $src}", []>,
@@ -34,6 +35,7 @@
     let Defs = [TMM0,TMM1,TMM2,TMM3,TMM4,TMM5,TMM6,TMM7] in
     def TILERELEASE : I<0x49, MRM_C0, (outs), (ins),
                         "tilerelease", [(int_x86_tilerelease)]>, VEX, T8PS;
+    let mayStore = 1 in
     def TILESTORED : I<0x4b, MRMDestMemFSIB, (outs),
                        (ins sibmem:$dst, TILE:$src),
                        "tilestored\t{$src, $dst|$dst, $src}", []>,
@@ -42,6 +44,11 @@
                      "tilezero\t$dst", []>,
                      VEX, T8XD;
 
+    def PTILESTOREDV : PseudoI<(outs), (ins GR16:$src1,
+                               GR16:$src2, opaquemem:$src3, TILE:$src4), []>;
+    def PTILELOADDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
+                              GR16:$src2, opaquemem:$src3), []>;
+
     let usesCustomInserter = 1 in {
       // Pseudo instructions, using immediates instead of tile registers.
       // To be translated to the actual instructions in X86ISelLowering.cpp
@@ -76,6 +83,11 @@
                       VEX_4V, T8PS;
     }
 
+    let Constraints = "$src4 = $dst" in
+    def PTDPBSSDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
+                            GR16:$src2, GR16:$src3, TILE:$src4,
+                            TILE:$src5, TILE:$src6), []>;
+
     let usesCustomInserter = 1 in {
       // Pseudo instructions, using immediates instead of tile registers.
       // To be translated to the actual instructions in X86ISelLowering.cpp
Index: llvm/lib/Target/X86/X86ISelLowering.cpp
===================================================================
--- llvm/lib/Target/X86/X86ISelLowering.cpp
+++ llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -1885,6 +1885,10 @@
     setOperationAction(ISD::TRUNCATE, MVT::v16i64, Custom);
   }
 
+  if (Subtarget.hasAMXTILE()) {
+    addRegisterClass(MVT::v256i32, &X86::TILERegClass);
+  }
+
   // We want to custom lower some of our intrinsics.
   setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
@@ -5271,6 +5275,12 @@
   // width.
   if (MemVT.getSizeInBits() > Subtarget.getPreferVectorWidth())
     return false;
+
+  // Don't merge to x86 amx tile, as we only map MVT::v256i32
+  // to x86 amx tile on amx intrinsics.
+  if (MemVT == MVT::v256i32)
+    return false;
+
   return true;
 }
 
Index: llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -4480,6 +4480,45 @@
 
   switch (Opcode) {
   default: break;
+  case ISD::INTRINSIC_W_CHAIN: {
+    unsigned IntNo = Node->getConstantOperandVal(1);
+    switch (IntNo) {
+    default: break;
+    case Intrinsic::x86_tileloadd64_internal: {
+      if (!Subtarget->hasAMXTILE())
+        break;
+      unsigned Opc = X86::PTILELOADDV;
+      // _tile_loadd_internal(row, col, buf, STRIDE)
+      SDValue Base = Node->getOperand(4);
+      SDValue Scale = getI8Imm(1, dl);
+      SDValue Index = Node->getOperand(5);
+      SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32);
+      SDValue Segment = CurDAG->getRegister(0, MVT::i16);
+      SDValue Chain = Node->getOperand(0);
+      MachineSDNode *CNode;
+      SDValue Ops[] = { Node->getOperand(2),
+        Node->getOperand(3),
+        Base, Scale, Index, Disp, Segment, Chain };
+      CNode = CurDAG->getMachineNode(Opc, dl, {MVT::v256i32, MVT::Other}, Ops);
+      ReplaceNode(Node, CNode);
+      return;
+    }
+    case Intrinsic::x86_tdpbssd_internal: {
+      if (!Subtarget->hasAMXTILE())
+        break;
+      unsigned Opc = X86::PTDPBSSDV;
+      SDValue Ops[] = { Node->getOperand(2), Node->getOperand(3),
+        Node->getOperand(4), Node->getOperand(5),
+        Node->getOperand(6), Node->getOperand(7) };
+      MachineSDNode *CNode = CurDAG->getMachineNode(Opc, dl,
+                                                    {MVT::v256i32, MVT::Other},
+                                                    Ops);
+      ReplaceNode(Node, CNode);
+      return;
+    }
+    }
+    break;
+  }
   case ISD::INTRINSIC_VOID: {
     unsigned IntNo = Node->getConstantOperandVal(1);
     switch (IntNo) {
@@ -4534,6 +4573,24 @@
 
       break;
     }
+    case Intrinsic::x86_tilestored64_internal: {
+      unsigned Opc = X86::PTILESTOREDV;
+      // _tile_stored_internal(row, col, buf, STRIDE, c)
+      SDValue Base = Node->getOperand(4);
+      SDValue Scale = getI8Imm(1, dl);
+      SDValue Index = Node->getOperand(5);
+      SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32);
+      SDValue Segment = CurDAG->getRegister(0, MVT::i16);
+      SDValue Chain = Node->getOperand(0);
+      MachineSDNode *CNode;
+      SDValue Ops[] = { Node->getOperand(2),
+        Node->getOperand(3),
+        Base, Scale, Index, Disp, Segment,
+        Node->getOperand(6), Chain };
+      CNode = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops);
+      ReplaceNode(Node, CNode);
+      return;
+    }
     case Intrinsic::x86_tileloadd64:
     case Intrinsic::x86_tileloaddt164:
     case Intrinsic::x86_tilestored64: {
Index: llvm/lib/Target/X86/X86ExpandPseudo.cpp
===================================================================
--- llvm/lib/Target/X86/X86ExpandPseudo.cpp
+++ llvm/lib/Target/X86/X86ExpandPseudo.cpp
@@ -468,6 +468,26 @@
   case TargetOpcode::ICALL_BRANCH_FUNNEL:
     ExpandICallBranchFunnel(&MBB, MBBI);
     return true;
+  case X86::PTILELOADDV: {
+    for (unsigned i = 2; i > 0; --i)
+      MI.RemoveOperand(i);
+    MI.setDesc(TII->get(X86::TILELOADD));
+    return true;
+  }
+  case X86::PTDPBSSDV: {
+    MI.untieRegOperand(4);
+    for (unsigned i = 3; i > 0; --i)
+      MI.RemoveOperand(i);
+    MI.setDesc(TII->get(X86::TDPBSSD));
+    MI.tieOperands(0, 1);
+    return true;
+  }
+  case X86::PTILESTOREDV: {
+    for (int i = 1; i >= 0; --i)
+      MI.RemoveOperand(i);
+    MI.setDesc(TII->get(X86::TILESTORED));
+    return true;
+  }
   }
   llvm_unreachable("Previous switch has a fallthrough?");
 }
Index: llvm/lib/Target/X86/X86.h
===================================================================
--- llvm/lib/Target/X86/X86.h
+++ llvm/lib/Target/X86/X86.h
@@ -76,6 +76,8 @@
 /// Return a pass that expands WinAlloca pseudo-instructions.
 FunctionPass *createX86WinAllocaExpander();
 
+FunctionPass *createX86TileConfigPass();
+
 /// Return a pass that inserts int3 at the end of the function if it ends with a
 /// CALL instruction. The pass does the same for each funclet as well. This
 /// ensures that the open interval of function start and end PCs contains all
@@ -162,6 +164,8 @@
 void initializeX86PartialReductionPass(PassRegistry &);
 void initializeX86SpeculativeLoadHardeningPassPass(PassRegistry &);
 void initializeX86SpeculativeExecutionSideEffectSuppressionPass(PassRegistry &);
+void initializeX86TileConfigPass(PassRegistry &);
+void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &);
 
 namespace X86AS {
 enum : unsigned {
Index: llvm/lib/Target/X86/CMakeLists.txt
===================================================================
--- llvm/lib/Target/X86/CMakeLists.txt
+++ llvm/lib/Target/X86/CMakeLists.txt
@@ -30,6 +30,8 @@
   X86CmovConversion.cpp
   X86DomainReassignment.cpp
   X86DiscriminateMemOps.cpp
+  X86LowerAMXType.cpp
+  X86TileConfig.cpp
   X86ExpandPseudo.cpp
   X86FastISel.cpp
   X86FixupBWInsts.cpp
Index: llvm/lib/IR/Function.cpp
===================================================================
--- llvm/lib/IR/Function.cpp
+++ llvm/lib/IR/Function.cpp
@@ -826,7 +826,8 @@
   IIT_SUBDIVIDE4_ARG = 45,
   IIT_VEC_OF_BITCASTS_TO_INT = 46,
   IIT_V128 = 47,
-  IIT_BF16 = 48
+  IIT_BF16 = 48,
+  IIT_V256 = 49
 };
 
 static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
@@ -920,6 +921,10 @@
     OutputTable.push_back(IITDescriptor::getVector(128, IsScalableVector));
     DecodeIITType(NextElt, Infos, Info, OutputTable);
     return;
+  case IIT_V256:
+    OutputTable.push_back(IITDescriptor::getVector(256, IsScalableVector));
+    DecodeIITType(NextElt, Infos, Info, OutputTable);
+    return;
   case IIT_V512:
     OutputTable.push_back(IITDescriptor::getVector(512, IsScalableVector));
     DecodeIITType(NextElt, Infos, Info, OutputTable);
Index: llvm/lib/CodeGen/VirtRegMap.cpp
===================================================================
--- llvm/lib/CodeGen/VirtRegMap.cpp
+++ llvm/lib/CodeGen/VirtRegMap.cpp
@@ -68,6 +68,7 @@
   Virt2PhysMap.clear();
   Virt2StackSlotMap.clear();
   Virt2SplitMap.clear();
+  Virt2ShapeMap.clear();
 
   grow();
   return false;
Index: llvm/lib/CodeGen/LiveRegMatrix.cpp
===================================================================
--- llvm/lib/CodeGen/LiveRegMatrix.cpp
+++ llvm/lib/CodeGen/LiveRegMatrix.cpp
@@ -54,6 +54,7 @@
 
 bool LiveRegMatrix::runOnMachineFunction(MachineFunction &MF) {
   TRI = MF.getSubtarget().getRegisterInfo();
+  MRI = &MF.getRegInfo();
   LIS = &getAnalysis<LiveIntervals>();
   VRM = &getAnalysis<VirtRegMap>();
 
@@ -221,3 +222,13 @@
   }
   return false;
 }
+
+Register LiveRegMatrix::getOneVReg(unsigned PhysReg) const {
+  LiveInterval *VRegInterval = nullptr;
+  for (MCRegUnitIterator Unit(PhysReg, TRI); Unit.isValid(); ++Unit) {
+    if ((VRegInterval = Matrix[*Unit].getOneVReg()))
+      return VRegInterval->reg();
+  }
+
+  return MCRegister::NoRegister;
+}
Index: llvm/lib/CodeGen/LiveIntervalUnion.cpp
===================================================================
--- llvm/lib/CodeGen/LiveIntervalUnion.cpp
+++ llvm/lib/CodeGen/LiveIntervalUnion.cpp
@@ -99,6 +99,16 @@
 }
 #endif //!NDEBUG
 
+LiveInterval *LiveIntervalUnion::getOneVReg() const {
+  if (empty())
+    return nullptr;
+  for (LiveSegments::const_iterator SI = Segments.begin(); SI.valid(); ++SI) {
+    // return the first valid live interval
+    return SI.value();
+  }
+  return nullptr;
+}
+
 // Scan the vector of interfering virtual registers in this union. Assume it's
 // quite small.
 bool LiveIntervalUnion::Query::isSeenInterference(LiveInterval *VirtReg) const {
Index: llvm/lib/CodeGen/InlineSpiller.cpp
===================================================================
--- llvm/lib/CodeGen/InlineSpiller.cpp
+++ llvm/lib/CodeGen/InlineSpiller.cpp
@@ -1556,6 +1556,8 @@
     VRM.assignVirt2Phys(New, VRM.getPhys(Old));
   else if (VRM.getStackSlot(Old) != VirtRegMap::NO_STACK_SLOT)
     VRM.assignVirt2StackSlot(New, VRM.getStackSlot(Old));
+  else if (VRM.hasShape(Old))
+    VRM.assignVirt2Shape(New, VRM.getShape(Old));
   else
     llvm_unreachable("VReg should be assigned either physreg or stackslot");
 }
Index: llvm/include/llvm/IR/IntrinsicsX86.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsX86.td
+++ llvm/include/llvm/IR/IntrinsicsX86.td
@@ -4977,3 +4977,26 @@
   def int_x86_tdpbf16ps : GCCBuiltin<"__builtin_ia32_tdpbf16ps">,
               Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], []>;
 }
+
+// AMX - internal intrinsics
+let TargetPrefix = "x86" in {
+  def int_x86_tileloadd64_internal :
+              GCCBuiltin<"__builtin_ia32_tileloadd64_internal">,
+              Intrinsic<[llvm_v256i32_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, llvm_i64_ty],
+                        []>;
+  def int_x86_tilezero_internal :
+              GCCBuiltin<"__builtin_ia32_tilezero_internal">,
+              Intrinsic<[llvm_v256i32_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_v256i32_ty], []>;
+  def int_x86_tdpbssd_internal :
+              GCCBuiltin<"__builtin_ia32_tdpbssd_internal">,
+              Intrinsic<[llvm_v256i32_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                         llvm_v256i32_ty, llvm_v256i32_ty,
+                         llvm_v256i32_ty], []>;
+  def int_x86_tilestored64_internal :
+              GCCBuiltin<"__builtin_ia32_tilestored64_internal">,
+              Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty,
+                             llvm_i64_ty, llvm_v256i32_ty], []>;
+}
Index: llvm/include/llvm/IR/Intrinsics.td
===================================================================
--- llvm/include/llvm/IR/Intrinsics.td
+++ llvm/include/llvm/IR/Intrinsics.td
@@ -289,6 +289,7 @@
 def llvm_v16i32_ty     : LLVMType<v16i32>;   // 16 x i32
 def llvm_v32i32_ty     : LLVMType<v32i32>;   // 32 x i32
 def llvm_v64i32_ty     : LLVMType<v64i32>;   // 64 x i32
+def llvm_v256i32_ty    : LLVMType<v256i32>;  //256 x i32
 
 def llvm_v1i64_ty      : LLVMType<v1i64>;    //  1 x i64
 def llvm_v2i64_ty      : LLVMType<v2i64>;    //  2 x i64
Index: llvm/include/llvm/CodeGen/VirtRegMap.h
===================================================================
--- llvm/include/llvm/CodeGen/VirtRegMap.h
+++ llvm/include/llvm/CodeGen/VirtRegMap.h
@@ -19,6 +19,7 @@
 #include "llvm/ADT/IndexedMap.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/CodeGen/TileShapeInfo.h"
 #include "llvm/Pass.h"
 #include <cassert>
 
@@ -60,6 +61,10 @@
     /// mapping.
     IndexedMap<unsigned, VirtReg2IndexFunctor> Virt2SplitMap;
 
+    /// Virt2ShapeMap - For X86 AMX register whose register is bound shape
+    /// information.
+    DenseMap<unsigned, ShapeT> Virt2ShapeMap;
+
     /// createSpillSlot - Allocate a spill slot for RC from MFI.
     unsigned createSpillSlot(const TargetRegisterClass *RC);
 
@@ -107,6 +112,23 @@
     /// the specified physical register
     void assignVirt2Phys(Register virtReg, MCPhysReg physReg);
 
+    bool isShapeMapEmpty() const {
+      return Virt2ShapeMap.empty();
+    }
+
+    bool hasShape(Register virtReg) const {
+      return getShape(virtReg).isValid();
+    }
+
+    ShapeT getShape(Register virtReg) const {
+      assert(virtReg.isVirtual());
+      return Virt2ShapeMap.lookup(virtReg);
+    }
+
+    void assignVirt2Shape(Register virtReg, ShapeT shape) {
+      Virt2ShapeMap[virtReg.id()] = shape;
+    }
+
     /// clears the specified virtual register's, physical
     /// register mapping
     void clearVirt(Register virtReg) {
@@ -133,6 +155,9 @@
     /// records virtReg is a split live interval from SReg.
     void setIsSplitFromReg(Register virtReg, unsigned SReg) {
       Virt2SplitMap[virtReg.id()] = SReg;
+      if (hasShape(SReg)) {
+        Virt2ShapeMap[virtReg.id()] = getShape(SReg);
+      }
     }
 
     /// returns the live interval virtReg is split from.
Index: llvm/include/llvm/CodeGen/TileShapeInfo.h
===================================================================
--- /dev/null
+++ llvm/include/llvm/CodeGen/TileShapeInfo.h
@@ -0,0 +1,100 @@
+#ifndef LLVM_CODEGEN_TILESHAPEINFO_H
+#define LLVM_CODEGEN_TILESHAPEINFO_H
+
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineOperand.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/Register.h"
+#include <utility>
+
+using namespace llvm;
+
+namespace llvm {
+
+class ShapeT {
+public:
+  ShapeT(MachineOperand *Row, MachineOperand *Col,
+         const MachineRegisterInfo *MRI = nullptr)
+      : Row(Row), Col(Col) {
+    if (MRI)
+      deduceImm(MRI);
+  }
+  ShapeT()
+      : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
+        ColImm(InvalidImmShape) {}
+  bool operator==(const ShapeT &Shape) {
+    MachineOperand *R = Shape.Row;
+    MachineOperand *C = Shape.Col;
+    if (!R || !C)
+      return false;
+    if (!Row || !Col)
+      return false;
+    if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
+      return true;
+    if ((RowImm != InvalidImmShape) && (Shape.getRowImm() != InvalidImmShape) &&
+        (ColImm != InvalidImmShape) && (Shape.getColImm() != InvalidImmShape)) {
+      return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
+    }
+    return false;
+  }
+
+  bool operator!=(const ShapeT &Shape) { return !(*this == Shape); }
+
+  ShapeT &operator=(const ShapeT &RHS) {
+    Row = RHS.Row;
+    Col = RHS.Col;
+    RowImm = RHS.RowImm;
+    ColImm = RHS.ColImm;
+    return *this;
+  }
+
+  MachineOperand *getRow() const { return Row; }
+
+  MachineOperand *getCol() const { return Col; }
+
+  int64_t getRowImm() const { return RowImm; }
+
+  int64_t getColImm() const { return ColImm; }
+
+  bool isValid() { return (Row != nullptr) && (Col != nullptr); }
+
+  void deduceImm(const MachineRegisterInfo *MRI) {
+    // All def must be the same value, otherwise it is invalid MIs.
+    // Find the immediate.
+    // TODO copy propagation.
+    auto GetImm = [&](Register Reg) {
+      int64_t Imm = InvalidImmShape;
+      for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
+        auto *MI = DefMO.getParent();
+        if (MI->isMoveImmediate()) {
+          Imm = MI->getOperand(1).getImm();
+          break;
+        }
+      }
+      return Imm;
+    };
+    RowImm = GetImm(Row->getReg());
+    ColImm = GetImm(Col->getReg());
+  }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  LLVM_DUMP_METHOD void dump() const {
+    if (Row)
+      Row->dump();
+    if (Col)
+      Col->dump();
+    dbgs() << "imm (" << RowImm << ", " << ColImm << ")\n";
+  }
+#endif
+private:
+  static constexpr int64_t InvalidImmShape = -1;
+  MachineOperand *Row;
+  MachineOperand *Col;
+  int64_t RowImm;
+  int64_t ColImm;
+};
+
+} // namespace llvm
+
+#endif
Index: llvm/include/llvm/CodeGen/Passes.h
===================================================================
--- llvm/include/llvm/CodeGen/Passes.h
+++ llvm/include/llvm/CodeGen/Passes.h
@@ -490,6 +490,8 @@
   /// The pass fixups statepoint machine instruction to replace usage of
   /// caller saved registers with stack slots.
   extern char &FixupStatepointCallerSavedID;
+
+  FunctionPass *createX86LowerAMXTypePass();
 } // End llvm namespace
 
 #endif
Index: llvm/include/llvm/CodeGen/LiveRegMatrix.h
===================================================================
--- llvm/include/llvm/CodeGen/LiveRegMatrix.h
+++ llvm/include/llvm/CodeGen/LiveRegMatrix.h
@@ -41,6 +41,7 @@
   const TargetRegisterInfo *TRI;
   LiveIntervals *LIS;
   VirtRegMap *VRM;
+  MachineRegisterInfo *MRI;
 
   // UserTag changes whenever virtual registers have been modified.
   unsigned UserTag = 0;
@@ -152,6 +153,8 @@
   /// Directly access the live interval unions per regunit.
   /// This returns an array indexed by the regunit number.
   LiveIntervalUnion *getLiveUnions() { return &Matrix[0]; }
+
+  Register getOneVReg(unsigned PhysReg) const;
 };
 
 } // end namespace llvm
Index: llvm/include/llvm/CodeGen/LiveIntervalUnion.h
===================================================================
--- llvm/include/llvm/CodeGen/LiveIntervalUnion.h
+++ llvm/include/llvm/CodeGen/LiveIntervalUnion.h
@@ -104,6 +104,9 @@
   void verify(LiveVirtRegBitSet& VisitedVRegs);
 #endif
 
+  // Get any virtual register that is assign to this physical unit
+  LiveInterval *getOneVReg() const;
+
   /// Query interferences between a single live virtual register and a live
   /// interval union.
   class Query {
Index: clang/test/CodeGen/AMX/amx_api.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/AMX/amx_api.c
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown  -target-feature +avx512f  -target-feature +amx-int8  \
+// RUN: -target-feature +amx-bf16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK
+
+#include <immintrin.h>
+
+char buf[1024];
+#define STRIDE 32
+
+char buf2[1024];
+
+void test_api(int cond, short row, short col) {
+//CHECK-LABEL: @test_api
+//CHECK: call <256 x i32> @llvm.x86.tileloadd64.internal
+//CHECK: call <256 x i32> @llvm.x86.tdpbssd.internal
+//CHECK: call void @llvm.x86.tilestored64.internal
+  __tile a = {row, 8};
+  __tile b = {8, col};
+  __tile c = {row, col};
+
+  if(cond) {
+    __tile_loadd(&a, buf, STRIDE);
+    __tile_loadd(&b, buf, STRIDE);
+    __tile_loadd(&c, buf, STRIDE);
+  } else {
+    __tile_loadd(&a, buf2, STRIDE);
+    __tile_loadd(&b, buf2, STRIDE);
+    __tile_loadd(&c, buf2, STRIDE);
+  }
+  __tile_dpbsud(&c, a, b);
+  __tile_stored(buf, STRIDE, c);
+}
Index: clang/lib/Headers/amxintrin.h
===================================================================
--- clang/lib/Headers/amxintrin.h
+++ clang/lib/Headers/amxintrin.h
@@ -66,6 +66,8 @@
   __builtin_ia32_tilerelease();
 }
 
+#undef __DEFAULT_FN_ATTRS
+
 /// Load tile rows from memory specifieid by "base" address and "stride" into
 /// destination tile "dst" using the tile configuration previously configured
 /// via "_tile_loadconfig".
@@ -219,7 +221,59 @@
 #define _tile_dpbf16ps(dst, src0, src1) \
   __builtin_ia32_tdpbf16ps((dst), (src0), (src1))
 
+#define __DEFAULT_FN_ATTRS \
+__attribute__((__always_inline__, __nodebug__, __target__("amx-int8")))
+
+/// This is new intrinsic interface
+typedef int _tile_data __attribute__((__vector_size__(1024), __aligned__(64)));
+static __inline__ _tile_data __DEFAULT_FN_ATTRS
+_tile_loadd_internal(short m, short n, const void *base, int stride) {
+  return __builtin_ia32_tileloadd64_internal(m, n, base,
+                                             (__SIZE_TYPE__)(stride));
+}
+
+static __inline__ _tile_data __DEFAULT_FN_ATTRS
+_tile_zero_internal(short m, short n, _tile_data tile) {
+  return __builtin_ia32_tilezero_internal(m, n, tile);
+}
+
+static __inline__ _tile_data __DEFAULT_FN_ATTRS
+_tile_dpbssd_internal(short m, short n, short k,
+                      _tile_data dst, _tile_data src1, _tile_data src2) {
+  return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2);
+}
+
+static __inline__ void  __DEFAULT_FN_ATTRS
+_tile_stored_internal(short m, short n, void *base,
+                      int stride, _tile_data tile) {
+  return __builtin_ia32_tilestored64_internal(m, n, base,
+                                              (__SIZE_TYPE__)(stride), tile);
+}
+
+typedef struct __tile_str {
+  const short row;
+  const short col;
+  _tile_data tile;
+}__tile;
+
+__DEFAULT_FN_ATTRS
+void __tile_loadd(__tile *dst, const void *base, long stride) {
+  dst->tile = _tile_loadd_internal(dst->row, dst->col, base, stride);
+}
+
+__DEFAULT_FN_ATTRS
+void __tile_dpbsud(__tile *dst, __tile src1, __tile src2) {
+  dst->tile = _tile_dpbssd_internal(src1.row, src2.col, src1.col,
+                                    dst->tile, src1.tile, src2.tile);
+}
+
+__DEFAULT_FN_ATTRS
+void __tile_stored(void *base, long stride, __tile src) {
+  _tile_stored_internal(src.row, src.col, base, stride, src.tile);
+}
+
 #undef __DEFAULT_FN_ATTRS
 
+
 #endif /* __x86_64__ */
 #endif /* __AMXINTRIN_H */
Index: clang/include/clang/Basic/BuiltinsX86_64.def
===================================================================
--- clang/include/clang/Basic/BuiltinsX86_64.def
+++ clang/include/clang/Basic/BuiltinsX86_64.def
@@ -94,6 +94,11 @@
 TARGET_BUILTIN(__builtin_ia32_cvtusi2ss64, "V4fV4fUOiIi", "ncV:128:", "avx512f")
 TARGET_BUILTIN(__builtin_ia32_directstore_u64, "vULi*ULi", "n", "movdiri")
 
+// AMX internal builtin
+TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile")
+TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUsV256i", "n", "amx-tile")
+TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
+TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile")
 // AMX
 TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
 TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile")
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to