Index: llvm/utils/TableGen/IntrinsicEmitter.cpp
--- llvm/utils/TableGen/IntrinsicEmitter.cpp
+++ llvm/utils/TableGen/IntrinsicEmitter.cpp
@@ -248,7 +248,8 @@
   IIT_V128 = 47,
   IIT_BF16 = 48,
   IIT_STRUCT9 = 49,
-  IIT_V256 = 50
+  IIT_V256 = 50,
+  IIT_AMX  = 51
 static void EncodeFixedValueType(MVT::SimpleValueType VT,
@@ -276,6 +277,7 @@
   case MVT::token: return Sig.push_back(IIT_TOKEN);
   case MVT::Metadata: return Sig.push_back(IIT_METADATA);
   case MVT::x86mmx: return Sig.push_back(IIT_MMX);
+  case MVT::x86amx: return Sig.push_back(IIT_AMX);
   // MVT::OtherVT is used to mean the empty struct type here.
   case MVT::Other: return Sig.push_back(IIT_EMPTYSTRUCT);
   // MVT::isVoid is used to represent varargs here.
Index: llvm/utils/TableGen/CodeGenTarget.cpp
--- llvm/utils/TableGen/CodeGenTarget.cpp
+++ llvm/utils/TableGen/CodeGenTarget.cpp
@@ -76,6 +76,7 @@
   case MVT::f128:     return "MVT::f128";
   case MVT::ppcf128:  return "MVT::ppcf128";
   case MVT::x86mmx:   return "MVT::x86mmx";
+  case MVT::x86amx:   return "MVT::x86amx";
   case MVT::Glue:     return "MVT::Glue";
   case MVT::isVoid:   return "MVT::isVoid";
   case MVT::v1i1:     return "MVT::v1i1";
Index: llvm/test/CodeGen/X86/AMX/amx-type.ll
--- llvm/test/CodeGen/X86/AMX/amx-type.ll
+++ llvm/test/CodeGen/X86/AMX/amx-type.ll
@@ -8,18 +8,103 @@
 @buf = dso_local global [1024 x i8] zeroinitializer, align 16
 @buf2 = dso_local global [1024 x i8] zeroinitializer, align 16
+; test bitcast x86_amx to <256 x i32>
+define dso_local void @test_user_empty(i16 %m, i16 %n, i8 *%buf, i64 %s) #2 {
+; CHECK-LABEL: @test_user_empty(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    ret void
+  %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %buf, i64 %s) #3
+  %t2 = bitcast x86_amx %t1 to <256 x i32>
+  ret void
+; test bitcast <256 x i32> to x86_amx
+define dso_local void @test_user_empty2(<256 x i32> %in) #2 {
+; CHECK-LABEL: @test_user_empty2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    ret void
+  %t = bitcast <256 x i32> %in to x86_amx
+  ret void
+define dso_local <256 x i32> @test_amx_load_bitcast(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 {
+; CHECK-LABEL: @test_amx_load_bitcast(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[T1:%.*]] = load <256 x i32>, <256 x i32>* [[IN:%.*]], align 64
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <256 x i32>* [[IN]] to i8*
+; CHECK-NEXT:    [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[TMP0]], i64 64)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP1]]) [[ATTR3:#.*]]
+; CHECK-NEXT:    ret <256 x i32> [[T1]]
+  %t1 = load <256 x i32>, <256 x i32>* %in, align 64
+  %t2 = bitcast <256 x i32> %t1 to x86_amx
+  call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t2) #3
+  ret <256 x i32> %t1
+define dso_local <256 x i32> @test_amx_bitcast_store(<256 x i32>* %out, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 {
+; CHECK-LABEL: @test_amx_bitcast_store(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[M]], i8* [[BUF:%.*]], i64 [[S:%.*]]) [[ATTR3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <256 x i32>* [[OUT:%.*]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], i8* [[TMP0]], i64 64, x86_amx [[T1]])
+; CHECK-NEXT:    [[TMP1:%.*]] = load <256 x i32>, <256 x i32>* [[OUT]], align 1024
+; CHECK-NEXT:    ret <256 x i32> [[TMP1]]
+  %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %m, i8* %buf, i64 %s) #3
+  %t2 = bitcast x86_amx %t1 to <256 x i32>
+  store <256 x i32> %t2, <256 x i32>* %out
+  ret <256 x i32> %t2
+define dso_local void @test_src_add(<256 x i32> %x, <256 x i32> %y, i16 %r, i16 %c, i8* %buf, i64 %s) #2 {
+; CHECK-LABEL: @test_src_add(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 1024
+; CHECK-NEXT:    [[ADD:%.*]] = add <256 x i32> [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    store <256 x i32> [[ADD]], <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C:%.*]], i8* [[TMP1]], i64 64)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP2]]) [[ATTR3]]
+; CHECK-NEXT:    ret void
+  %add = add <256 x i32> %y, %x
+  %t = bitcast <256 x i32> %add to x86_amx
+  call void @llvm.x86.tilestored64.internal(i16 %r, i16 %c, i8* %buf, i64 %s, x86_amx %t) #3
+  ret void
+define dso_local void @test_src_add2(<256 x i32> %x, i16 %r, i16 %c, i8* %buf, i64 %s) #2 {
+; CHECK-LABEL: @test_src_add2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 1024
+; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]]) [[ATTR3]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[TMP1]], i64 64, x86_amx [[T1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024
+; CHECK-NEXT:    [[ADD:%.*]] = add <256 x i32> [[TMP2]], [[X:%.*]]
+; CHECK-NEXT:    ret void
+  %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %r, i16 %c, i8* %buf, i64 %s) #3
+  %t2 = bitcast x86_amx %t1 to <256 x i32>
+  %add = add <256 x i32> %t2, %x
+  ret void
 define dso_local void @test_load(i8* %in, i8* %out) local_unnamed_addr #2 {
 ; CHECK-LABEL: @test_load(
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[IN:%.*]] to <256 x i32>*
 ; CHECK-NEXT:    [[TMP2:%.*]] = bitcast i8* [[OUT:%.*]] to <256 x i32>*
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <256 x i32>* [[TMP1]] to <128 x i32>*
-; CHECK-NEXT:    [[TMP4:%.*]] = load <128 x i32>, <128 x i32>* [[TMP3]], align 64
-; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP3]], i32 1
-; CHECK-NEXT:    [[TMP6:%.*]] = load <128 x i32>, <128 x i32>* [[TMP5]], align 64
-; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <256 x i32>* [[TMP2]] to <128 x i32>*
-; CHECK-NEXT:    store <128 x i32> [[TMP4]], <128 x i32>* [[TMP7]], align 64
-; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP7]], i32 1
-; CHECK-NEXT:    store <128 x i32> [[TMP6]], <128 x i32>* [[TMP8]], align 64
+; CHECK-NEXT:    [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 64, [[TBAA2:!tbaa !.*]]
+; CHECK-NEXT:    store <256 x i32> [[TMP3]], <256 x i32>* [[TMP2]], align 64, [[TBAA2]]
 ; CHECK-NEXT:    ret void
   %1 = bitcast i8* %in to <256 x i32>*
@@ -29,18 +114,33 @@
   ret void
+define dso_local <256 x i32> @foo(<256 x i32>* nocapture readonly byval(<256 x i32>) align 1024 %0, <256 x i32>* nocapture readonly byval(<256 x i32>) align 1024 %1) local_unnamed_addr #0 {
+; CHECK-LABEL: @foo(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[X:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0:%.*]], align 1024, [[TBAA5:!tbaa !.*]]
+; CHECK-NEXT:    [[Y:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1:%.*]], align 1024, [[TBAA5]]
+; CHECK-NEXT:    [[ADD:%.*]] = add <256 x i32> [[Y]], [[X]]
+; CHECK-NEXT:    ret <256 x i32> [[ADD]]
+  %x = load <256 x i32>, <256 x i32>* %0, align 1024, !tbaa !2
+  %y = load <256 x i32>, <256 x i32>* %1, align 1024, !tbaa !2
+  %add = add <256 x i32> %y, %x
+  ret <256 x i32> %add
 define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i64 %2) local_unnamed_addr #0 {
 ; CHECK-LABEL: @__tile_loadd(
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0
-; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2:!tbaa !.*]]
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA5]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1
-; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7:!tbaa !.*]]
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA8:!tbaa !.*]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = shl i64 [[TMP2:%.*]], 32
 ; CHECK-NEXT:    [[TMP9:%.*]] = ashr exact i64 [[TMP8]], 32
-; CHECK-NEXT:    [[TMP10:%.*]] = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP1:%.*]], i64 [[TMP9]]) [[ATTR3:#.*]]
+; CHECK-NEXT:    [[TMP10:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP1:%.*]], i64 [[TMP9]]) [[ATTR3]]
 ; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2
 ; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64, <256 x i32> [[TMP10]])
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64, x86_amx [[TMP10]])
 ; CHECK-NEXT:    ret void
   %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 0
