[llvm-branch-commits] [llvm] AMDGPU: Handle folding vector splats of inline split f64 inline immediates (PR #140878)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed May 21 03:34:27 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Matt Arsenault (arsenm)
<details>
<summary>Changes</summary>
Recognize a reg_sequence with 32-bit elements that produce a 64-bit
splat value. This enables folding f64 constants into mfma operands
---
Full diff: https://github.com/llvm/llvm-project/pull/140878.diff
2 Files Affected:
- (modified) llvm/lib/Target/AMDGPU/SIFoldOperands.cpp (+70-33)
- (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll (+6-35)
``````````diff
diff --git a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
index eb7fb94e25f5c..70e3974bb22b4 100644
--- a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
+++ b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
@@ -227,12 +227,12 @@ class SIFoldOperandsImpl {
getRegSeqInit(SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs,
Register UseReg) const;
- std::pair<MachineOperand *, const TargetRegisterClass *>
+ std::pair<int64_t, const TargetRegisterClass *>
isRegSeqSplat(MachineInstr &RegSeg) const;
- MachineOperand *tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
- MachineOperand *SplatVal,
- const TargetRegisterClass *SplatRC) const;
+ bool tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
+ int64_t SplatVal,
+ const TargetRegisterClass *SplatRC) const;
bool tryToFoldACImm(const FoldableDef &OpToFold, MachineInstr *UseMI,
unsigned UseOpIdx,
@@ -967,15 +967,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
return getRegSeqInit(*Def, Defs);
}
-std::pair<MachineOperand *, const TargetRegisterClass *>
+std::pair<int64_t, const TargetRegisterClass *>
SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
SmallVector<std::pair<MachineOperand *, unsigned>, 32> Defs;
const TargetRegisterClass *SrcRC = getRegSeqInit(RegSeq, Defs);
if (!SrcRC)
return {};
- // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
- // every other other element is 0 for 64-bit immediates)
+ bool TryToMatchSplat64 = false;
+
int64_t Imm;
for (unsigned I = 0, E = Defs.size(); I != E; ++I) {
const MachineOperand *Op = Defs[I].first;
@@ -987,38 +987,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
Imm = SubImm;
continue;
}
- if (Imm != SubImm)
+
+ if (Imm != SubImm) {
+ if (I == 1 && (E & 1) == 0) {
+ // If we have an even number of inputs, there's a chance this is a
+ // 64-bit element splat broken into 32-bit pieces.
+ TryToMatchSplat64 = true;
+ break;
+ }
+
return {}; // Can only fold splat constants
+ }
+ }
+
+ if (!TryToMatchSplat64)
+ return {Defs[0].first->getImm(), SrcRC};
+
+ // Fallback to recognizing 64-bit splats broken into 32-bit pieces
+ // (i.e. recognize every other other element is 0 for 64-bit immediates)
+ int64_t SplatVal64;
+ for (unsigned I = 0, E = Defs.size(); I != E; I += 2) {
+ const MachineOperand *Op0 = Defs[I].first;
+ const MachineOperand *Op1 = Defs[I + 1].first;
+
+ if (!Op0->isImm() || !Op1->isImm())
+ return {};
+
+ unsigned SubReg0 = Defs[I].second;
+ unsigned SubReg1 = Defs[I + 1].second;
+
+ // Assume we're going to generally encounter reg_sequences with sorted
+ // subreg indexes, so reject any that aren't consecutive.
+ if (TRI->getChannelFromSubReg(SubReg0) + 1 !=
+ TRI->getChannelFromSubReg(SubReg1))
+ return {};
+
+ int64_t MergedVal = Make_64(Op1->getImm(), Op0->getImm());
+ if (I == 0)
+ SplatVal64 = MergedVal;
+ else if (SplatVal64 != MergedVal)
+ return {};
}
- return {Defs[0].first, SrcRC};
+ const TargetRegisterClass *RC64 = TRI->getSubRegisterClass(
+ MRI->getRegClass(RegSeq.getOperand(0).getReg()), AMDGPU::sub0_sub1);
+
+ return {SplatVal64, RC64};
}
-MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
- MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
+bool SIFoldOperandsImpl::tryFoldRegSeqSplat(
+ MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
const TargetRegisterClass *SplatRC) const {
const MCInstrDesc &Desc = UseMI->getDesc();
if (UseOpIdx >= Desc.getNumOperands())
- return nullptr;
+ return false;
// Filter out unhandled pseudos.
if (!AMDGPU::isSISrcOperand(Desc, UseOpIdx))
- return nullptr;
+ return false;
int16_t RCID = Desc.operands()[UseOpIdx].RegClass;
if (RCID == -1)
- return nullptr;
+ return false;
+
+ const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
// Special case 0/-1, since when interpreted as a 64-bit element both halves
- // have the same bits. Effectively this code does not handle 64-bit element
- // operands correctly, as the incoming 64-bit constants are already split into
- // 32-bit sequence elements.
- //
- // TODO: We should try to figure out how to interpret the reg_sequence as a
- // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
- // constants.
- if (SplatVal->getImm() != 0 && SplatVal->getImm() != -1) {
- const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
+ // have the same bits. These are the only cases where a splat has the same
+ // interpretation for 32-bit and 64-bit splats.
+ if (SplatVal != 0 && SplatVal != -1) {
// We need to figure out the scalar type read by the operand. e.g. the MFMA
// operand will be AReg_128, and we want to check if it's compatible with an
// AReg_32 constant.
@@ -1032,17 +1069,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
OpRC = TRI->getSubRegisterClass(OpRC, AMDGPU::sub0_sub1);
break;
default:
- return nullptr;
+ return false;
}
if (!TRI->getCommonSubClass(OpRC, SplatRC))
- return nullptr;
+ return false;
}
- if (!TII->isOperandLegal(*UseMI, UseOpIdx, SplatVal))
- return nullptr;
+ MachineOperand TmpOp = MachineOperand::CreateImm(SplatVal);
+ if (!TII->isOperandLegal(*UseMI, UseOpIdx, &TmpOp))
+ return false;
- return SplatVal;
+ return true;
}
bool SIFoldOperandsImpl::tryToFoldACImm(
@@ -1120,7 +1158,7 @@ void SIFoldOperandsImpl::foldOperand(
Register RegSeqDstReg = UseMI->getOperand(0).getReg();
unsigned RegSeqDstSubReg = UseMI->getOperand(UseOpIdx + 1).getImm();
- MachineOperand *SplatVal;
+ int64_t SplatVal;
const TargetRegisterClass *SplatRC;
std::tie(SplatVal, SplatRC) = isRegSeqSplat(*UseMI);
@@ -1131,10 +1169,9 @@ void SIFoldOperandsImpl::foldOperand(
MachineInstr *RSUseMI = RSUse->getParent();
unsigned OpNo = RSUseMI->getOperandNo(RSUse);
- if (SplatVal) {
- if (MachineOperand *Foldable =
- tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
- FoldableDef SplatDef(*Foldable, SplatRC);
+ if (SplatRC) {
+ if (tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
+ FoldableDef SplatDef(SplatVal, SplatRC);
appendFoldCandidate(FoldList, RSUseMI, OpNo, SplatDef);
continue;
}
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll
index 5d5dc01439fe4..a9cffd6e1c943 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll
@@ -165,19 +165,9 @@ bb:
}
; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_1:
-; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0x3ff00000
-; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
-; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
-
-; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
-; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
; GCN: global_store_dwordx4
; GCN: global_store_dwordx4
@@ -190,19 +180,9 @@ bb:
}
; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_neg1:
-; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0xbff00000
-; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
-; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
-
-; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
-; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
; GCN: global_store_dwordx4
; GCN: global_store_dwordx4
@@ -215,18 +195,9 @@ bb:
}
; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_int_64:
-; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 64{{$}}
-; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], 0
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
-
-; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
-; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
; GCN: global_store_dwordx4
; GCN: global_store_dwordx4
``````````
</details>
https://github.com/llvm/llvm-project/pull/140878
More information about the llvm-branch-commits
mailing list