[llvm] 942ec5c - [X86][AMX] combine tile cast and load/store instruction.

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 28 00:08:58 PDT 2022


Author: Luo, Yuanke
Date: 2022-04-28T14:55:21+08:00
New Revision: 942ec5c36d922b29d0e7200082b533018d42421a

URL: https://github.com/llvm/llvm-project/commit/942ec5c36d922b29d0e7200082b533018d42421a
DIFF: https://github.com/llvm/llvm-project/commit/942ec5c36d922b29d0e7200082b533018d42421a.diff

LOG: [X86][AMX] combine tile cast and load/store instruction.

The `llvm.x86.cast.tile.to.vector` intrinsic is lowered to
`llvm.x86.tilestored64.internal` and `load <256 x i32>`. The
`llvm.x86.cast.vector.to.tile` is lowered to `store <256 x i32>` and
`llvm.x86.tileloadd64.internal`. When `llvm.x86.cast.tile.to.vector` is
used by `store <256 x i32>` or `load <256 x i32>` is used by
`llvm.x86.cast.vector.to.tile`, they can be combined by
`llvm.x86.tilestored64.internal` and `llvm.x86.tileloadd64.internal`.

Differential Revision: https://reviews.llvm.org/D124378

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86LowerAMXType.cpp
    llvm/test/CodeGen/X86/AMX/amx-combine.ll
    llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index 95cb783634b2b..0ad3e6c9d81cd 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -74,7 +74,7 @@ static bool isAMXCast(Instruction *II) {
          match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
 }
 
-static bool isAMXIntrinsic(User *I) {
+static bool isAMXIntrinsic(Value *I) {
   auto *II = dyn_cast<IntrinsicInst>(I);
   if (!II)
     return false;
@@ -908,6 +908,99 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi(
   return true;
 }
 
+// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
+// store <256 x i32> %43, <256 x i32>* %p, align 64
+// -->
+// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
+//                                           i64 64, x86_amx %42)
+static void combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
+  Value *Tile = Cast->getOperand(0);
+  // TODO: If it is cast intrinsic or phi node, we can propagate the
+  // shape information through def-use chain.
+  if (!isAMXIntrinsic(Tile))
+    return;
+  auto *II = cast<IntrinsicInst>(Tile);
+  // Tile is output from AMX intrinsic. The first operand of the
+  // intrinsic is row, the second operand of the intrinsic is column.
+  Value *Row = II->getOperand(0);
+  Value *Col = II->getOperand(1);
+  IRBuilder<> Builder(ST);
+  // Use the maximum column as stride. It must be the same with load
+  // stride.
+  Value *Stride = Builder.getInt64(64);
+  Value *I8Ptr =
+      Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
+  std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
+  Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
+}
+
+// %65 = load <256 x i32>, <256 x i32>* %p, align 64
+// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
+// -->
+// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
+//                                                   i8* %p, i64 64)
+static void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
+  Value *Row = nullptr, *Col = nullptr;
+  Use &U = *(Cast->use_begin());
+  unsigned OpNo = U.getOperandNo();
+  auto *II = cast<IntrinsicInst>(U.getUser());
+  // TODO: If it is cast intrinsic or phi node, we can propagate the
+  // shape information through def-use chain.
+  if (!isAMXIntrinsic(II))
+    return;
+  std::tie(Row, Col) = getShape(II, OpNo);
+  IRBuilder<> Builder(LD);
+  // Use the maximun column as stride.
+  Value *Stride = Builder.getInt64(64);
+  Value *I8Ptr =
+      Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
+  std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
+
+  Value *NewInst =
+      Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
+  Cast->replaceAllUsesWith(NewInst);
+}
+
+static bool combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
+  bool Change = false;
+  for (auto *Cast : Casts) {
+    IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cast);
+    // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
+    // store <256 x i32> %43, <256 x i32>* %p, align 64
+    // -->
+    // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
+    //                                           i64 64, x86_amx %42)
+    if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
+      SmallVector<Instruction *, 2> DeadStores;
+      for (User *U : Cast->users()) {
+        StoreInst *Store = dyn_cast<StoreInst>(U);
+        if (!Store)
+          continue;
+        combineCastStore(cast<IntrinsicInst>(Cast), Store);
+        DeadStores.push_back(Store);
+        Change = true;
+      }
+      for (auto *Store : DeadStores)
+        Store->eraseFromParent();
+    } else { // x86_cast_vector_to_tile
+      SmallVector<Instruction *, 2> DeadLoads;
+      LoadInst *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
+      if (!Load || !Load->hasOneUse())
+        continue;
+      // %65 = load <256 x i32>, <256 x i32>* %p, align 64
+      // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
+      // -->
+      // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
+      //                                                   i8* %p, i64 64)
+      combineLoadCast(cast<IntrinsicInst>(Cast), Load);
+      // Set the operand is null so that load instruction can be erased.
+      Cast->setOperand(0, nullptr);
+      Load->eraseFromParent();
+    }
+  }
+  return Change;
+}
+
 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
   bool Change = false;
   // Collect tile cast instruction.