@@ -49,32 +149,33 @@
   %7 = load i16, i16* %6, align 2, !tbaa !7
   %8 = shl i64 %2, 32
   %9 = ashr exact i64 %8, 32
-  %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9) #3
-  %11 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
-  store <256 x i32> %10, <256 x i32>* %11, align 64, !tbaa !8
+  %10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9) #3
+  %11 = bitcast x86_amx %10 to <256 x i32>
+  %12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
+  store <256 x i32> %11, <256 x i32>* %12, align 64, !tbaa !8
   ret void
 define dso_local void @__tile_dpbsud(%struct.__tile_str* nocapture %0, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #0 {
 ; CHECK-LABEL: @__tile_dpbsud(
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0
-; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA5]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1
-; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]]
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA8]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1
-; CHECK-NEXT:    [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2, [[TBAA7]]
+; CHECK-NEXT:    [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2, [[TBAA8]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
 ; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
-; CHECK-NEXT:    [[TMP12:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64)
+; CHECK-NEXT:    [[TMP12:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64)
 ; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
 ; CHECK-NEXT:    [[TMP14:%.*]] = bitcast <256 x i32>* [[TMP13]] to i8*
-; CHECK-NEXT:    [[TMP15:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64)
+; CHECK-NEXT:    [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64)
 ; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
 ; CHECK-NEXT:    [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP16]] to i8*
-; CHECK-NEXT:    [[TMP18:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64)
-; CHECK-NEXT:    [[TMP19:%.*]] = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], <256 x i32> [[TMP12]], <256 x i32> [[TMP15]], <256 x i32> [[TMP18]]) [[ATTR3]]
+; CHECK-NEXT:    [[TMP18:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64)
+; CHECK-NEXT:    [[TMP19:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], x86_amx [[TMP12]], x86_amx [[TMP15]], x86_amx [[TMP18]]) [[ATTR3]]
 ; CHECK-NEXT:    [[TMP20:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, <256 x i32> [[TMP19]])
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, x86_amx [[TMP19]])
 ; CHECK-NEXT:    ret void
   %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0
@@ -85,27 +186,31 @@
   %9 = load i16, i16* %8, align 2, !tbaa !7
   %10 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
   %11 = load <256 x i32>, <256 x i32>* %10, align 64, !tbaa !8
-  %12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2
-  %13 = load <256 x i32>, <256 x i32>* %12, align 64, !tbaa !8
-  %14 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
-  %15 = load <256 x i32>, <256 x i32>* %14, align 64, !tbaa !8
-  %16 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, <256 x i32> %11, <256 x i32> %13, <256 x i32> %15) #3
-  store <256 x i32> %16, <256 x i32>* %10, align 64, !tbaa !8
+  %12 = bitcast <256 x i32> %11 to x86_amx
+  %13 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2
+  %14 = load <256 x i32>, <256 x i32>* %13, align 64, !tbaa !8
+  %15 = bitcast <256 x i32> %14 to x86_amx
+  %16 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
+  %17 = load <256 x i32>, <256 x i32>* %16, align 64, !tbaa !8
+  %18 = bitcast <256 x i32> %17 to x86_amx
+  %19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, x86_amx %12, x86_amx %15, x86_amx %18) #3
+  %20 = bitcast x86_amx %19 to <256 x i32>
+  store <256 x i32> %20, <256 x i32>* %10, align 64, !tbaa !8
   ret void
 define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #1 {
 ; CHECK-LABEL: @__tile_stored(
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0
-; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA5]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1
-; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]]
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA8]]
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
 ; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP8]] to i8*
-; CHECK-NEXT:    [[TMP10:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP9]], i64 64)
+; CHECK-NEXT:    [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP9]], i64 64)
 ; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP1:%.*]], 32
 ; CHECK-NEXT:    [[TMP12:%.*]] = ashr exact i64 [[TMP11]], 32
-; CHECK-NEXT:    tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP12]], <256 x i32> [[TMP10]]) [[ATTR3]]
+; CHECK-NEXT:    tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP12]], x86_amx [[TMP10]]) [[ATTR3]]
 ; CHECK-NEXT:    ret void
   %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0
@@ -114,15 +219,16 @@
   %7 = load i16, i16* %6, align 2, !tbaa !7
   %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
   %9 = load <256 x i32>, <256 x i32>* %8, align 64, !tbaa !8
-  %10 = shl i64 %1, 32
-  %11 = ashr exact i64 %10, 32
-  tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %11, <256 x i32> %9) #3
+  %10 = bitcast <256 x i32> %9 to x86_amx
+  %11 = shl i64 %1, 32
+  %12 = ashr exact i64 %11, 32
+  tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %12, x86_amx %10) #3
   ret void
-declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
-declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
-declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #3
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #3
 attributes #0 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
 attributes #1 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
Index: llvm/test/CodeGen/X86/AMX/amx-spill.ll
--- llvm/test/CodeGen/X86/AMX/amx-spill.ll
+++ llvm/test/CodeGen/X86/AMX/amx-spill.ll
@@ -70,43 +70,43 @@
 ; CHECK-NEXT:    tilerelease
 ; CHECK-NEXT:    vzeroupper
 ; CHECK-NEXT:    retq
-  %4 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
-  %5 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
-  %6 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
-  %7 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
-  %8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
-  %9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
-  %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %4 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %5 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %6 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %7 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %8 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %9 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
   %11 = icmp eq i32 %0, 0
   br i1 %11, label %16, label %12
 12:                                               ; preds = %3
-  %13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
-  %14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
-  %15 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %13 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %14 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %15 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
   br label %20
 16:                                               ; preds = %3
-  %17 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
-  %18 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
-  %19 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %17 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %18 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %19 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
   br label %20
 20:                                               ; preds = %16, %12
-  %21 = phi <256 x i32> [ %17, %16 ], [ %13, %12 ]
-  %22 = phi <256 x i32> [ %18, %16 ], [ %14, %12 ]
-  %23 = phi <256 x i32> [ %19, %16 ], [ %15, %12 ]
-  %24 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %1, <256 x i32> %23, <256 x i32> %21, <256 x i32> %22) #3
-  %25 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %6, <256 x i32> %24, <256 x i32> %5) #3
-  %26 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %8, <256 x i32> %25, <256 x i32> %7) #3
-  %27 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %2, i16 %2, i16 %2, <256 x i32> %10, <256 x i32> %26, <256 x i32> %9) #3
-  tail call void @llvm.x86.tilestored64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %27) #3
+  %21 = phi x86_amx [ %17, %16 ], [ %13, %12 ]
+  %22 = phi x86_amx [ %18, %16 ], [ %14, %12 ]
+  %23 = phi x86_amx [ %19, %16 ], [ %15, %12 ]
+  %24 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %1, x86_amx %23, x86_amx %21, x86_amx %22) #3
+  %25 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, x86_amx %6, x86_amx %24, x86_amx %5) #3
+  %26 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, x86_amx %8, x86_amx %25, x86_amx %7) #3
+  %27 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %2, i16 %2, i16 %2, x86_amx %10, x86_amx %26, x86_amx %9) #3
+  tail call void @llvm.x86.tilestored64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %27) #3
   ret void
-declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
-declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
-declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #3
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #3
 attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
 attributes #3 = { nounwind }
Index: llvm/test/CodeGen/X86/AMX/amx-intrinsic-chain.ll
--- llvm/test/CodeGen/X86/AMX/amx-intrinsic-chain.ll
+++ llvm/test/CodeGen/X86/AMX/amx-intrinsic-chain.ll
@@ -37,23 +37,23 @@
 ; CHECK-NEXT:    vzeroupper
 ; CHECK-NEXT:    retq
-  %a1 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %A_mem, i64 64)
+  %a1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %A_mem, i64 64)
   %addr = getelementptr inbounds i8, i8* %A_mem, i64 1024
-  %a2 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %addr, i64 64)
-  %c1 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64)
+  %a2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %addr, i64 64)
+  %c1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64)
   %caddr = getelementptr inbounds i8, i8* %C_mem, i64 1024
-  %c2 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64)
+  %c2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64)
   br label %dotpd
-  %b = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %B_mem, i64 64)
-  %dp1 = call <256 x i32> @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, <256 x i32> %c1, <256 x i32> %a1, <256 x i32> %b)
-  call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64, <256 x i32> %dp1)
-  %dp2 = call <256 x i32> @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, <256 x i32> %c2, <256 x i32> %a2, <256 x i32> %b)
-  call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64, <256 x i32> %dp2)
+  %b = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %B_mem, i64 64)
+  %dp1 = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %c1, x86_amx %a1, x86_amx %b)
+  call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64, x86_amx %dp1)
+  %dp2 = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %c2, x86_amx %a2, x86_amx %b)
+  call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64, x86_amx %dp2)
   ret void
-declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
-declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>)
-declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>)
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
Index: llvm/test/CodeGen/X86/AMX/amx-config.ll
--- llvm/test/CodeGen/X86/AMX/amx-config.ll
+++ llvm/test/CodeGen/X86/AMX/amx-config.ll
@@ -47,31 +47,31 @@
   br i1 %4, label %11, label %7
 7:                                                ; preds = %3
-  %8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
-  %9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
-  %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %8 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %9 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
   br label %15
 11:                                               ; preds = %3
-  %12 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
-  %13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
-  %14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %12 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %13 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %14 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
   br label %15
 15:                                               ; preds = %11, %7
-  %16 = phi <256 x i32> [ %12, %11 ], [ %8, %7 ]
-  %17 = phi <256 x i32> [ %13, %11 ], [ %9, %7 ]
-  %18 = phi <256 x i32> [ %14, %11 ], [ %10, %7 ]
-  %19 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %6, i16 %2, i16 %1, <256 x i32> %18, <256 x i32> %16, <256 x i32> %17) #3
-  tail call void @llvm.x86.tilestored64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %19) #3
+  %16 = phi x86_amx [ %12, %11 ], [ %8, %7 ]
+  %17 = phi x86_amx [ %13, %11 ], [ %9, %7 ]
+  %18 = phi x86_amx [ %14, %11 ], [ %10, %7 ]
+  %19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %6, i16 %2, i16 %1, x86_amx %18, x86_amx %16, x86_amx %17) #3
+  tail call void @llvm.x86.tilestored64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %19) #3
   ret void
-declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
-declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #3
-declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #3
 attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+avx,+avx2,+avx512f,+cx8,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
 attributes #3 = { nounwind }
Index: llvm/test/CodeGen/X86/AMX/amx-across-func.ll
--- llvm/test/CodeGen/X86/AMX/amx-across-func.ll
+++ llvm/test/CodeGen/X86/AMX/amx-across-func.ll
@@ -71,20 +71,20 @@
 ; CHECK-NEXT:    .cfi_def_cfa_offset 8
 ; CHECK-NEXT:    tilerelease
 ; CHECK-NEXT:    retq
-  %3 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %0, i16 8, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32) #4
-  %4 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 8, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 1024), i64 32) #4
+  %3 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %0, i16 8, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32) #4
+  %4 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 1024), i64 32) #4
   tail call void (...) @foo() #4
-  %5 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32) #4
-  %6 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %0, i16 %1, i16 8, <256 x i32> %5, <256 x i32> %3, <256 x i32> %4) #4
-  tail call void @llvm.x86.tilestored64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32, <256 x i32> %6) #4
+  %5 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32) #4
+  %6 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %0, i16 %1, i16 8, x86_amx %5, x86_amx %3, x86_amx %4) #4
+  tail call void @llvm.x86.tilestored64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32, x86_amx %6) #4
   ret void
 declare dso_local void @foo(...) local_unnamed_addr #3
-declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #4
-declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #4
-declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #4
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #4
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #4
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #4
 attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
 attributes #3 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
