[llvm] [X86][AMX] Combine constant zero vector and AMX cast to tilezero (PR #92384)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Thu May 16 04:21:46 PDT 2024


https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/92384

Found this problem when investigating #91207

>From 0a33db0fae8db4f3f430b26909583d8e2188f2df Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Thu, 16 May 2024 19:17:33 +0800
Subject: [PATCH] [X86][AMX] Combine constant zero vector and AMX cast to
 tilezero

Found this problem when investigating #91207
---
 llvm/lib/Target/X86/X86LowerAMXType.cpp     | 31 ++++++++
 llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll | 86 +++++----------------
 2 files changed, 49 insertions(+), 68 deletions(-)

diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index b69058787a4e2..539a9b2e4b6d2 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -709,6 +709,7 @@ class X86LowerAMXCast {
   X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
   bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
   bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
+  bool combineTilezero(IntrinsicInst *Cast);
   bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
   bool combineAMXcast(TargetLibraryInfo *TLI);
   bool transformAMXCast(IntrinsicInst *AMXCast);
@@ -988,6 +989,27 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
   return EraseLoad;
 }
 
+// %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
+// -->
+// %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
+bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {
+  Value *Row = nullptr, *Col = nullptr;
+  Use &U = *(Cast->use_begin());
+  unsigned OpNo = U.getOperandNo();
+  auto *II = cast<IntrinsicInst>(U.getUser());
+  if (!isAMXIntrinsic(II))
+    return false;
+
+  std::tie(Row, Col) = getShape(II, OpNo);
+  std::array<Value *, 2> Args = {Row, Col};
+
+  IRBuilder<> Builder(Cast);
+  Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal,
+                                           std::nullopt, Args);
+  Cast->replaceAllUsesWith(NewInst);
+  return true;
+}
+
 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
   bool Change = false;
   for (auto *Cast : Casts) {
@@ -1011,6 +1033,14 @@ bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
       for (auto *Store : DeadStores)
         Store->eraseFromParent();
     } else { // x86_cast_vector_to_tile
+      //  %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
+      //  -->
+      //  %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
+      if (dyn_cast<ConstantAggregateZero>(Cast->getOperand(0))) {
+        Change |= combineTilezero(cast<IntrinsicInst>(Cast));
+        continue;
+      }
+
       SmallVector<Instruction *, 2> DeadLoads;
       auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
       if (!Load || !Load->hasOneUse())
@@ -1024,6 +1054,7 @@ bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
         // Set the operand is null so that load instruction can be erased.
         Cast->setOperand(0, nullptr);
         Load->eraseFromParent();
+        Change = true;
       }
     }
   }
diff --git a/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll b/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll
index 7511e5953dac1..5efdba81e76be 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll
@@ -52,26 +52,13 @@ declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_
 declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
 declare void @llvm.x86.tilestored64.internal(i16, i16, ptr, i64, x86_amx)
 
