[PATCH] D153002: [X86][AMX] set Stride to Tile's Col when doing combine amxcast and store into tilestore: %tile = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* %src_ptr, i64 64) %vec = call <256 x i8> @llvm.x86.cast.tile.to.vector.v256i8(x86_amx...
Bing Yu via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 15 01:29:32 PDT 2023
yubing created this revision.
Herald added subscribers: pengfei, hiraditya.
Herald added a project: All.
yubing requested review of this revision.
Herald added a project: LLVM.
Herald added a subscriber: llvm-commits.
...%tile) store <256 x i8> %vec, <256 x i8>* %dst_ptr, align 256 => %tile = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* %src_ptr, i64 64) %stride = sext i16 32 to i64 call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* %dst_ptr, i64 32, x86_amx %tile)
Repository:
rG LLVM Github Monorepo
https://reviews.llvm.org/D153002
Files:
llvm/lib/Target/X86/X86LowerAMXType.cpp
llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
Index: llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
===================================================================
--- llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
+++ llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll
@@ -74,7 +74,8 @@
; CHECK-NEXT: [[TMP1:%.*]] = sext i16 [[M]] to i64
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], ptr [[TMP0]], i64 [[TMP1]], x86_amx [[T1]])
; CHECK-NEXT: [[TMP2:%.*]] = load <256 x i32>, ptr [[TMP0]], align 1024
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], ptr [[OUT:%.*]], i64 64, x86_amx [[T1]])
+; CHECK-NEXT: [[TMP3:%.*]] = sext i16 [[M]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], ptr [[OUT:%.*]], i64 [[TMP3]], x86_amx [[T1]])
; CHECK-NEXT: ret <256 x i32> [[TMP2]]
;
entry:
@@ -129,7 +130,8 @@
; CHECK-NEXT: [[TMP8:%.*]] = ashr exact i64 [[TMP7]], 32
; CHECK-NEXT: [[TMP9:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP1:%.*]], i64 [[TMP8]])
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP0]], i64 0, i32 2
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 64, x86_amx [[TMP9]])
+; CHECK-NEXT: [[TMP11:%.*]] = sext i16 [[TMP6]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 [[TMP11]], x86_amx [[TMP9]])
; CHECK-NEXT: ret void
;
%4 = load i16, ptr %0, align 64
@@ -159,7 +161,8 @@
; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP2]], i64 0, i32 2
; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP6]], ptr [[TMP14]], i64 64)
; CHECK-NEXT: [[TMP16:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP4]], i16 [[TMP6]], i16 [[TMP8]], x86_amx [[TMP11]], x86_amx [[TMP13]], x86_amx [[TMP15]])
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 64, x86_amx [[TMP16]])
+; CHECK-NEXT: [[TMP17:%.*]] = sext i16 [[TMP6]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 [[TMP17]], x86_amx [[TMP16]])
; CHECK-NEXT: ret void
;
%4 = load i16, ptr %1, align 64
@@ -189,7 +192,8 @@
; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64)
; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64)
; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]])
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT: [[TMP5:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP5]], x86_amx [[T6]])
; CHECK-NEXT: ret void
;
%t0 = load <256 x i32>, ptr %pa, align 64
@@ -211,7 +215,8 @@
; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64)
; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64)
; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]])
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT: [[TMP5:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP5]], x86_amx [[T6]])
; CHECK-NEXT: ret void
;
%t0 = load <256 x i32>, ptr %pa, align 64
@@ -233,7 +238,8 @@
; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64)
; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64)
; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]])
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT: [[TMP5:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP5]], x86_amx [[T6]])
; CHECK-NEXT: ret void
;
%t0 = load <256 x i32>, ptr %pa, align 64
@@ -255,7 +261,8 @@
; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64)
; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64)
; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]])
-; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]])
+; CHECK-NEXT: [[TMP5:%.*]] = sext i16 [[N]] to i64
+; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP5]], x86_amx [[T6]])
; CHECK-NEXT: ret void
;
%t0 = load <256 x i32>, ptr %pa, align 64
Index: llvm/lib/Target/X86/X86LowerAMXType.cpp
===================================================================
--- llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -934,9 +934,8 @@
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
IRBuilder<> Builder(ST);
- // Use the maximum column as stride. It must be the same with load
- // stride.
- Value *Stride = Builder.getInt64(64);
+ // Stride should be equal to col(measured by bytes)
+ Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
Value *I8Ptr =
Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D153002.531642.patch
Type: text/x-patch
Size: 6386 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20230615/356b36ae/attachment.bin>
More information about the llvm-commits
mailing list