@@ -949,17 +1042,22 @@ bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
   Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
   Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
 
+  SmallVector<Instruction *, 8> LiveCasts;
   auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
     for (auto *Inst : Insts) {
       if (Inst->use_empty()) {
         Inst->eraseFromParent();
         Change = true;
+      } else {
+        LiveCasts.push_back(Inst);
       }
     }
   };
 
   EraseInst(Vec2TileInsts);
   EraseInst(Tile2VecInsts);
+  Change |= combineLdSt(LiveCasts);
+  EraseInst(LiveCasts);
 
   // Handle the A->B->A cast, and there is an intervening PHI node.
   for (BasicBlock &BB : Func) {

diff  --git a/llvm/test/CodeGen/X86/AMX/amx-combine.ll b/llvm/test/CodeGen/X86/AMX/amx-combine.ll
index 2fd095f73cb92..6f349cd5a8bcf 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-combine.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-combine.ll
@@ -3,12 +3,9 @@
 
 define void @combine_store(<256 x i32> *%p) {
 ; CHECK-LABEL: @combine_store(
-; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
 ; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* [[TMP2]], i64 64, x86_amx [[T1]])
-; CHECK-NEXT:    [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
-; CHECK-NEXT:    store <256 x i32> [[TMP3]], <256 x i32>* [[P:%.*]], align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[P:%.*]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* [[TMP1]], i64 64, x86_amx [[T1]])
 ; CHECK-NEXT:    ret void
 ;
   %t1 = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
@@ -24,7 +21,8 @@ define <256 x i32> @combine_store_2user(<256 x i32> *%p) {
 ; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* [[TMP2]], i64 64, x86_amx [[T1]])
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
-; CHECK-NEXT:    store <256 x i32> [[TMP3]], <256 x i32>* [[P:%.*]], align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[P:%.*]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* [[TMP4]], i64 64, x86_amx [[T1]])
 ; CHECK-NEXT:    ret <256 x i32> [[TMP3]]
 ;
   %t1 = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
@@ -35,15 +33,27 @@ define <256 x i32> @combine_store_2user(<256 x i32> *%p) {
 
 define void @combine_load(<256 x i32> *%p, i8 *%p2) {
 ; CHECK-LABEL: @combine_load(
-; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[T1:%.*]] = load <256 x i32>, <256 x i32>* [[P:%.*]], align 64
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T1]], <256 x i32>* [[TMP1]], align 1024
-; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* [[TMP2]], i64 64)
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* [[P2:%.*]], i64 64, x86_amx [[TMP3]])
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[P:%.*]] to i8*
+; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* [[TMP1]], i64 64)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* [[P2:%.*]], i64 64, x86_amx [[TMP2]])
+; CHECK-NEXT:    ret void
+;
+  %t1 = load <256 x i32>, <256 x i32>* %p, align 64
+  %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t1)
+  call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* %p2, i64 64, x86_amx %t2)
+  ret void
+}
+
+define void @combine_cast_across_store(<256 x i32> *%p, i8 *%p2) {
+; CHECK-LABEL: @combine_cast_across_store(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <256 x i32>* [[P:%.*]] to i8*
+; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* [[TMP1]], i64 64)
+; CHECK-NEXT:    store <256 x i32> zeroinitializer, <256 x i32>* [[P]], align 64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* [[P2:%.*]], i64 64, x86_amx [[TMP2]])
 ; CHECK-NEXT:    ret void
 ;
   %t1 = load <256 x i32>, <256 x i32>* %p, align 64