-define void @PR90954(ptr %0, ptr %1, i32 %2) {
+define void @PR90954(ptr %0, ptr %1, i32 %2) nounwind {
 ; CHECK-LABEL: PR90954:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    pushq %rbp
-; CHECK-NEXT:    .cfi_def_cfa_offset 16
-; CHECK-NEXT:    .cfi_offset %rbp, -16
-; CHECK-NEXT:    movq %rsp, %rbp
-; CHECK-NEXT:    .cfi_def_cfa_register %rbp
-; CHECK-NEXT:    pushq %r15
 ; CHECK-NEXT:    pushq %r14
-; CHECK-NEXT:    pushq %r13
-; CHECK-NEXT:    pushq %r12
 ; CHECK-NEXT:    pushq %rbx
-; CHECK-NEXT:    andq $-1024, %rsp # imm = 0xFC00
-; CHECK-NEXT:    subq $5120, %rsp # imm = 0x1400
-; CHECK-NEXT:    .cfi_offset %rbx, -56
-; CHECK-NEXT:    .cfi_offset %r12, -48
-; CHECK-NEXT:    .cfi_offset %r13, -40
-; CHECK-NEXT:    .cfi_offset %r14, -32
-; CHECK-NEXT:    .cfi_offset %r15, -24
+; CHECK-NEXT:    subq $2912, %rsp # imm = 0xB60
 ; CHECK-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; CHECK-NEXT:    vmovups %zmm0, {{[0-9]+}}(%rsp)
 ; CHECK-NEXT:    movb $1, {{[0-9]+}}(%rsp)
@@ -87,29 +74,26 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) {
 ; CHECK-NEXT:    movw $64, %cx
 ; CHECK-NEXT:    movw $16, %di
 ; CHECK-NEXT:    movb $1, %r8b
-; CHECK-NEXT:    movl $64, %r9d
-; CHECK-NEXT:    leaq {{[0-9]+}}(%rsp), %r10
-; CHECK-NEXT:    leaq {{[0-9]+}}(%rsp), %r11
-; CHECK-NEXT:    xorl %ebx, %ebx
-; CHECK-NEXT:    xorl %r14d, %r14d
+; CHECK-NEXT:    xorl %r9d, %r9d
+; CHECK-NEXT:    xorl %r10d, %r10d
 ; CHECK-NEXT:    jmp .LBB1_1
 ; CHECK-NEXT:    .p2align 4, 0x90
 ; CHECK-NEXT:  .LBB1_5: # in Loop: Header=BB1_1 Depth=1
-; CHECK-NEXT:    incq %r14
-; CHECK-NEXT:    addl %edx, %ebx
+; CHECK-NEXT:    incq %r10
+; CHECK-NEXT:    addl %edx, %r9d
 ; CHECK-NEXT:  .LBB1_1: # =>This Loop Header: Depth=1
 ; CHECK-NEXT:    # Child Loop BB1_2 Depth 2
-; CHECK-NEXT:    movslq %ebx, %r15
-; CHECK-NEXT:    leaq (%rsi,%r15,4), %r15
-; CHECK-NEXT:    xorl %r12d, %r12d
-; CHECK-NEXT:    xorl %r13d, %r13d
+; CHECK-NEXT:    movslq %r9d, %r11
+; CHECK-NEXT:    leaq (%rsi,%r11,4), %r11
+; CHECK-NEXT:    xorl %ebx, %ebx
+; CHECK-NEXT:    xorl %r14d, %r14d
 ; CHECK-NEXT:    jmp .LBB1_2
 ; CHECK-NEXT:    .p2align 4, 0x90
 ; CHECK-NEXT:  .LBB1_4: # in Loop: Header=BB1_2 Depth=2
-; CHECK-NEXT:    tilestored %tmm1, (%r15,%rax)
-; CHECK-NEXT:    incq %r13
-; CHECK-NEXT:    addq $64, %r15
-; CHECK-NEXT:    decq %r12
+; CHECK-NEXT:    tilestored %tmm1, (%r11,%rax)
+; CHECK-NEXT:    incq %r14
+; CHECK-NEXT:    addq $64, %r11
+; CHECK-NEXT:    decq %rbx
 ; CHECK-NEXT:    je .LBB1_5
 ; CHECK-NEXT:  .LBB1_2: # Parent Loop BB1_1 Depth=1
 ; CHECK-NEXT:    # => This Inner Loop Header: Depth=2
@@ -118,46 +102,12 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) {
 ; CHECK-NEXT:    testb %r8b, %r8b
 ; CHECK-NEXT:    jne .LBB1_4
 ; CHECK-NEXT:  # %bb.3: # in Loop: Header=BB1_2 Depth=2
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    tileloadd (%r10,%r9), %tmm1
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    vmovaps %zmm0, {{[0-9]+}}(%rsp)
-; CHECK-NEXT:    tileloadd (%r11,%r9), %tmm2
+; CHECK-NEXT:    tilezero %tmm1
+; CHECK-NEXT:    tilezero %tmm2
 ; CHECK-NEXT:    tdpbf16ps %tmm2, %tmm1, %tmm0
-; CHECK-NEXT:    movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; CHECK-NEXT:    movabsq $64, %rax
-; CHECK-NEXT:    tilestored %tmm0, 3072(%rsp,%rax) # 1024-byte Folded Spill
+; CHECK-NEXT:    movabsq $64, %rbp
+; CHECK-NEXT:    tilestored %tmm0, 896(%rsp,%rbp) # 1024-byte Folded Spill
 ; CHECK-NEXT:    tileloadd {{[-0-9]+}}(%r{{[sb]}}p), %tmm1 # 1024-byte Folded Reload
-; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
 ; CHECK-NEXT:    jmp .LBB1_4
   %4 = shl i32 %2, 4
   %5 = icmp eq i64 0, 0



More information about the llvm-commits mailing list