Index: llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
--- llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -1120,6 +1120,10 @@
   // Fold away bit casts of the stored value by storing the original type.
   if (auto *BC = dyn_cast<BitCastInst>(V)) {
     V = BC->getOperand(0);
+    // Don't transform when the type is x86_amx, it make the pass that lower
+    // x86_amx type happy.
+    if (BC->getType()->isX86_AMXTy() || V->getType()->isX86_AMXTy())
+      return false;
     if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) {
       combineStoreToNewValue(IC, SI, V);
       return true;
Index: llvm/lib/Target/X86/
--- llvm/lib/Target/X86/
+++ llvm/lib/Target/X86/
@@ -637,7 +637,7 @@
 // Tiles
 let CopyCost = -1 in // Don't allow copying of tile registers
-def TILE : RegisterClass<"X86", [v256i32], 8192,
+def TILE : RegisterClass<"X86", [x86amx], 8192,
                          (sequence "TMM%u", 0, 7)> {let Size = 8192;}
 def TILECFG : RegisterClass<"X86", [untyped], 512, (add TMMCFG)> {
   let CopyCost = -1;  // Don't allow copying of tile config registers.
Index: llvm/lib/Target/X86/X86LowerAMXType.cpp
--- llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -6,20 +6,19 @@
-/// \file Pass to transform <256 x i32>
-/// <256 x i32> is mapped to AMX tile register on X86, AMX instruction set only
-/// provides simple operation on tile register. The basic elementwise operation
-/// is not supported by AMX. Since we define the AMX tile as vector <256 x i32>
+/// \file Pass to transform <256 x i32> load/store
+/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
+/// provides simple operation on x86_amx. The basic elementwise operation
+/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
 /// and only AMX intrinsics can operate on the type, we need transform
-/// load/store <256 x i32> instruction to AMX load/store. Besides, we split
-/// <256 x i32> to 2 <128 x i32> if the vector is not used or defined by AMX
-/// intrinsics, so that in instruction selection it can be lowered to proper
-/// size which HW can support.
+/// load/store <256 x i32> instruction to AMX load/store. Otherwise we are not
+/// able to lower the bitcast instruction to X86 instruction.
 #include "X86.h"
-#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/Passes.h"
@@ -30,231 +29,306 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/IntrinsicsX86.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
 using namespace llvm;
+using namespace PatternMatch;
 #define DEBUG_TYPE "lower-amx-type"
-namespace {
-class X86LowerAMXType {
-  Function &Func;
-  const DataLayout &DL;
-  DenseSet<Instruction *> LDSet;
-  DenseSet<Instruction *> STSet;
-  DenseMap<Value *, std::pair<LoadInst *, LoadInst *>> LoadMap;
-  X86LowerAMXType(Function &F) : Func(F), DL(F.getParent()->getDataLayout()) {}
-  bool visit();
-  bool visitLD();
-  bool visitST();
-  void splitST(Instruction *Inst);
-  void splitLD(Instruction *Inst);
-// Split v256i32 load/store to 2 v128i32, so that ISel can
-// lower it to proper vector size.
-void X86LowerAMXType::splitST(Instruction *Inst) {
-  StoreInst *ST = dyn_cast<StoreInst>(Inst);
-  IRBuilder<> Builder(ST);
-  LLVMContext &Ctx = Builder.getContext();
-  Type *Ty = ST->getValueOperand()->getType();
-  EVT VT = EVT::getEVT(Ty);
-  EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx);
-  Type *HalfTy = HalfVT.getTypeForEVT(Ctx);
-  LoadInst *Lo, *Hi;
-  std::tie(Lo, Hi) = LoadMap[ST->getValueOperand()];
-  Value *Ptr = ST->getPointerOperand();
-  PointerType *HalfPtrTy = HalfTy->getPointerTo(ST->getPointerAddressSpace());
-  Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy);
-  // The HW require the alignment for AMX tile is 64, but front-end generate
-  // code for the vector alignment which is the vector size.
-  uint64_t HalfTySize = HalfTy->getPrimitiveSizeInBits().getFixedSize() / 8;
-  Align Alignment = std::min(Lo->getAlign(), Align(HalfTySize));
-  Builder.CreateAlignedStore(Lo, HalfPtr, Alignment, ST->isVolatile());
+static AllocaInst *CreateAllocaInst(IRBuilder<> &Builder, BasicBlock *BB) {
+  Function &F = *BB->getParent();
+  Module *M = BB->getModule();
+  const DataLayout &DL = M->getDataLayout();
-  HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1));
-  Builder.CreateAlignedStore(Hi, HalfPtr, Alignment, ST->isVolatile());
+  Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
+  auto AllocaAlignment = DL.getPrefTypeAlign(V256I32Ty);
+  unsigned AllocaAS = DL.getAllocaAddrSpace();
+  AllocaInst *AllocaRes =
+      new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
+  AllocaRes->setAlignment(AllocaAlignment);
+  return AllocaRes;
-bool X86LowerAMXType::visitST() {
-  if (STSet.empty())
-    return false;
-  for (auto *Inst : STSet) {
-    Value *Row, *Col;
-    const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst->getOperand(0));
-    if (!II)
-      Row = Col = nullptr;
-    else {
-      switch (II->getIntrinsicID()) {
-      default:
-        Row = Col = nullptr;
-        break;
-      case Intrinsic::x86_tileloadd64_internal:
-      case Intrinsic::x86_tdpbssd_internal: {
-        Row = II->getArgOperand(0);
-        Col = II->getArgOperand(1);
-        break;
-      }
-      }
-    }
-    if (!Row) {
-      splitST(Inst);
-      continue;
+static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
+  Value *Row = nullptr, *Col = nullptr;
+  switch (II->getIntrinsicID()) {
+  default:
+    llvm_unreachable("Expect amx intrinsics");
+  case Intrinsic::x86_tileloadd64_internal:
+  case Intrinsic::x86_tilestored64_internal: {
+    Row = II->getArgOperand(0);
+    Col = II->getArgOperand(1);
+    break;
+  }
+  // a * b + c
+  // The shape depends on which operand.
+  case Intrinsic::x86_tdpbssd_internal: {
+    switch (OpNo) {
+    case 3:
+      Row = II->getArgOperand(0);
+      Col = II->getArgOperand(1);
+      break;
+    case 4:
+      Row = II->getArgOperand(0);
+      Col = II->getArgOperand(2);
+      break;
+    case 5:
+      Row = II->getArgOperand(2);
+      Col = II->getArgOperand(1);
+      break;
-    IRBuilder<> Builder(Inst);
-    LLVMContext &Ctx = Builder.getContext();
-    // Use the maximun column as stride. It must be the same with load stride.
-    Value *Stride = Builder.getInt64(64);
-    Value *I8Ptr =
-        Builder.CreateBitCast(Inst->getOperand(1), Type::getInt8PtrTy(Ctx));
-    std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride,
-                                   Inst->getOperand(0)};
-    Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
+    break;
-  return true;
+  }
+  return std::make_pair(Row, Col);
-void X86LowerAMXType::splitLD(Instruction *Inst) {
-  LoadInst *LD = dyn_cast<LoadInst>(Inst);
-  IRBuilder<> Builder(LD);
-  LLVMContext &Ctx = Builder.getContext();
-  Type *Ty = LD->getType();
-  EVT VT = EVT::getEVT(Ty);
-  EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx);
-  Type *HalfTy = HalfVT.getTypeForEVT(Ctx);
+// %src = load <256 x i32>, <256 x i32>* %addr, align 64
+// %2 = bitcast <256 x i32> %src to x86_amx
+// -->
+// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
+// i8* %addr, i64 %stride64)
+static void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
+  Value *Row = nullptr, *Col = nullptr;
+  Use &U = *(Bitcast->use_begin());
+  unsigned OpNo = U.getOperandNo();
+  auto *II = cast<IntrinsicInst>(U.getUser());
+  std::tie(Row, Col) = getShape(II, OpNo);
+  IRBuilder<> Builder(Bitcast);
+  // Use the maximun column as stride.
+  Value *Stride = Builder.getInt64(64);
+  Value *I8Ptr =
+      Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
+  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
-  Value *Ptr = LD->getPointerOperand();
-  PointerType *HalfPtrTy = HalfTy->getPointerTo(LD->getPointerAddressSpace());
-  Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy);
-  // The HW require the alignment for AMX tile is 64, but front-end generate
-  // code for the vector alignment which is the vector size.
-  uint64_t HalfTySize = HalfTy->getPrimitiveSizeInBits().getFixedSize() / 8;
-  Align Alignment = std::min(LD->getAlign(), Align(HalfTySize));
-  auto *Lo =
-      Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, LD->isVolatile());
+  Value *NewInst =
+      Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
+  Bitcast->replaceAllUsesWith(NewInst);
-  HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1));
-  auto *Hi =
-      Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, LD->isVolatile());
+// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
+//                                                    %stride);
+// %13 = bitcast x86_amx %src to <256 x i32>
+// store <256 x i32> %13, <256 x i32>* %addr, align 64
+// -->
+// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
+//                                           %stride64, %13)
+static void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
-  LoadMap[Inst] = std::make_pair(Lo, Hi);
+  auto *Tile = Bitcast->getOperand(0);
+  auto *II = cast<IntrinsicInst>(Tile);
+  // Tile is output from AMX intrinsic. The first operand of the
+  // intrinsic is row, the second operand of the intrinsic is column.
+  Value *Row = II->getOperand(0);
+  Value *Col = II->getOperand(1);
+  IRBuilder<> Builder(ST);
+  // Use the maximum column as stride. It must be the same with load
+  // stride.
+  Value *Stride = Builder.getInt64(64);
+  Value *I8Ptr =
+      Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
+  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
+  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
+  if (Bitcast->hasOneUse())
+    return;
+  // %13 = bitcast x86_amx %src to <256 x i32>
+  // store <256 x i32> %13, <256 x i32>* %addr, align 64
+  // %add = <256 x i32> %13, <256 x i32> %src2
+  // -->
+  // %13 = bitcast x86_amx %src to <256 x i32>
+  // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
+  //                                           %stride64, %13)
+  // %14 = load <256 x i32>, %addr
+  // %add = <256 x i32> %14, <256 x i32> %src2
+  Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
+  Bitcast->replaceAllUsesWith(Vec);
-bool X86LowerAMXType::visitLD() {
-  if (LDSet.empty())
-    return false;
-  for (auto &Inst : LDSet) {
-    int Count = 0;
-    Value *NewInst = nullptr;
-    // The user should be all AMX intrinsics or all LLVM instruction.
-    // Don't support it is used by both AMX intrinsics and LLVM instructions.
-    for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
-      Use &U = *I++;
-      const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U.getUser());
-      if (!II) {
-        Count++;
-        continue;
-      }
-      if (NewInst)
-        continue;
-      Value *Row, *Col;
-      switch (II->getIntrinsicID()) {
-      default:
-        report_fatal_error("Non-AMX intrinsic use tile type.");
-        break;
-      case Intrinsic::x86_tdpbssd_internal: {
-        unsigned OpNo = U.getOperandNo();
-        switch (OpNo) {
-        case 3:
-          Row = II->getArgOperand(0);
-          Col = II->getArgOperand(1);
-          break;
-        case 4:
-          Row = II->getArgOperand(0);
-          Col = II->getArgOperand(2);
-          break;
-        case 5:
-          Row = II->getArgOperand(2);
-          Col = II->getArgOperand(1);
-          break;
-        }
-        break;
-      }
-      case Intrinsic::x86_tilestored64_internal: {
-        Row = II->getArgOperand(0);
-        Col = II->getArgOperand(1);
-        break;
-      }
-      }
-      assert(Count == 0 && "Can NOT mix amx intrinsic and LLVM instruction");
-      // FIXME: The shape def should be ahead of load.
-      IRBuilder<> Builder(Inst);
-      LLVMContext &Ctx = Builder.getContext();
-      // Use the maximun column as stride.
-      Value *Stride = Builder.getInt64(64);
-      Value *I8Ptr =
-          Builder.CreateBitCast(Inst->getOperand(0), Type::getInt8PtrTy(Ctx));
-      std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
+// transform bitcast to <store, load> instructions.
+static bool transformBitcast(BitCastInst *Bitcast) {
+  IRBuilder<> Builder(Bitcast);
+  AllocaInst *AllocaAddr;
+  Value *I8Ptr, *Stride;
+  auto *Src = Bitcast->getOperand(0);
-      NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
-                                        None, Args);
+  auto Prepare = [&]() {
+    AllocaAddr = CreateAllocaInst(Builder, Bitcast->getParent());
+    I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
+    Stride = Builder.getInt64(64);
+  };
-      Inst->replaceAllUsesWith(NewInst);
-    }
-    if (!NewInst)
-      splitLD(Inst);
+  if (Bitcast->getType()->isX86_AMXTy()) {
+    // %2 = bitcast <256 x i32> %src to x86_amx
+    // -->
+    // %addr = alloca <256 x i32>, align 1024
+    // store <256 x i32> %src, <256 x i32>* %addr, align 1024
+    // %addr2 = bitcast <256 x i32>* to i8*
+    // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
+    //                                                  i8* %addr2,
+    //                                                  i64 64)
+    Use &U = *(Bitcast->use_begin());
+    unsigned OpNo = U.getOperandNo();
+    auto *II = dyn_cast<IntrinsicInst>(U.getUser());
+    if (!II)
+      return false; // May be bitcast from x86amx to <256 x i32>.
+    Prepare();
+    Builder.CreateStore(Src, AllocaAddr);
+    // TODO we can pick an constant operand for the shape.
+    Value *Row = nullptr, *Col = nullptr;
+    std::tie(Row, Col) = getShape(II, OpNo);
+    std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
+    Value *NewInst = Builder.CreateIntrinsic(
+        Intrinsic::x86_tileloadd64_internal, None, Args);
+    Bitcast->replaceAllUsesWith(NewInst);
+  } else {
+    // %2 = bitcast x86_amx %src to <256 x i32>
+    // -->
+    // %addr = alloca <256 x i32>, align 1024
+    // %addr2 = bitcast <256 x i32>* to i8*
+    // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
+    //                                           i8* %addr2, i64 %stride)
+    // %2 = load <256 x i32>, <256 x i32>* %addr, align 1024
+    auto *II = dyn_cast<IntrinsicInst>(Src);
+    if (!II)
+      return false; // May be bitcast from <256 x i32> to x86amx.
+    Prepare();
+    Value *Row = II->getOperand(0);
+    Value *Col = II->getOperand(1);
+    std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
+    Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
+    Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
+    Bitcast->replaceAllUsesWith(NewInst);
   return true;
+namespace {
+class X86LowerAMXType {
+  Function &Func;
+  X86LowerAMXType(Function &F) : Func(F) {}
+  bool visit();
 bool X86LowerAMXType::visit() {
-  bool C;
-  auto IsAMXType = [](FixedVectorType *VTy) {
-    if (!VTy)
-      return false;
-    if (!VTy->getScalarType()->isIntegerTy(32))
-      return false;
-    if (VTy->getNumElements() != 256)
-      return false;
+  SmallVector<Instruction *, 8> DeadInsts;
-    return true;
-  };
+  for (BasicBlock *BB : post_order(&Func)) {
+    for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
+         II != IE;) {
+      Instruction &Inst = *II++;
+      auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
+      if (!Bitcast)
+        continue;
-  for (BasicBlock &BB : Func) {
-    for (Instruction &Inst : BB) {
-      LoadInst *LD = dyn_cast<LoadInst>(&Inst);
-      // Check load instruction.
-      // %3 = load <256 x i32>, <256 x i32>* %1, align 64
-      if (LD) {
-        FixedVectorType *VTy = dyn_cast<FixedVectorType>(Inst.getType());
-        if (!IsAMXType(VTy))
+      Value *Src = Bitcast->getOperand(0);
+      if (Bitcast->getType()->isX86_AMXTy()) {
+        if (Bitcast->user_empty()) {
+          DeadInsts.push_back(Bitcast);
-        LDSet.insert(&Inst);
-        continue;
+        }
+        LoadInst *LD = dyn_cast<LoadInst>(Src);
+        if (!LD) {
+          if (transformBitcast(Bitcast))
+            DeadInsts.push_back(Bitcast);
+          continue;
+        }
+        // If load has mutli-user, duplicate a amx load.
+        // %src = load <256 x i32>, <256 x i32>* %addr, align 64
+        // %2 = bitcast <256 x i32> %src to x86_amx
+        // %add = add <256 x i32> %src, <256 x i32> %src2
+        // -->
+        // %src = load <256 x i32>, <256 x i32>* %addr, align 64
+        // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
+        //                                            i8* %addr, i64 %stride64)
+        // %add = add <256 x i32> %src, <256 x i32> %src2
+        // If load has one user, the load will be eliminated in DAG ISel.
+        // %src = load <256 x i32>, <256 x i32>* %addr, align 64
+        // %2 = bitcast <256 x i32> %src to x86_amx
+        // -->
+        // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
+        //                                            i8* %addr, i64 %stride64)
+        combineLoadBitcast(LD, Bitcast);
+        DeadInsts.push_back(Bitcast);
+        if (LD->hasOneUse())
+          DeadInsts.push_back(LD);
+      } else if (Src->getType()->isX86_AMXTy()) {
+        if (Bitcast->user_empty()) {
+          DeadInsts.push_back(Bitcast);
+          continue;
+        }
+        StoreInst *ST = nullptr;
+        for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
+             UI != UE;) {
+          Value *I = (UI++)->getUser();
+          ST = dyn_cast<StoreInst>(I);
+          if (ST)
+            break;
+        }
+        if (!ST) {
+          if (transformBitcast(Bitcast))
+            DeadInsts.push_back(Bitcast);
+          continue;
+        }
+        // If bitcast (%13) has one use, combine bitcast and store to amx store.
+        // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
+        //                                                    %stride);
+        // %13 = bitcast x86_amx %src to <256 x i32>
+        // store <256 x i32> %13, <256 x i32>* %addr, align 64
+        // -->
+        // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
+        //                                           %stride64, %13)
+        //
+        // If bitcast (%13) has multi-use, transform as below.
+        // %13 = bitcast x86_amx %src to <256 x i32>
+        // store <256 x i32> %13, <256 x i32>* %addr, align 64
+        // %add = <256 x i32> %13, <256 x i32> %src2
+        // -->
+        // %13 = bitcast x86_amx %src to <256 x i32>
+        // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
+        //                                           %stride64, %13)
+        // %14 = load <256 x i32>, %addr
+        // %add = <256 x i32> %14, <256 x i32> %src2
+        //
+        combineBitcastStore(Bitcast, ST);
+        // Delete user first.
+        DeadInsts.push_back(ST);
+        DeadInsts.push_back(Bitcast);
-      // Check store instruction.
-      // store <256 x i32> %3, <256 x i32>* %2, align 64
-      StoreInst *ST = dyn_cast<StoreInst>(&Inst);
-      if (!ST)
-        continue;
-      FixedVectorType *VTy =
-          dyn_cast<FixedVectorType>(ST->getOperand(0)->getType());
-      if (!IsAMXType(VTy))
-        continue;
-      STSet.insert(&Inst);
-  C = visitLD() | visitST();
-  for (auto *Inst : STSet)
-    Inst->eraseFromParent();
-  for (auto *Inst : LDSet)
-    Inst->eraseFromParent();
+  bool C = !DeadInsts.empty();
+  SmallSet<Instruction *, 8> DeletedInst;
+  auto DeleteInst = [&](Instruction *Inst) {
+    SmallVector<Instruction *, 4> DeadIs;
+    DeadIs.push_back(Inst);
+    while (!DeadIs.empty()) {
+      auto *Inst = DeadIs.back();
+      DeadIs.pop_back();
+      if (DeletedInst.count(Inst))
+        continue;
+      for (auto I = Inst->op_begin(), E = Inst->op_end(); I != E;) {
+        Instruction *Op = dyn_cast<Instruction>(*I);
+        if (Op && Op->hasOneUse())
+          DeadIs.push_back(Op);
+        ++I;
+      }
+      Inst->eraseFromParent();
+      DeletedInst.insert(Inst);
+    }
+  };
+  for (auto *Inst : DeadInsts)
+    DeleteInst(Inst);
   return C;
 } // anonymous namespace
Index: llvm/lib/Target/X86/X86ISelLowering.cpp
--- llvm/lib/Target/X86/X86ISelLowering.cpp
+++ llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -1898,7 +1898,7 @@
   if (Subtarget.hasAMXTILE()) {
-    addRegisterClass(MVT::v256i32, &X86::TILERegClass);
+    addRegisterClass(MVT::x86amx, &X86::TILERegClass);
   // We want to custom lower some of our intrinsics.
@@ -5346,11 +5346,6 @@
   if (MemVT.getSizeInBits() > Subtarget.getPreferVectorWidth())
     return false;
-  // Don't merge to x86 amx tile, as we only map MVT::v256i32
-  // to x86 amx tile on amx intrinsics.
-  if (MemVT == MVT::v256i32)
-    return false;
   return true;
Index: llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
--- llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -4618,7 +4618,7 @@
-      CNode = CurDAG->getMachineNode(Opc, dl, {MVT::v256i32, MVT::Other}, Ops);
+      CNode = CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops);
       ReplaceNode(Node, CNode);
@@ -4637,7 +4637,7 @@
       MachineSDNode *CNode =
-          CurDAG->getMachineNode(Opc, dl, {MVT::v256i32, MVT::Other}, Ops);
+          CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops);
       ReplaceNode(Node, CNode);
Index: llvm/lib/IR/Type.cpp
--- llvm/lib/IR/Type.cpp
+++ llvm/lib/IR/Type.cpp
@@ -49,6 +49,7 @@
   case LabelTyID     : return getLabelTy(C);
   case MetadataTyID  : return getMetadataTy(C);
   case X86_MMXTyID   : return getX86_MMXTy(C);
+  case X86_AMXTyID   : return getX86_AMXTy(C);
   case TokenTyID     : return getTokenTy(C);
     return nullptr;
@@ -81,6 +82,14 @@
       Ty->getPrimitiveSizeInBits().getFixedSize() == 64)
     return true;
+  //  8192-bit fixed width vector types can be losslessly converted to x86amx.
+  if (((isa<FixedVectorType>(this)) && Ty->isX86_AMXTy()) &&
+      getPrimitiveSizeInBits().getFixedSize() == 8192)
+    return true;
+  if ((isX86_AMXTy() && isa<FixedVectorType>(Ty)) &&
+      Ty->getPrimitiveSizeInBits().getFixedSize() == 8192)
+    return true;
   // At this point we have only various mismatches of the first class types
   // remaining and ptr->ptr. Just select the lossless conversions. Everything
   // else is not lossless. Conservatively assume we can't losslessly convert
@@ -120,6 +129,7 @@
   case Type::FP128TyID: return TypeSize::Fixed(128);
   case Type::PPC_FP128TyID: return TypeSize::Fixed(128);
   case Type::X86_MMXTyID: return TypeSize::Fixed(64);
+  case Type::X86_AMXTyID: return TypeSize::Fixed(8192);
   case Type::IntegerTyID:
     return TypeSize::Fixed(cast<IntegerType>(this)->getBitWidth());
   case Type::FixedVectorTyID:
@@ -179,6 +189,7 @@
 Type *Type::getFP128Ty(LLVMContext &C) { return &C.pImpl->FP128Ty; }
 Type *Type::getPPC_FP128Ty(LLVMContext &C) { return &C.pImpl->PPC_FP128Ty; }
 Type *Type::getX86_MMXTy(LLVMContext &C) { return &C.pImpl->X86_MMXTy; }
+Type *Type::getX86_AMXTy(LLVMContext &C) { return &C.pImpl->X86_AMXTy; }
 IntegerType *Type::getInt1Ty(LLVMContext &C) { return &C.pImpl->Int1Ty; }
 IntegerType *Type::getInt8Ty(LLVMContext &C) { return &C.pImpl->Int8Ty; }
@@ -223,6 +234,10 @@
   return getX86_MMXTy(C)->getPointerTo(AS);
+PointerType *Type::getX86_AMXPtrTy(LLVMContext &C, unsigned AS) {
+  return getX86_AMXTy(C)->getPointerTo(AS);
 PointerType *Type::getIntNPtrTy(LLVMContext &C, unsigned N, unsigned AS) {
   return getIntNTy(C, N)->getPointerTo(AS);
Index: llvm/lib/IR/LLVMContextImpl.h
--- llvm/lib/IR/LLVMContextImpl.h
+++ llvm/lib/IR/LLVMContextImpl.h
@@ -1418,7 +1418,7 @@
   // Basic type instances.
   Type VoidTy, LabelTy, HalfTy, BFloatTy, FloatTy, DoubleTy, MetadataTy,
-  Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy;
+  Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy, X86_AMXTy;
   IntegerType Int1Ty, Int8Ty, Int16Ty, Int32Ty, Int64Ty, Int128Ty;
   BumpPtrAllocator Alloc;
Index: llvm/lib/IR/LLVMContextImpl.cpp
--- llvm/lib/IR/LLVMContextImpl.cpp
+++ llvm/lib/IR/LLVMContextImpl.cpp
@@ -35,6 +35,7 @@
     FP128Ty(C, Type::FP128TyID),
     PPC_FP128Ty(C, Type::PPC_FP128TyID),
     X86_MMXTy(C, Type::X86_MMXTyID),
+    X86_AMXTy(C, Type::X86_AMXTyID),
     Int1Ty(C, 1),
     Int8Ty(C, 8),
     Int16Ty(C, 16),
Index: llvm/lib/IR/Function.cpp
--- llvm/lib/IR/Function.cpp
+++ llvm/lib/IR/Function.cpp
@@ -764,6 +764,7 @@
     case Type::FP128TyID:     Result += "f128";     break;
     case Type::PPC_FP128TyID: Result += "ppcf128";  break;
     case Type::X86_MMXTyID:   Result += "x86mmx";   break;
+    case Type::X86_AMXTyID:   Result += "x86amx";   break;
     case Type::IntegerTyID:
       Result += "i" + utostr(cast<IntegerType>(Ty)->getBitWidth());
@@ -848,7 +849,8 @@
   IIT_V128 = 47,
   IIT_BF16 = 48,
   IIT_STRUCT9 = 49,
-  IIT_V256 = 50
+  IIT_V256 = 50,
+  IIT_AMX  = 51
 static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
@@ -871,6 +873,9 @@
   case IIT_MMX:
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::MMX, 0));
+  case IIT_AMX:
+    OutputTable.push_back(IITDescriptor::get(IITDescriptor::AMX, 0));
+    return;
   case IIT_TOKEN:
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::Token, 0));
@@ -1108,6 +1113,7 @@
   case IITDescriptor::Void: return Type::getVoidTy(Context);
   case IITDescriptor::VarArg: return Type::getVoidTy(Context);
   case IITDescriptor::MMX: return Type::getX86_MMXTy(Context);
+  case IITDescriptor::AMX: return Type::getX86_AMXTy(Context);
   case IITDescriptor::Token: return Type::getTokenTy(Context);
   case IITDescriptor::Metadata: return Type::getMetadataTy(Context);
   case IITDescriptor::Half: return Type::getHalfTy(Context);
@@ -1287,6 +1293,7 @@
     case IITDescriptor::Void: return !Ty->isVoidTy();
     case IITDescriptor::VarArg: return true;
     case IITDescriptor::MMX:  return !Ty->isX86_MMXTy();
+    case IITDescriptor::AMX:  return !Ty->isX86_AMXTy();
     case IITDescriptor::Token: return !Ty->isTokenTy();
     case IITDescriptor::Metadata: return !Ty->isMetadataTy();
     case IITDescriptor::Half: return !Ty->isHalfTy();
Index: llvm/lib/IR/DataLayout.cpp
--- llvm/lib/IR/DataLayout.cpp
+++ llvm/lib/IR/DataLayout.cpp
@@ -810,6 +810,8 @@
     Alignment = PowerOf2Ceil(Alignment);
     return Align(Alignment);
+  case Type::X86_AMXTyID:
+    return Align(512);
     llvm_unreachable("Bad type for getAlignment!!!");
Index: llvm/lib/IR/Core.cpp
--- llvm/lib/IR/Core.cpp
+++ llvm/lib/IR/Core.cpp
@@ -512,6 +512,8 @@
     return LLVMVectorTypeKind;
   case Type::X86_MMXTyID:
     return LLVMX86_MMXTypeKind;
+  case Type::X86_AMXTyID:
+    return LLVMX86_AMXTypeKind;
   case Type::TokenTyID:
     return LLVMTokenTypeKind;
   case Type::ScalableVectorTyID:
@@ -623,6 +625,9 @@
 LLVMTypeRef LLVMX86MMXTypeInContext(LLVMContextRef C) {
   return (LLVMTypeRef) Type::getX86_MMXTy(*unwrap(C));
+LLVMTypeRef LLVMX86AMXTypeInContext(LLVMContextRef C) {
+  return (LLVMTypeRef) Type::getX86_AMXTy(*unwrap(C));
 LLVMTypeRef LLVMHalfType(void) {
   return LLVMHalfTypeInContext(LLVMGetGlobalContext());
@@ -648,6 +653,9 @@
 LLVMTypeRef LLVMX86MMXType(void) {
   return LLVMX86MMXTypeInContext(LLVMGetGlobalContext());
+LLVMTypeRef LLVMX86AMXType(void) {
+  return LLVMX86AMXTypeInContext(LLVMGetGlobalContext());
 /*--.. Operations on function types ........................................--*/
Index: llvm/lib/IR/ConstantFold.cpp
--- llvm/lib/IR/ConstantFold.cpp
+++ llvm/lib/IR/ConstantFold.cpp
@@ -535,7 +535,7 @@
     return UndefValue::get(DestTy);
-  if (V->isNullValue() && !DestTy->isX86_MMXTy() &&
+  if (V->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy() &&
       opc != Instruction::AddrSpaceCast)
     return Constant::getNullValue(DestTy);
Index: llvm/lib/IR/AsmWriter.cpp
--- llvm/lib/IR/AsmWriter.cpp
+++ llvm/lib/IR/AsmWriter.cpp
@@ -609,6 +609,7 @@
   case Type::LabelTyID:     OS << "label"; return;
   case Type::MetadataTyID:  OS << "metadata"; return;
   case Type::X86_MMXTyID:   OS << "x86_mmx"; return;
+  case Type::X86_AMXTyID:   OS << "x86_amx"; return;
   case Type::TokenTyID:     OS << "token"; return;
   case Type::IntegerTyID:
     OS << 'i' << cast<IntegerType>(Ty)->getBitWidth();
Index: llvm/lib/CodeGen/ValueTypes.cpp
--- llvm/lib/CodeGen/ValueTypes.cpp
+++ llvm/lib/CodeGen/ValueTypes.cpp
@@ -164,6 +164,7 @@
   case MVT::Other:     return "ch";
   case MVT::Glue:      return "glue";
   case MVT::x86mmx:    return "x86mmx";
+  case MVT::x86amx:    return "x86amx";
   case MVT::Metadata:  return "Metadata";
   case MVT::Untyped:   return "Untyped";
   case MVT::exnref:    return "exnref";
@@ -195,6 +196,7 @@
   case MVT::f128:    return Type::getFP128Ty(Context);
   case MVT::ppcf128: return Type::getPPC_FP128Ty(Context);
   case MVT::x86mmx:  return Type::getX86_MMXTy(Context);
+  case MVT::x86amx:  return Type::getX86_AMXTy(Context);
   case MVT::v1i1:
     return FixedVectorType::get(Type::getInt1Ty(Context), 1);
   case MVT::v2i1:
@@ -501,6 +503,7 @@
   case Type::DoubleTyID:    return MVT(MVT::f64);
   case Type::X86_FP80TyID:  return MVT(MVT::f80);
   case Type::X86_MMXTyID:   return MVT(MVT::x86mmx);
+  case Type::X86_AMXTyID:   return MVT(MVT::x86amx);
   case Type::FP128TyID:     return MVT(MVT::f128);
   case Type::PPC_FP128TyID: return MVT(MVT::ppcf128);
   case Type::PointerTyID:   return MVT(MVT::iPTR);
Index: llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
--- llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -913,6 +913,7 @@
     case Type::LabelTyID:     Code = bitc::TYPE_CODE_LABEL;     break;
     case Type::MetadataTyID:  Code = bitc::TYPE_CODE_METADATA;  break;
     case Type::X86_MMXTyID:   Code = bitc::TYPE_CODE_X86_MMX;   break;
+    case Type::X86_AMXTyID:   Code = bitc::TYPE_CODE_X86_AMX;   break;
     case Type::TokenTyID:     Code = bitc::TYPE_CODE_TOKEN;     break;
     case Type::IntegerTyID:
       // INTEGER: [width]
Index: llvm/lib/Bitcode/Reader/BitcodeReader.cpp
--- llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1763,6 +1763,9 @@
     case bitc::TYPE_CODE_X86_MMX:   // X86_MMX
       ResultTy = Type::getX86_MMXTy(Context);
+    case bitc::TYPE_CODE_X86_AMX:   // X86_AMX
+      ResultTy = Type::getX86_AMXTy(Context);
+      break;
     case bitc::TYPE_CODE_TOKEN:     // TOKEN
       ResultTy = Type::getTokenTy(Context);
Index: llvm/lib/AsmParser/LLLexer.cpp
--- llvm/lib/AsmParser/LLLexer.cpp
+++ llvm/lib/AsmParser/LLLexer.cpp
@@ -840,6 +840,7 @@
   TYPEKEYWORD("label",     Type::getLabelTy(Context));
   TYPEKEYWORD("metadata",  Type::getMetadataTy(Context));
   TYPEKEYWORD("x86_mmx",   Type::getX86_MMXTy(Context));
+  TYPEKEYWORD("x86_amx",   Type::getX86_AMXTy(Context));
   TYPEKEYWORD("token",     Type::getTokenTy(Context));
Index: llvm/lib/Analysis/ConstantFolding.cpp
--- llvm/lib/Analysis/ConstantFolding.cpp
+++ llvm/lib/Analysis/ConstantFolding.cpp
@@ -105,9 +105,9 @@
          "Invalid constantexpr bitcast!");
   // Catch the obvious splat cases.
-  if (C->isNullValue() && !DestTy->isX86_MMXTy())
+  if (C->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy())
     return Constant::getNullValue(DestTy);
-  if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() &&
+  if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy() &&
       !DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types!
     return Constant::getAllOnesValue(DestTy);
@@ -358,12 +358,13 @@
     // Catch the obvious splat cases (since all-zeros can coerce non-integral
     // pointers legally).
-    if (C->isNullValue() && !DestTy->isX86_MMXTy())
+    if (C->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy())
       return Constant::getNullValue(DestTy);
     if (C->isAllOnesValue() &&
         (DestTy->isIntegerTy() || DestTy->isFloatingPointTy() ||
          DestTy->isVectorTy()) &&
-        !DestTy->isX86_MMXTy() && !DestTy->isPtrOrPtrVectorTy())
+        !DestTy->isX86_AMXTy() && !DestTy->isX86_MMXTy() &&
+        !DestTy->isPtrOrPtrVectorTy())
       // Get ones when the input is trivial, but
       // only for supported types inside getAllOnesValue.
       return Constant::getAllOnesValue(DestTy);
@@ -575,14 +576,16 @@
     C = FoldBitCast(C, MapTy->getPointerTo(AS), DL);
     if (Constant *Res = FoldReinterpretLoadFromConstPtr(C, MapTy, DL)) {
-      if (Res->isNullValue() && !LoadTy->isX86_MMXTy())
+      if (Res->isNullValue() && !LoadTy->isX86_MMXTy() &&
+          !LoadTy->isX86_AMXTy())
         // Materializing a zero can be done trivially without a bitcast
         return Constant::getNullValue(LoadTy);
       Type *CastTy = LoadTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(LoadTy) : LoadTy;
       Res = FoldBitCast(Res, CastTy, DL);
       if (LoadTy->isPtrOrPtrVectorTy()) {
         // For vector of pointer, we needed to first convert to a vector of integer, then do vector inttoptr
-        if (Res->isNullValue() && !LoadTy->isX86_MMXTy())
+        if (Res->isNullValue() && !LoadTy->isX86_MMXTy() &&
+            !LoadTy->isX86_AMXTy())
           return Constant::getNullValue(LoadTy);
         if (DL.isNonIntegralPointerType(LoadTy->getScalarType()))
           // Be careful not to replace a load of an addrspace value with an inttoptr here
Index: llvm/include/llvm/Support/MachineValueType.h
--- llvm/include/llvm/Support/MachineValueType.h
+++ llvm/include/llvm/Support/MachineValueType.h
@@ -247,9 +247,10 @@
       exnref         = 161,   // WebAssembly's exnref type
       funcref        = 162,   // WebAssembly's funcref type
       externref      = 163,   // WebAssembly's externref type
+      x86amx         = 164,   // This is an X86 AMX value
       FIRST_VALUETYPE =  1,   // This is always the beginning of the list.
-      LAST_VALUETYPE = 164,   // This always remains at the end of the list.
+      LAST_VALUETYPE = 165,   // This always remains at the end of the list.
       // This is the current maximum for LAST_VALUETYPE.
       // MVT::MAX_ALLOWED_VALUETYPE is used for asserts and to size bit vectors
@@ -966,6 +967,7 @@
       case v256i32:
       case v128i64:
       case v256f32:
+      case x86amx:
       case v128f64:  return TypeSize::Fixed(8192);
       case v512i32:
       case v256i64:
Index: llvm/include/llvm/IR/Type.h
--- llvm/include/llvm/IR/Type.h
+++ llvm/include/llvm/IR/Type.h
@@ -65,6 +65,7 @@
     LabelTyID,     ///< Labels
     MetadataTyID,  ///< Metadata
     X86_MMXTyID,   ///< MMX vectors (64 bits, X86 specific)
+    X86_AMXTyID,   ///< AMX vectors (8192 bits, X86 specific)
     TokenTyID,     ///< Tokens
     // Derived types... see DerivedTypes.h file.
@@ -182,6 +183,9 @@
   /// Return true if this is X86 MMX.
   bool isX86_MMXTy() const { return getTypeID() == X86_MMXTyID; }
+  /// Return true if this is X86 AMX.
+  bool isX86_AMXTy() const { return getTypeID() == X86_AMXTyID; }
   /// Return true if this is a FP type or a vector of FP.
   bool isFPOrFPVectorTy() const { return getScalarType()->isFloatingPointTy(); }
@@ -252,7 +256,7 @@
   /// includes all first-class types except struct and array types.
   bool isSingleValueType() const {
     return isFloatingPointTy() || isX86_MMXTy() || isIntegerTy() ||
-           isPointerTy() || isVectorTy();
+           isPointerTy() || isVectorTy() || isX86_AMXTy();
   /// Return true if the type is an aggregate type. This means it is valid as
@@ -268,8 +272,8 @@
   bool isSized(SmallPtrSetImpl<Type*> *Visited = nullptr) const {
     // If it's a primitive, it is always sized.
     if (getTypeID() == IntegerTyID || isFloatingPointTy() ||
-        getTypeID() == PointerTyID ||
-        getTypeID() == X86_MMXTyID)
+        getTypeID() == PointerTyID || getTypeID() == X86_MMXTyID ||
+        getTypeID() == X86_AMXTyID)
       return true;
     // If it is not something that can have a size (e.g. a function or label),
     // it doesn't have a size.
@@ -405,6 +409,7 @@
   static Type *getFP128Ty(LLVMContext &C);
   static Type *getPPC_FP128Ty(LLVMContext &C);
   static Type *getX86_MMXTy(LLVMContext &C);
+  static Type *getX86_AMXTy(LLVMContext &C);
   static Type *getTokenTy(LLVMContext &C);
   static IntegerType *getIntNTy(LLVMContext &C, unsigned N);
   static IntegerType *getInt1Ty(LLVMContext &C);
@@ -460,6 +465,7 @@
   static PointerType *getFP128PtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getPPC_FP128PtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getX86_MMXPtrTy(LLVMContext &C, unsigned AS = 0);
+  static PointerType *getX86_AMXPtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getIntNPtrTy(LLVMContext &C, unsigned N, unsigned AS = 0);
   static PointerType *getInt1PtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getInt8PtrTy(LLVMContext &C, unsigned AS = 0);
Index: llvm/include/llvm/IR/
--- llvm/include/llvm/IR/
+++ llvm/include/llvm/IR/
@@ -5041,6 +5041,22 @@
               Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
                         [ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<1>>,
+  // AMX - internal intrinsics
+  def int_x86_tileloadd64_internal :
+              GCCBuiltin<"__builtin_ia32_tileloadd64_internal">,
+              Intrinsic<[llvm_x86amx_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, llvm_i64_ty],
+                        []>;
+  def int_x86_tdpbssd_internal :
+              GCCBuiltin<"__builtin_ia32_tdpbssd_internal">,
+              Intrinsic<[llvm_x86amx_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                         llvm_x86amx_ty, llvm_x86amx_ty,
+                         llvm_x86amx_ty], []>;
+  def int_x86_tilestored64_internal :
+              GCCBuiltin<"__builtin_ia32_tilestored64_internal">,
+              Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty,
+                             llvm_i64_ty, llvm_x86amx_ty], []>;
@@ -5055,20 +5071,4 @@
               Intrinsic<[llvm_i8_ty], [], []>;
   def int_x86_senduipi : GCCBuiltin<"__builtin_ia32_senduipi">,
               Intrinsic<[], [llvm_i64_ty], []>;
-// AMX - internal intrinsics
-  def int_x86_tileloadd64_internal :
-              GCCBuiltin<"__builtin_ia32_tileloadd64_internal">,
-              Intrinsic<[llvm_v256i32_ty],
-                        [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, llvm_i64_ty],
-                        []>;
-  def int_x86_tdpbssd_internal :
-              GCCBuiltin<"__builtin_ia32_tdpbssd_internal">,
-              Intrinsic<[llvm_v256i32_ty],
-                        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
-                         llvm_v256i32_ty, llvm_v256i32_ty,
-                         llvm_v256i32_ty], []>;
-  def int_x86_tilestored64_internal :
-              GCCBuiltin<"__builtin_ia32_tilestored64_internal">,
-              Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty,
-                             llvm_i64_ty, llvm_v256i32_ty], []>;
Index: llvm/include/llvm/IR/
--- llvm/include/llvm/IR/
+++ llvm/include/llvm/IR/
@@ -255,6 +255,8 @@
 def llvm_x86mmx_ty     : LLVMType<x86mmx>;
 def llvm_ptrx86mmx_ty  : LLVMPointerType<llvm_x86mmx_ty>;         // <1 x i64>*
+def llvm_x86amx_ty     : LLVMType<x86amx>;
 def llvm_v2i1_ty       : LLVMType<v2i1>;     //   2 x i1
 def llvm_v4i1_ty       : LLVMType<v4i1>;     //   4 x i1
 def llvm_v8i1_ty       : LLVMType<v8i1>;     //   8 x i1
Index: llvm/include/llvm/IR/Intrinsics.h
--- llvm/include/llvm/IR/Intrinsics.h
+++ llvm/include/llvm/IR/Intrinsics.h
@@ -125,7 +125,8 @@
-      VecOfBitcastsToInt
+      VecOfBitcastsToInt,
+      AMX
     } Kind;
     union {
Index: llvm/include/llvm/IR/DataLayout.h
--- llvm/include/llvm/IR/DataLayout.h
+++ llvm/include/llvm/IR/DataLayout.h
@@ -690,6 +690,8 @@
   case Type::PPC_FP128TyID:
   case Type::FP128TyID:
     return TypeSize::Fixed(128);
+  case Type::X86_AMXTyID:
+    return TypeSize::Fixed(8192);
   // In memory objects this is always aligned to a higher boundary, but
   // only 80 bits contain information.
   case Type::X86_FP80TyID:
Index: llvm/include/llvm/CodeGen/
--- llvm/include/llvm/CodeGen/
+++ llvm/include/llvm/CodeGen/
@@ -196,6 +196,7 @@
 def exnref : ValueType<0  , 161>;   // WebAssembly's exnref type
 def funcref : ValueType<0  , 162>;   // WebAssembly's funcref type
 def externref : ValueType<0  , 163>;   // WebAssembly's externref type
+def x86amx : ValueType<8192, 164>;   // X86 AMX value
 def token  : ValueType<0  , 248>;   // TokenTy
Index: llvm/include/llvm/Bitcode/LLVMBitCodes.h
--- llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -168,7 +168,8 @@
+  TYPE_CODE_X86_AMX = 24 // X86 AMX
 enum OperandBundleTagCode {
Index: llvm/include/llvm-c/Core.h
--- llvm/include/llvm-c/Core.h
+++ llvm/include/llvm-c/Core.h
@@ -160,6 +160,7 @@
   LLVMVectorTypeKind,    /**< Fixed width SIMD vector type */
   LLVMMetadataTypeKind,  /**< Metadata */
   LLVMX86_MMXTypeKind,   /**< X86 MMX */
+  LLVMX86_AMXTypeKind,   /**< X86 AMX */
   LLVMTokenTypeKind,     /**< Tokens */
   LLVMScalableVectorTypeKind, /**< Scalable SIMD vector type */
   LLVMBFloatTypeKind     /**< 16 bit brain floating point type */
@@ -1493,6 +1494,11 @@
 LLVMTypeRef LLVMX86MMXTypeInContext(LLVMContextRef C);
+ * Create a X86 AMX type in a context.
+ */
+LLVMTypeRef LLVMX86AMXTypeInContext(LLVMContextRef C);
  * Create a token type in a context.
@@ -1510,6 +1516,7 @@
 LLVMTypeRef LLVMVoidType(void);
 LLVMTypeRef LLVMLabelType(void);
 LLVMTypeRef LLVMX86MMXType(void);
+LLVMTypeRef LLVMX86AMXType(void);
  * @}
Index: clang/test/CodeGen/X86/amx_api.c
--- clang/test/CodeGen/X86/amx_api.c
+++ clang/test/CodeGen/X86/amx_api.c
@@ -11,8 +11,8 @@
 // This is an example code and integration test.
 void test_api(int cond, short row, short col) {
   //CHECK-LABEL: @test_api
-  //CHECK: call <256 x i32> @llvm.x86.tileloadd64.internal
-  //CHECK: call <256 x i32> @llvm.x86.tdpbssd.internal
+  //CHECK: call x86_amx @llvm.x86.tileloadd64.internal
+  //CHECK: call x86_amx @llvm.x86.tdpbssd.internal
   //CHECK: call void @llvm.x86.tilestored64.internal
   __tile1024i a = {row, 8};
   __tile1024i b = {8, col};
@@ -33,19 +33,22 @@
 void test_tile_loadd(short row, short col) {
   //CHECK-LABEL: @test_tile_loadd
-  //CHECK: call <256 x i32> @llvm.x86.tileloadd64.internal
+  //CHECK: call x86_amx @llvm.x86.tileloadd64.internal
+  //CHECK-NEXT: {{%.*}} = bitcast x86_amx {{%.*}} to <256 x i32>
   __tile1024i a = {row, col};
   __tile_loadd(&a, buf, STRIDE);
 void test_tile_dpbsud(__tile1024i a, __tile1024i b, __tile1024i c) {
   //CHECK-LABEL: @test_tile_dpbsud
-  //CHECK: call <256 x i32> @llvm.x86.tdpbssd.internal
+  //CHECK: call x86_amx @llvm.x86.tdpbssd.internal
+  //CHECK-NEXT: {{%.*}} = bitcast x86_amx {{%.*}} to <256 x i32>
   __tile_dpbsud(&c, a, b);
 void test_tile_stored(__tile1024i c) {
   //CHECK-LABEL: @test_tile_stored
-  //CHECK: call void @llvm.x86.tilestored64.internal
+  //CHECK: {{%.*}} = bitcast <256 x i32> {{%.*}} to x86_amx
+  //CHECK-NEXT: call void @llvm.x86.tilestored64.internal
   __tile_stored(buf, STRIDE, c);
cfe-commits mailing list

Reply via email to