+  store <256 x i32> zeroinitializer, <256 x i32>* %p, align 64
   %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t1)
   call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* %p2, i64 64, x86_amx %t2)
   ret void

diff  --git a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
index 6080344f358d8..129d515ff078d 100644
--- a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
+++ b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
@@ -77,7 +77,8 @@ define dso_local <256 x i32> @test_amx_bitcast_store(<256 x i32>* %out, i16 %m,
 ; CHECK-NEXT:    [[TMP2:%.*]] = sext i16 [[M]] to i64
 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], i8* [[TMP1]], i64 [[TMP2]], x86_amx [[T1]])
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024
-; CHECK-NEXT:    store <256 x i32> [[TMP3]], <256 x i32>* [[OUT:%.*]], align 1024
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[OUT:%.*]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], i8* [[TMP4]], i64 64, x86_amx [[T1]])
 ; CHECK-NEXT:    ret <256 x i32> [[TMP3]]
 ;
 entry:
@@ -127,20 +128,16 @@ entry:
 
 define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i64 %2) local_unnamed_addr {
 ; CHECK-LABEL: @__tile_loadd(
-; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0
-; CHECK-NEXT:    [[TMP6:%.*]] = load i16, i16* [[TMP5]], align 64
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1
-; CHECK-NEXT:    [[TMP8:%.*]] = load i16, i16* [[TMP7]], align 2
-; CHECK-NEXT:    [[TMP9:%.*]] = shl i64 [[TMP2:%.*]], 32
-; CHECK-NEXT:    [[TMP10:%.*]] = ashr exact i64 [[TMP9]], 32
-; CHECK-NEXT:    [[TMP11:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP1:%.*]], i64 [[TMP10]])
-; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
-; CHECK-NEXT:    [[TMP13:%.*]] = sext i16 [[TMP8]] to i64
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP12]], i64 [[TMP13]], x86_amx [[TMP11]])
-; CHECK-NEXT:    [[TMP14:%.*]] = load <256 x i32>, <256 x i32>* [[TMP4]], align 1024
-; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2
-; CHECK-NEXT:    store <256 x i32> [[TMP14]], <256 x i32>* [[TMP15]], align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2
+; CHECK-NEXT:    [[TMP8:%.*]] = shl i64 [[TMP2:%.*]], 32
+; CHECK-NEXT:    [[TMP9:%.*]] = ashr exact i64 [[TMP8]], 32
+; CHECK-NEXT:    [[TMP10:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP1:%.*]], i64 [[TMP9]])
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64, x86_amx [[TMP10]])
 ; CHECK-NEXT:    ret void
 ;
   %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 0
