[llvm] [X86][AMX] Move Stride close to its use (PR #174095)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 31 06:21:24 PST 2025


https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/174095

Fixes: #174066

>From ee38765b36b1b5e9195b9a43207fb1740e007d81 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Wed, 31 Dec 2025 22:08:01 +0800
Subject: [PATCH] [X86][AMX] Move Stride close to its use

Fixes: #174066
---
 llvm/lib/Target/X86/X86LowerAMXType.cpp       |  4 +-
 llvm/test/CodeGen/X86/AMX/amx-combine.ll      |  2 +-
 .../X86/AMX/lat-transform-amx-bitcast.ll      | 39 +++++++++++++++++++
 3 files changed, 42 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index d93bcd31c5721..505cddeceb9bc 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -994,8 +994,6 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
     return false;
   std::tie(Row, Col) = getShape(II, OpNo);
   IRBuilder<> Builder(LD);
-  // Stride should be equal to col(measured by bytes)
-  Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
   Value *I8Ptr;
 
   // To save compiling time, we create doninator tree when it is really
@@ -1015,6 +1013,8 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
   } else {
     I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
   }
+  // Stride should be equal to col(measured by bytes)
+  Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
 
   Value *NewInst =
diff --git a/llvm/test/CodeGen/X86/AMX/amx-combine.ll b/llvm/test/CodeGen/X86/AMX/amx-combine.ll
index 72e072dd15761..b873a9463f2f5 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-combine.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-combine.ll
@@ -98,7 +98,6 @@ define void @test_tile_dpbssd(ptr byval(%struct.__tile1024i_str) align 64 %a, pt
 ; CHECK-NEXT:    [[B_ROW_PTR:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i64 2
 ; CHECK-NEXT:    [[B_ROW:%.*]] = load i16, ptr [[B_ROW_PTR]], align 2
 ; CHECK-NEXT:    [[B_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 64
-; CHECK-NEXT:    [[TMP1:%.*]] = sext i16 [[B_ROW]] to i64
 ; CHECK-NEXT:    [[B_TILE:%.*]] = load <256 x i32>, ptr [[B_TILE_PTR]], align 64
 ; CHECK-NEXT:    store <256 x i32> [[B_TILE]], ptr [[TMP0]], align 1024
 ; CHECK-NEXT:    [[A_ROW:%.*]] = load i16, ptr [[A:%.*]], align 64
@@ -111,6 +110,7 @@ define void @test_tile_dpbssd(ptr byval(%struct.__tile1024i_str) align 64 %a, pt
 ; CHECK-NEXT:    [[C_TILE_PTR:%.*]] = getelementptr inbounds [[STRUCT___TILE1024I_STR:%.*]], ptr [[C:%.*]], i64 0, i32 3
 ; CHECK-NEXT:    [[TMP5:%.*]] = sext i16 [[B_ROW]] to i64
 ; CHECK-NEXT:    [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_ROW]], ptr [[C_TILE_PTR]], i64 [[TMP5]])
+; CHECK-NEXT:    [[TMP1:%.*]] = sext i16 [[B_ROW]] to i64
 ; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP2]], i16 [[B_ROW]], ptr [[TMP0]], i64 [[TMP1]])
 ; CHECK-NEXT:    [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_ROW]], i16 [[A_COL]], x86_amx [[TMP6]], x86_amx [[TMP4]], x86_amx [[TMP7]])
 ; CHECK-NEXT:    ret void
diff --git a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
index 0b419bb8573d5..b819873608ce3 100644
--- a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
+++ b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
@@ -354,6 +354,45 @@ exit:
   ret void
 }
 
+ at b = dso_local local_unnamed_addr global i16 0, align 2
+ at c = dso_local local_unnamed_addr global i16 0, align 2
+ at d = dso_local local_unnamed_addr global i16 0, align 2
+ at e = dso_local local_unnamed_addr global <256 x i32> zeroinitializer, align 1024
+ at f = dso_local local_unnamed_addr global <256 x i32> zeroinitializer, align 1024
+
+define void @pr166653(ptr %0) {
+; CHECK-LABEL: @pr166653(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[H:%.*]] = load <256 x i32>, ptr [[TMP0:%.*]], align 1024
+; CHECK-NEXT:    store <256 x i32> [[H]], ptr [[TMP1]], align 1024
+; CHECK-NEXT:    [[TMP2:%.*]] = load i16, ptr @b, align 2
+; CHECK-NEXT:    [[TMP3:%.*]] = load i16, ptr @c, align 2
+; CHECK-NEXT:    [[TMP4:%.*]] = load i16, ptr @d, align 2
+; CHECK-NEXT:    [[TMP5:%.*]] = udiv i16 [[TMP4]], 4
+; CHECK-NEXT:    [[TMP6:%.*]] = sext i16 [[TMP3]] to i64
+; CHECK-NEXT:    [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP2]], i16 [[TMP3]], ptr @e, i64 [[TMP6]])
+; CHECK-NEXT:    [[TMP8:%.*]] = sext i16 [[TMP4]] to i64
+; CHECK-NEXT:    [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP2]], i16 [[TMP4]], ptr @f, i64 [[TMP8]])
+; CHECK-NEXT:    [[TMP10:%.*]] = sext i16 [[TMP3]] to i64
+; CHECK-NEXT:    [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP3]], ptr [[TMP1]], i64 [[TMP10]])
+; CHECK-NEXT:    [[TMP12:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP2]], i16 [[TMP3]], i16 [[TMP4]], x86_amx [[TMP7]], x86_amx [[TMP9]], x86_amx [[TMP11]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %h = load <256 x i32>, ptr %0, align 1024
+  %1 = load i16, ptr @b, align 2
+  %2 = load i16, ptr @c, align 2
+  %3 = load i16, ptr @d, align 2
+  %4 = load <256 x i32>, ptr @e, align 1024
+  %5 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %4)
+  %6 = load <256 x i32>, ptr @f, align 1024
+  %7 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %6)
+  %8 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %h)
+  %9 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %3, x86_amx %5, x86_amx %7, x86_amx %8)
+  ret void
+}
+
 declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
 declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, ptr, i64)
 declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)



More information about the llvm-commits mailing list