[llvm] 516e326 - [X86][AMX] set Stride to Tile's Col when doing combine amxcast and store into tilestore

Bing1 Yu via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 19 20:55:45 PDT 2023


Author: Bing1 Yu
Date: 2023-06-20T11:55:25+08:00
New Revision: 516e32678d87fea013aa972444b987d95eaef8aa

URL: https://github.com/llvm/llvm-project/commit/516e32678d87fea013aa972444b987d95eaef8aa
DIFF: https://github.com/llvm/llvm-project/commit/516e32678d87fea013aa972444b987d95eaef8aa.diff

LOG: [X86][AMX] set Stride to Tile's Col when doing combine amxcast and store into tilestore

%tile = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* %src_ptr, i64 64)
%vec = call <256 x i8> @llvm.x86.cast.tile.to.vector.v256i8(x86_amx...%tile)
store <256 x i8> %vec, <256 x i8>* %dst_ptr, align 256
=>
%tile = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* %src_ptr, i64 64)
%stride = sext i16 32 to i64
call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* %dst_ptr, i64 32, x86_amx %tile)

Reviewed By: LuoYuanke

Differential Revision: https://reviews.llvm.org/D153002

Added: 
    

Modified: 
    llvm/include/llvm/IR/IntrinsicsX86.td
    llvm/lib/Target/X86/X86LowerAMXType.cpp
    llvm/test/CodeGen/X86/AMX/amx-combine.ll
    llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td
index d8d8bc59987e1..ab735daee93dd 100644
--- a/llvm/include/llvm/IR/IntrinsicsX86.td
+++ b/llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5420,8 +5420,10 @@ let TargetPrefix = "x86" in {
                         [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
                          llvm_x86amx_ty, llvm_x86amx_ty,
                          llvm_x86amx_ty], []>;
+  // the vector size can be smaller than AMX register size (1024 bytes)
   def int_x86_cast_vector_to_tile:
       DefaultAttrsIntrinsic<[llvm_x86amx_ty], [llvm_anyvector_ty], [IntrNoMem]>;
+  // the vector size can be smaller than AMX register size (1024 bytes)
   def int_x86_cast_tile_to_vector:
       DefaultAttrsIntrinsic<[llvm_anyvector_ty], [llvm_x86amx_ty], [IntrNoMem]>;
 

diff  --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index fa9d367cf7628..0416f0f0d2ec9 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -934,9 +934,8 @@ bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
   Value *Row = II->getOperand(0);
   Value *Col = II->getOperand(1);
   IRBuilder<> Builder(ST);
-  // Use the maximum column as stride. It must be the same with load
-  // stride.
-  Value *Stride = Builder.getInt64(64);
+  // Stride should be equal to col(measured by bytes)
+  Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
   Value *I8Ptr =
       Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
@@ -962,8 +961,8 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
     return false;
   std::tie(Row, Col) = getShape(II, OpNo);
   IRBuilder<> Builder(LD);
-  // Use the maximun column as stride.
-  Value *Stride = Builder.getInt64(64);
+  // Stride should be equal to col(measured by bytes)
+  Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
   Value *I8Ptr;
 
   // To save compiling time, we create doninator tree when it is really
@@ -1089,8 +1088,14 @@ bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
 
   EraseInst(Vec2TileInsts);
   EraseInst(Tile2VecInsts);
+  LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
+                       "Vec2Tile and Tile2Vec:\n";
+             Func.dump());
   Change |= combineLdSt(LiveCasts);
   EraseInst(LiveCasts);
+  LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
+                       "AMXCast and load/store:\n";
+             Func.dump());
 
   // Handle the A->B->A cast, and there is an intervening PHI node.
   for (BasicBlock &BB : Func) {
@@ -1118,6 +1123,9 @@ bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
     Instruction *I = DeadInst.pop_back_val();
     Change |= DCEInstruction(I, DeadInst, TLI);
   }
+  LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
+                       "optimizeAMXCastFromPhi:\n";
+             Func.dump());
   return Change;
 }
 

