[llvm] [X86][AMX] Move Stride close to its use (PR #174095)
Phoebe Wang via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 31 06:21:24 PST 2025
https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/174095
Fixes: #174066
>From ee38765b36b1b5e9195b9a43207fb1740e007d81 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Wed, 31 Dec 2025 22:08:01 +0800
Subject: [PATCH] [X86][AMX] Move Stride close to its use
Fixes: #174066
---
llvm/lib/Target/X86/X86LowerAMXType.cpp | 4 +-
llvm/test/CodeGen/X86/AMX/amx-combine.ll | 2 +-
.../X86/AMX/lat-transform-amx-bitcast.ll | 39 +++++++++++++++++++
3 files changed, 42 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index d93bcd31c5721..505cddeceb9bc 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -994,8 +994,6 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
return false;
std::tie(Row, Col) = getShape(II, OpNo);
IRBuilder<> Builder(LD);
- // Stride should be equal to col(measured by bytes)
- Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
Value *I8Ptr;
// To save compiling time, we create doninator tree when it is really
@@ -1015,6 +1013,8 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
} else {
I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
}
+ // Stride should be equal to col(measured by bytes)
+ Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
Value *NewInst =
diff --git a/llvm/test/CodeGen/X86/AMX/amx-combine.ll b/llvm/test/CodeGen/X86/AMX/amx-combine.ll
index 72e072dd15761..b873a9463f2f5 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-combine.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-combine.ll
@@ -98,7 +98,6 @@ define void @test_tile_dpbssd(ptr byval(%struct.__tile1024i_str) align 64 %a, pt
; CHECK-NEXT: [[B_ROW_PTR:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i64 2
; CHECK-NEXT: [[B_ROW:%.*]] = load i16, ptr [[B_ROW_PTR]], align 2
; CHECK-NEXT: [[B_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 64
-; CHECK-NEXT: [[TMP1:%.*]] = sext i16 [[B_ROW]] to i64
; CHECK-NEXT: [[B_TILE:%.*]] = load <256 x i32>, ptr [[B_TILE_PTR]], align 64
; CHECK-NEXT: store <256 x i32> [[B_TILE]], ptr [[TMP0]], align 1024
; CHECK-NEXT: [[A_ROW:%.*]] = load i16, ptr [[A:%.*]], align 64
@@ -111,6 +110,7 @@ define void @test_tile_dpbssd(ptr byval(%struct.__tile1024i_str) align 64 %a, pt
; CHECK-NEXT: [[C_TILE_PTR:%.*]] = getelementptr inbounds [[STRUCT___TILE1024I_STR:%.*]], ptr [[C:%.*]], i64 0, i32 3
; CHECK-NEXT: [[TMP5:%.*]] = sext i16 [[B_ROW]] to i64
; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_ROW]], ptr [[C_TILE_PTR]], i64 [[TMP5]])
+; CHECK-NEXT: [[TMP1:%.*]] = sext i16 [[B_ROW]] to i64
; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP2]], i16 [[B_ROW]], ptr [[TMP0]], i64 [[TMP1]])
; CHECK-NEXT: [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_ROW]], i16 [[A_COL]], x86_amx [[TMP6]], x86_amx [[TMP4]], x86_amx [[TMP7]])
; CHECK-NEXT: ret void
diff --git a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
index 0b419bb8573d5..b819873608ce3 100644
--- a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
+++ b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
@@ -354,6 +354,45 @@ exit:
ret void
}
+ at b = dso_local local_unnamed_addr global i16 0, align 2
+ at c = dso_local local_unnamed_addr global i16 0, align 2
+ at d = dso_local local_unnamed_addr global i16 0, align 2
+ at e = dso_local local_unnamed_addr global <256 x i32> zeroinitializer, align 1024
+ at f = dso_local local_unnamed_addr global <256 x i32> zeroinitializer, align 1024
+
+define void @pr166653(ptr %0) {
+; CHECK-LABEL: @pr166653(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT: [[H:%.*]] = load <256 x i32>, ptr [[TMP0:%.*]], align 1024
+; CHECK-NEXT: store <256 x i32> [[H]], ptr [[TMP1]], align 1024
+; CHECK-NEXT: [[TMP2:%.*]] = load i16, ptr @b, align 2
+; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr @c, align 2
+; CHECK-NEXT: [[TMP4:%.*]] = load i16, ptr @d, align 2
+; CHECK-NEXT: [[TMP5:%.*]] = udiv i16 [[TMP4]], 4
+; CHECK-NEXT: [[TMP6:%.*]] = sext i16 [[TMP3]] to i64
+; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP2]], i16 [[TMP3]], ptr @e, i64 [[TMP6]])
+; CHECK-NEXT: [[TMP8:%.*]] = sext i16 [[TMP4]] to i64
+; CHECK-NEXT: [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP2]], i16 [[TMP4]], ptr @f, i64 [[TMP8]])
+; CHECK-NEXT: [[TMP10:%.*]] = sext i16 [[TMP3]] to i64
+; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP3]], ptr [[TMP1]], i64 [[TMP10]])
+; CHECK-NEXT: [[TMP12:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP2]], i16 [[TMP3]], i16 [[TMP4]], x86_amx [[TMP7]], x86_amx [[TMP9]], x86_amx [[TMP11]])
+; CHECK-NEXT: ret void
+;
+entry:
+ %h = load <256 x i32>, ptr %0, align 1024
+ %1 = load i16, ptr @b, align 2
+ %2 = load i16, ptr @c, align 2
+ %3 = load i16, ptr @d, align 2
+ %4 = load <256 x i32>, ptr @e, align 1024
+ %5 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %4)
+ %6 = load <256 x i32>, ptr @f, align 1024
+ %7 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %6)
+ %8 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %h)
+ %9 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %3, x86_amx %5, x86_amx %7, x86_amx %8)
+ ret void
+}
+
declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, ptr, i64)
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
More information about the llvm-commits
mailing list