@@ -158,41 +155,25 @@ define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i6
 
 define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr {
 ; CHECK-LABEL: @__tile_dpbssd(
-; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP5:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP6:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP7:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0
-; CHECK-NEXT:    [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 64
-; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1
-; CHECK-NEXT:    [[TMP11:%.*]] = load i16, i16* [[TMP10]], align 2
-; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1
-; CHECK-NEXT:    [[TMP13:%.*]] = load i16, i16* [[TMP12]], align 2
-; CHECK-NEXT:    [[TMP14:%.*]] = udiv i16 [[TMP13]], 4
-; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP16:%.*]] = load <256 x i32>, <256 x i32>* [[TMP15]], align 64
-; CHECK-NEXT:    [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP7]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[TMP16]], <256 x i32>* [[TMP7]], align 1024
-; CHECK-NEXT:    [[TMP18:%.*]] = sext i16 [[TMP11]] to i64
-; CHECK-NEXT:    [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP11]], i8* [[TMP17]], i64 [[TMP18]])
-; CHECK-NEXT:    [[TMP20:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP21:%.*]] = load <256 x i32>, <256 x i32>* [[TMP20]], align 64
-; CHECK-NEXT:    [[TMP22:%.*]] = bitcast <256 x i32>* [[TMP6]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[TMP21]], <256 x i32>* [[TMP6]], align 1024
-; CHECK-NEXT:    [[TMP23:%.*]] = sext i16 [[TMP13]] to i64
-; CHECK-NEXT:    [[TMP24:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP13]], i8* [[TMP22]], i64 [[TMP23]])
-; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP26:%.*]] = load <256 x i32>, <256 x i32>* [[TMP25]], align 64
-; CHECK-NEXT:    [[TMP27:%.*]] = bitcast <256 x i32>* [[TMP5]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[TMP26]], <256 x i32>* [[TMP5]], align 1024
-; CHECK-NEXT:    [[TMP28:%.*]] = sext i16 [[TMP11]] to i64
-; CHECK-NEXT:    [[TMP29:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP14]], i16 [[TMP11]], i8* [[TMP27]], i64 [[TMP28]])
-; CHECK-NEXT:    [[TMP30:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP9]], i16 [[TMP11]], i16 [[TMP13]], x86_amx [[TMP19]], x86_amx [[TMP24]], x86_amx [[TMP29]])
-; CHECK-NEXT:    [[TMP31:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
-; CHECK-NEXT:    [[TMP32:%.*]] = sext i16 [[TMP11]] to i64
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP9]], i16 [[TMP11]], i8* [[TMP31]], i64 [[TMP32]], x86_amx [[TMP30]])
-; CHECK-NEXT:    [[TMP33:%.*]] = load <256 x i32>, <256 x i32>* [[TMP4]], align 1024
-; CHECK-NEXT:    store <256 x i32> [[TMP33]], <256 x i32>* [[TMP15]], align 64
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2
+; CHECK-NEXT:    [[TMP10:%.*]] = udiv i16 [[TMP9]], 4
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
+; CHECK-NEXT:    [[TMP13:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64)
+; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP14]] to i8*
+; CHECK-NEXT:    [[TMP16:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP15]], i64 64)
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP18:%.*]] = bitcast <256 x i32>* [[TMP17]] to i8*
+; CHECK-NEXT:    [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP10]], i16 [[TMP7]], i8* [[TMP18]], i64 64)
+; CHECK-NEXT:    [[TMP20:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], x86_amx [[TMP13]], x86_amx [[TMP16]], x86_amx [[TMP19]])
+; CHECK-NEXT:    [[TMP21:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP21]], i64 64, x86_amx [[TMP20]])
 ; CHECK-NEXT:    ret void
 ;
   %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0