diff  --git a/llvm/test/CodeGen/X86/AMX/amx-combine.ll b/llvm/test/CodeGen/X86/AMX/amx-combine.ll
index 0dc1f4b159838..fe21d64eb7a3a 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-combine.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-combine.ll
@@ -97,18 +97,21 @@ define void @test_tile_dpbssd(ptr byval(%struct.__tile1024i_str) align 64 %a, pt
 ; CHECK-NEXT:    [[B_ROW_PTR:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i64 2
 ; CHECK-NEXT:    [[B_ROW:%.*]] = load i16, ptr [[B_ROW_PTR]], align 2
 ; CHECK-NEXT:    [[B_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 64
+; CHECK-NEXT:    [[TMP1:%.*]] = sext i16 [[B_ROW]] to i64
 ; CHECK-NEXT:    [[B_TILE:%.*]] = load <256 x i32>, ptr [[B_TILE_PTR]], align 64
 ; CHECK-NEXT:    store <256 x i32> [[B_TILE]], ptr [[TMP0]], align 1024
 ; CHECK-NEXT:    [[A_ROW:%.*]] = load i16, ptr [[A:%.*]], align 64
 ; CHECK-NEXT:    [[A_COL_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 2
 ; CHECK-NEXT:    [[A_COL:%.*]] = load i16, ptr [[A_COL_PTR]], align 2
-; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[A_COL]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = udiv i16 [[A_COL]], 4
 ; CHECK-NEXT:    [[A_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 64
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[A_COL]], ptr [[A_TILE_PTR]], i64 64)
+; CHECK-NEXT:    [[TMP3:%.*]] = sext i16 [[A_COL]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[A_COL]], ptr [[A_TILE_PTR]], i64 [[TMP3]])
 ; CHECK-NEXT:    [[C_TILE_PTR:%.*]] = getelementptr inbounds [[STRUCT___TILE1024I_STR:%.*]], ptr [[C:%.*]], i64 0, i32 3
-; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_ROW]], ptr [[C_TILE_PTR]], i64 64)
-; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[B_ROW]], ptr [[TMP0]], i64 64)
-; CHECK-NEXT:    [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_ROW]], i16 [[A_COL]], x86_amx [[TMP3]], x86_amx [[TMP2]], x86_amx [[TMP4]])
+; CHECK-NEXT:    [[TMP5:%.*]] = sext i16 [[B_ROW]] to i64
+; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_ROW]], ptr [[C_TILE_PTR]], i64 [[TMP5]])
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP2]], i16 [[B_ROW]], ptr [[TMP0]], i64 [[TMP1]])
+; CHECK-NEXT:    [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_ROW]], i16 [[A_COL]], x86_amx [[TMP6]], x86_amx [[TMP4]], x86_amx [[TMP7]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -130,8 +133,38 @@ entry:
   ret void
 }
 
+define void @combine_v256i8amcast_with_store(i8* %src_ptr, <256 x i8>* %dst_ptr) {
+; CHECK-LABEL: @combine_v256i8amcast_with_store(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TILE:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, ptr [[SRC_PTR:%.*]], i64 64)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, ptr [[DST_PTR:%.*]], i64 32, x86_amx [[TILE]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %tile = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* %src_ptr, i64 64)
+  %vec = call <256 x i8> @llvm.x86.cast.tile.to.vector.v256i8(x86_amx %tile)
+  store <256 x i8> %vec, <256 x i8>* %dst_ptr, align 256
+  ret void
+}
+
+define void @combine_v256i8amcast_with_load(i8* %src_ptr, <256 x i8>* %dst_ptr) {
+; CHECK-LABEL: @combine_v256i8amcast_with_load(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, ptr [[SRC_PTR:%.*]], i64 32)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 8, i16 32, ptr [[DST_PTR:%.*]], i64 32, x86_amx [[TMP0]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %vec = load <256 x i8>, ptr %src_ptr, align 256
+  %tile = call x86_amx @llvm.x86.cast.vector.to.tile.v256i8(<256 x i8> %vec)
+  call void @llvm.x86.tilestored64.internal(i16 8, i16 32, <256 x i8>* %dst_ptr, i64 32, x86_amx %tile)
+  ret void
+}
+
 declare x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32>)
 declare <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx)
