[llvm-branch-commits] [llvm] AMDGPU: Handle folding vector splats of inline split f64 inline immediates (PR #140878)
    Matt Arsenault via llvm-branch-commits 
    llvm-branch-commits at lists.llvm.org
       
    Tue Jun 17 18:53:18 PDT 2025
    
    
  
https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/140878
>From 5d5a54cff58c1096973d3a9c28f728ca0afe3889 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Mon, 19 May 2025 21:51:06 +0200
Subject: [PATCH] AMDGPU: Handle folding vector splats of inline split f64
 inline immediates
Recognize a reg_sequence with 32-bit elements that produce a 64-bit
splat value. This enables folding f64 constants into mfma operands
---
 llvm/lib/Target/AMDGPU/SIFoldOperands.cpp     | 103 ++++++++++++------
 .../CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll |  41 +------
 2 files changed, 76 insertions(+), 68 deletions(-)
diff --git a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
index 565abad9b0366..8121098a97bd1 100644
--- a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
+++ b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
@@ -227,12 +227,12 @@ class SIFoldOperandsImpl {
   getRegSeqInit(SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs,
                 Register UseReg) const;
 
-  std::pair<MachineOperand *, const TargetRegisterClass *>
+  std::pair<int64_t, const TargetRegisterClass *>
   isRegSeqSplat(MachineInstr &RegSeg) const;
 
-  MachineOperand *tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
-                                     MachineOperand *SplatVal,
-                                     const TargetRegisterClass *SplatRC) const;
+  bool tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
+                          int64_t SplatVal,
+                          const TargetRegisterClass *SplatRC) const;
 
   bool tryToFoldACImm(const FoldableDef &OpToFold, MachineInstr *UseMI,
                       unsigned UseOpIdx,
@@ -966,15 +966,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
   return getRegSeqInit(*Def, Defs);
 }
 
-std::pair<MachineOperand *, const TargetRegisterClass *>
+std::pair<int64_t, const TargetRegisterClass *>
 SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
   SmallVector<std::pair<MachineOperand *, unsigned>, 32> Defs;
   const TargetRegisterClass *SrcRC = getRegSeqInit(RegSeq, Defs);
   if (!SrcRC)
     return {};
 
-  // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
-  // every other other element is 0 for 64-bit immediates)
+  bool TryToMatchSplat64 = false;
+
   int64_t Imm;
   for (unsigned I = 0, E = Defs.size(); I != E; ++I) {
     const MachineOperand *Op = Defs[I].first;
@@ -986,38 +986,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
       Imm = SubImm;
       continue;
     }
-    if (Imm != SubImm)
+
+    if (Imm != SubImm) {
+      if (I == 1 && (E & 1) == 0) {
+        // If we have an even number of inputs, there's a chance this is a
+        // 64-bit element splat broken into 32-bit pieces.
+        TryToMatchSplat64 = true;
+        break;
+      }
+
       return {}; // Can only fold splat constants
+    }
+  }
+
+  if (!TryToMatchSplat64)
+    return {Defs[0].first->getImm(), SrcRC};
+
+  // Fallback to recognizing 64-bit splats broken into 32-bit pieces
+  // (i.e. recognize every other other element is 0 for 64-bit immediates)
+  int64_t SplatVal64;
+  for (unsigned I = 0, E = Defs.size(); I != E; I += 2) {
+    const MachineOperand *Op0 = Defs[I].first;
+    const MachineOperand *Op1 = Defs[I + 1].first;
+
+    if (!Op0->isImm() || !Op1->isImm())
+      return {};
+
+    unsigned SubReg0 = Defs[I].second;
+    unsigned SubReg1 = Defs[I + 1].second;
+
+    // Assume we're going to generally encounter reg_sequences with sorted
+    // subreg indexes, so reject any that aren't consecutive.
+    if (TRI->getChannelFromSubReg(SubReg0) + 1 !=
+        TRI->getChannelFromSubReg(SubReg1))
+      return {};
+
+    int64_t MergedVal = Make_64(Op1->getImm(), Op0->getImm());
+    if (I == 0)
+      SplatVal64 = MergedVal;
+    else if (SplatVal64 != MergedVal)
+      return {};
   }
 
-  return {Defs[0].first, SrcRC};
+  const TargetRegisterClass *RC64 = TRI->getSubRegisterClass(
+      MRI->getRegClass(RegSeq.getOperand(0).getReg()), AMDGPU::sub0_sub1);
+
+  return {SplatVal64, RC64};
 }
 
-MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
-    MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
+bool SIFoldOperandsImpl::tryFoldRegSeqSplat(
+    MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
     const TargetRegisterClass *SplatRC) const {
   const MCInstrDesc &Desc = UseMI->getDesc();
   if (UseOpIdx >= Desc.getNumOperands())
-    return nullptr;
+    return false;
 
   // Filter out unhandled pseudos.
   if (!AMDGPU::isSISrcOperand(Desc, UseOpIdx))
-    return nullptr;
+    return false;
 
   int16_t RCID = Desc.operands()[UseOpIdx].RegClass;
   if (RCID == -1)
-    return nullptr;
+    return false;
+
+  const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
 
   // Special case 0/-1, since when interpreted as a 64-bit element both halves
-  // have the same bits. Effectively this code does not handle 64-bit element
-  // operands correctly, as the incoming 64-bit constants are already split into
-  // 32-bit sequence elements.
-  //
-  // TODO: We should try to figure out how to interpret the reg_sequence as a
-  // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
-  // constants.
-  if (SplatVal->getImm() != 0 && SplatVal->getImm() != -1) {
-    const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
+  // have the same bits. These are the only cases where a splat has the same
+  // interpretation for 32-bit and 64-bit splats.
+  if (SplatVal != 0 && SplatVal != -1) {
     // We need to figure out the scalar type read by the operand. e.g. the MFMA
     // operand will be AReg_128, and we want to check if it's compatible with an
     // AReg_32 constant.
@@ -1031,17 +1068,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
       OpRC = TRI->getSubRegisterClass(OpRC, AMDGPU::sub0_sub1);
       break;
     default:
-      return nullptr;
+      return false;
     }
 
     if (!TRI->getCommonSubClass(OpRC, SplatRC))
-      return nullptr;
+      return false;
   }
 
-  if (!TII->isOperandLegal(*UseMI, UseOpIdx, SplatVal))
-    return nullptr;
+  MachineOperand TmpOp = MachineOperand::CreateImm(SplatVal);
+  if (!TII->isOperandLegal(*UseMI, UseOpIdx, &TmpOp))
+    return false;
 
-  return SplatVal;
+  return true;
 }
 
 bool SIFoldOperandsImpl::tryToFoldACImm(
@@ -1119,7 +1157,7 @@ void SIFoldOperandsImpl::foldOperand(
     Register RegSeqDstReg = UseMI->getOperand(0).getReg();
     unsigned RegSeqDstSubReg = UseMI->getOperand(UseOpIdx + 1).getImm();
 
-    MachineOperand *SplatVal;
+    int64_t SplatVal;
     const TargetRegisterClass *SplatRC;
     std::tie(SplatVal, SplatRC) = isRegSeqSplat(*UseMI);
 
@@ -1130,10 +1168,9 @@ void SIFoldOperandsImpl::foldOperand(
       MachineInstr *RSUseMI = RSUse->getParent();
       unsigned OpNo = RSUseMI->getOperandNo(RSUse);
 
-      if (SplatVal) {
-        if (MachineOperand *Foldable =
-                tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
-          FoldableDef SplatDef(*Foldable, SplatRC);
+      if (SplatRC) {
+        if (tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
+          FoldableDef SplatDef(SplatVal, SplatRC);
           appendFoldCandidate(FoldList, RSUseMI, OpNo, SplatDef);
           continue;
         }
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll
index 5d5dc01439fe4..a9cffd6e1c943 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll
@@ -165,19 +165,9 @@ bb:
 }
 
 ; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_1:
-; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0x3ff00000
-; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
-; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
-
-; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
 ; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
-; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
 ; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
 ; GCN:    global_store_dwordx4
 ; GCN:    global_store_dwordx4
@@ -190,19 +180,9 @@ bb:
 }
 
 ; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_neg1:
-; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0xbff00000
-; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
-; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
-
-; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
 ; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
-; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
 ; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
 ; GCN:    global_store_dwordx4
 ; GCN:    global_store_dwordx4
@@ -215,18 +195,9 @@ bb:
 }
 
 ; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_int_64:
-; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 64{{$}}
-; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], 0
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
-
-; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
 ; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
-; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
 ; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
 ; GCN:    global_store_dwordx4
 ; GCN:    global_store_dwordx4
    
    
More information about the llvm-branch-commits
mailing list