@@ -218,32 +199,16 @@ define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct._
 
 define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
 ; CHECK-LABEL: @__tile_dpbsud(
-; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP2:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP3:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
-; CHECK-NEXT:    [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
-; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
-; CHECK-NEXT:    [[TMP7:%.*]] = sext i16 [[K]] to i64
-; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
-; CHECK-NEXT:    [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
-; CHECK-NEXT:    [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
-; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
-; CHECK-NEXT:    [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
-; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
-; CHECK-NEXT:    [[TMP13:%.*]] = sext i16 [[N]] to i64
-; CHECK-NEXT:    [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
-; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
-; CHECK-NEXT:    [[TMP16:%.*]] = sext i16 [[N]] to i64
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
-; CHECK-NEXT:    [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
-; CHECK-NEXT:    store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
@@ -260,32 +225,16 @@ define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <
 
 define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
 ; CHECK-LABEL: @__tile_dpbusd(
-; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP2:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP3:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
-; CHECK-NEXT:    [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
-; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
-; CHECK-NEXT:    [[TMP7:%.*]] = sext i16 [[K]] to i64
-; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
-; CHECK-NEXT:    [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
-; CHECK-NEXT:    [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
-; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
-; CHECK-NEXT:    [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
-; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
-; CHECK-NEXT:    [[TMP13:%.*]] = sext i16 [[N]] to i64
-; CHECK-NEXT:    [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
-; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
-; CHECK-NEXT:    [[TMP16:%.*]] = sext i16 [[N]] to i64
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
-; CHECK-NEXT:    [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
-; CHECK-NEXT:    store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
@@ -302,32 +251,16 @@ define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <
 
 define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
 ; CHECK-LABEL: @__tile_dpbuud(
-; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP2:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP3:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
-; CHECK-NEXT:    [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
-; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
-; CHECK-NEXT:    [[TMP7:%.*]] = sext i16 [[K]] to i64
-; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
-; CHECK-NEXT:    [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
-; CHECK-NEXT:    [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
-; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
-; CHECK-NEXT:    [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
-; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
-; CHECK-NEXT:    [[TMP13:%.*]] = sext i16 [[N]] to i64
-; CHECK-NEXT:    [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
-; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
-; CHECK-NEXT:    [[TMP16:%.*]] = sext i16 [[N]] to i64
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
-; CHECK-NEXT:    [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
-; CHECK-NEXT:    store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
@@ -344,32 +277,16 @@ define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <
 
 define dso_local void @__tile_dpbf16ps(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) {
 ; CHECK-LABEL: @__tile_dpbf16ps(
-; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP2:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP3:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4
-; CHECK-NEXT:    [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64
-; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024
-; CHECK-NEXT:    [[TMP7:%.*]] = sext i16 [[K]] to i64
-; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]])
-; CHECK-NEXT:    [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024
-; CHECK-NEXT:    [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64
-; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]])
-; CHECK-NEXT:    [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64
-; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024
-; CHECK-NEXT:    [[TMP13:%.*]] = sext i16 [[N]] to i64
-; CHECK-NEXT:    [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]])
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]])
-; CHECK-NEXT:    [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8*
-; CHECK-NEXT:    [[TMP16:%.*]] = sext i16 [[N]] to i64
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]])
-; CHECK-NEXT:    [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024
-; CHECK-NEXT:    store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8*
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64)
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8*
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64)
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8*
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64)
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, <256 x i32>* %pa, align 64
@@ -386,20 +303,16 @@ define dso_local void @__tile_dpbf16ps(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc,
 
 define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr {
 ; CHECK-LABEL: @__tile_stored(
-; CHECK-NEXT:    [[TMP4:%.*]] = alloca <256 x i32>, align 64
-; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0
-; CHECK-NEXT:    [[TMP6:%.*]] = load i16, i16* [[TMP5]], align 64
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1
-; CHECK-NEXT:    [[TMP8:%.*]] = load i16, i16* [[TMP7]], align 2
-; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP10:%.*]] = load <256 x i32>, <256 x i32>* [[TMP9]], align 64
-; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8*
-; CHECK-NEXT:    store <256 x i32> [[TMP10]], <256 x i32>* [[TMP4]], align 1024
-; CHECK-NEXT:    [[TMP12:%.*]] = sext i16 [[TMP8]] to i64
-; CHECK-NEXT:    [[TMP13:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP11]], i64 [[TMP12]])
-; CHECK-NEXT:    [[TMP14:%.*]] = shl i64 [[TMP1:%.*]], 32
-; CHECK-NEXT:    [[TMP15:%.*]] = ashr exact i64 [[TMP14]], 32
-; CHECK-NEXT:    tail call void @llvm.x86.tilestored64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP0:%.*]], i64 [[TMP15]], x86_amx [[TMP13]])
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP8]] to i8*
+; CHECK-NEXT:    [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP9]], i64 64)
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP1:%.*]], 32
+; CHECK-NEXT:    [[TMP12:%.*]] = ashr exact i64 [[TMP11]], 32
+; CHECK-NEXT:    tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP12]], x86_amx [[TMP10]])
 ; CHECK-NEXT:    ret void
 ;
   %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0


        


More information about the llvm-commits mailing list