+declare x86_amx @llvm.x86.cast.vector.to.tile.v256i8(<256 x i8>)
+declare <256 x i8> @llvm.x86.cast.tile.to.vector.v256i8(x86_amx)
 declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
 declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, ptr, i64)
 declare void @llvm.x86.tilestored64.internal(i16, i16, ptr, i64, x86_amx)

diff  --git a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
index dc3b15e7c5503..391727d54a03a 100644
--- a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
+++ b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
@@ -74,7 +74,8 @@ define dso_local <256 x i32> @test_amx_bitcast_store(ptr %out, i16 %m, i16 %n, p
 ; CHECK-NEXT:    [[TMP1:%.*]] = sext i16 [[M]] to i64
 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], ptr [[TMP0]], i64 [[TMP1]], x86_amx [[T1]])
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <256 x i32>, ptr [[TMP0]], align 1024
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], ptr [[OUT:%.*]], i64 64, x86_amx [[T1]])
+; CHECK-NEXT:    [[TMP3:%.*]] = sext i16 [[M]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], ptr [[OUT:%.*]], i64 [[TMP3]], x86_amx [[T1]])
 ; CHECK-NEXT:    ret <256 x i32> [[TMP2]]
 ;
 entry:
@@ -129,7 +130,8 @@ define dso_local void @__tile_loadd(ptr nocapture %0, ptr %1, i64 %2) local_unna
 ; CHECK-NEXT:    [[TMP8:%.*]] = ashr exact i64 [[TMP7]], 32
 ; CHECK-NEXT:    [[TMP9:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP1:%.*]], i64 [[TMP8]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP0]], i64 0, i32 2
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 64, x86_amx [[TMP9]])
+; CHECK-NEXT:    [[TMP11:%.*]] = sext i16 [[TMP6]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 [[TMP11]], x86_amx [[TMP9]])
 ; CHECK-NEXT:    ret void
 ;
   %4 = load i16, ptr %0, align 64
@@ -153,13 +155,17 @@ define dso_local void @__tile_dpbssd(ptr nocapture %0, ptr nocapture readonly by
 ; CHECK-NEXT:    [[TMP8:%.*]] = load i16, ptr [[TMP7]], align 2
 ; CHECK-NEXT:    [[TMP9:%.*]] = udiv i16 [[TMP8]], 4
 ; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP0:%.*]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 64)
-; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP1]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP13:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP4]], i16 [[TMP8]], ptr [[TMP12]], i64 64)
-; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP2]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP6]], ptr [[TMP14]], i64 64)
-; CHECK-NEXT:    [[TMP16:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP4]], i16 [[TMP6]], i16 [[TMP8]], x86_amx [[TMP11]], x86_amx [[TMP13]], x86_amx [[TMP15]])
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 64, x86_amx [[TMP16]])
+; CHECK-NEXT:    [[TMP11:%.*]] = sext i16 [[TMP6]] to i64
+; CHECK-NEXT:    [[TMP12:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 [[TMP11]])
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP1]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP14:%.*]] = sext i16 [[TMP8]] to i64
+; CHECK-NEXT:    [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP4]], i16 [[TMP8]], ptr [[TMP13]], i64 [[TMP14]])
+; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP2]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP17:%.*]] = sext i16 [[TMP6]] to i64
+; CHECK-NEXT:    [[TMP18:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP6]], ptr [[TMP16]], i64 [[TMP17]])
+; CHECK-NEXT:    [[TMP19:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP4]], i16 [[TMP6]], i16 [[TMP8]], x86_amx [[TMP12]], x86_amx [[TMP15]], x86_amx [[TMP18]])
+; CHECK-NEXT:    [[TMP20:%.*]] = sext i16 [[TMP6]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 [[TMP20]], x86_amx [[TMP19]])
 ; CHECK-NEXT:    ret void
 ;
   %4 = load i16, ptr %1, align 64
@@ -185,11 +191,15 @@ define dso_local void @__tile_dpbssd(ptr nocapture %0, ptr nocapture readonly by
 define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, ptr %pc, ptr %pa, ptr %pb) {
 ; CHECK-LABEL: @__tile_dpbsud(
 ; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], ptr [[PA:%.*]], i64 64)
-; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64)
-; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64)
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]])
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], ptr [[PA:%.*]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N]], ptr [[PB:%.*]], i64 [[TMP4]])
+; CHECK-NEXT:    [[TMP6:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 [[TMP6]])
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP8]], x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, ptr %pa, align 64
@@ -207,11 +217,15 @@ define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, ptr %pc, ptr %pa, p
 define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, ptr %pc, ptr %pa, ptr %pb) {
 ; CHECK-LABEL: @__tile_dpbusd(
 ; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], ptr [[PA:%.*]], i64 64)
-; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64)
-; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64)
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]])
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], ptr [[PA:%.*]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N]], ptr [[PB:%.*]], i64 [[TMP4]])
+; CHECK-NEXT:    [[TMP6:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 [[TMP6]])
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP8]], x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, ptr %pa, align 64
@@ -229,11 +243,15 @@ define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, ptr %pc, ptr %pa, p
 define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, ptr %pc, ptr %pa, ptr %pb) {
 ; CHECK-LABEL: @__tile_dpbuud(
 ; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], ptr [[PA:%.*]], i64 64)
-; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64)
-; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64)
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]])
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], ptr [[PA:%.*]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N]], ptr [[PB:%.*]], i64 [[TMP4]])
+; CHECK-NEXT:    [[TMP6:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 [[TMP6]])
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP8]], x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, ptr %pa, align 64
@@ -251,11 +269,15 @@ define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, ptr %pc, ptr %pa, p
 define dso_local void @__tile_dpbf16ps(i16 %m, i16 %n, i16 %k, ptr %pc, ptr %pa, ptr %pb) {
 ; CHECK-LABEL: @__tile_dpbf16ps(
 ; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], ptr [[PA:%.*]], i64 64)
-; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64)
-; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64)
-; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]])
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i16 [[K]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], ptr [[PA:%.*]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = sext i16 [[N:%.*]] to i64
+; CHECK-NEXT:    [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N]], ptr [[PB:%.*]], i64 [[TMP4]])
+; CHECK-NEXT:    [[TMP6:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 [[TMP6]])
+; CHECK-NEXT:    [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP8]], x86_amx [[T6]])
 ; CHECK-NEXT:    ret void
 ;
   %t0 = load <256 x i32>, ptr %pa, align 64
@@ -276,10 +298,11 @@ define dso_local void @__tile_stored(ptr %0, i64 %1, ptr nocapture readonly byva
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], ptr [[TMP2]], i64 0, i32 1
 ; CHECK-NEXT:    [[TMP6:%.*]] = load i16, ptr [[TMP5]], align 2
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP2]], i64 0, i32 2
-; CHECK-NEXT:    [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP7]], i64 64)
-; CHECK-NEXT:    [[TMP9:%.*]] = shl i64 [[TMP1:%.*]], 32
-; CHECK-NEXT:    [[TMP10:%.*]] = ashr exact i64 [[TMP9]], 32
-; CHECK-NEXT:    tail call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP0:%.*]], i64 [[TMP10]], x86_amx [[TMP8]])
+; CHECK-NEXT:    [[TMP8:%.*]] = sext i16 [[TMP6]] to i64
+; CHECK-NEXT:    [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP7]], i64 [[TMP8]])
+; CHECK-NEXT:    [[TMP10:%.*]] = shl i64 [[TMP1:%.*]], 32
+; CHECK-NEXT:    [[TMP11:%.*]] = ashr exact i64 [[TMP10]], 32
+; CHECK-NEXT:    tail call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP0:%.*]], i64 [[TMP11]], x86_amx [[TMP9]])
 ; CHECK-NEXT:    ret void
 ;
   %4 = load i16, ptr %2, align 64


        


More information about the llvm-commits mailing list