[llvm] use vgpr16 for madmixfma (PR #159421)
Brox Chen via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 17 11:48:27 PDT 2025
https://github.com/broxigarchen created https://github.com/llvm/llvm-project/pull/159421
None
>From ae4cd7fd6e149906a9b9f648108d4edccfcf614a Mon Sep 17 00:00:00 2001
From: guochen2 <guochen2 at amd.com>
Date: Tue, 16 Sep 2025 17:49:50 -0400
Subject: [PATCH] use vgpr16 for madmixfma
---
llvm/lib/Target/AMDGPU/VOP3PInstructions.td | 132 ++++++++++----------
1 file changed, 68 insertions(+), 64 deletions(-)
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index f7279b664ed27..800a6001bfe5a 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -158,16 +158,26 @@ defm V_PK_MAXIMUM3_F16 : VOP3PInst<"v_pk_maximum3_f16", VOP3P_Profile<VOP_V2F16_
}
} // End isCommutable = 1, FPDPRounding = 1
+class MadFmaMixPatOp<dag op1, bit true16, bit isHi16 = 0> {
+ dag ret = !if(true16,
+ !if(isHi16,
+ (REG_SEQUENCE VGPR_32, (i16 (IMPLICIT_DEF)), lo16, op1, hi16),
+ (REG_SEQUENCE VGPR_32, op1, lo16, (i16 (IMPLICIT_DEF)), hi16)),
+ op1);
+}
+
// TODO: Make sure we're doing the right thing with denormals. Note
// that FMA and MAD will differ.
-multiclass MadFmaMixPats<SDPatternOperator fma_like,
- Instruction mix_inst,
- Instruction mixlo_inst,
- Instruction mixhi_inst,
- ValueType VT = f16,
- ValueType vecVT = v2f16> {
+multiclass MadFmaMixPatsImpl<SDPatternOperator fma_like,
+ Instruction mix_inst,
+ Instruction mixlo_inst,
+ Instruction mixhi_inst,
+ ValueType VT,
+ ValueType vecVT,
+ bit true16> {
defvar VOP3PMadMixModsPat = !if (!eq(VT, bf16), VOP3PMadMixBF16Mods, VOP3PMadMixMods);
defvar VOP3PMadMixModsExtPat = !if (!eq(VT, bf16), VOP3PMadMixBF16ModsExt, VOP3PMadMixModsExt);
+
// At least one of the operands needs to be an fpextend of an f16
// for this to be worthwhile, so we need three patterns here.
// TODO: Could we use a predicate to inspect src1/2/3 instead?
@@ -175,19 +185,25 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
(f32 (fma_like (f32 (VOP3PMadMixModsExtPat VT:$src0, i32:$src0_mods)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_mods)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_mods)))),
- (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
+ (mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
+ $src1_mods, MadFmaMixPatOp<(VT $src1), true16>.ret,
+ $src2_mods, MadFmaMixPatOp<(VT $src2), true16>.ret,
DSTCLAMP.NONE)>;
def : GCNPat <
(f32 (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_mods)),
(f32 (VOP3PMadMixModsExtPat VT:$src1, i32:$src1_mods)),
(f32 (VOP3PMadMixModsPat f32:$src2, i32:$src2_mods)))),
- (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
+ (mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
+ $src1_mods, MadFmaMixPatOp<(VT $src1), true16>.ret,
+ $src2_mods, $src2,
DSTCLAMP.NONE)>;
def : GCNPat <
(f32 (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_mods)),
(f32 (VOP3PMadMixModsPat f32:$src1, i32:$src1_mods)),
(f32 (VOP3PMadMixModsExtPat VT:$src2, i32:$src2_mods)))),
- (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
+ (mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
+ $src1_mods, $src1,
+ $src2_mods, MadFmaMixPatOp<(VT $src2), true16>.ret,
DSTCLAMP.NONE)>;
def : GCNPat <
@@ -198,13 +214,13 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
(VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$hi_src0, i32:$hi_src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$hi_src1, i32:$hi_src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$hi_src2, i32:$hi_src2_modifiers))))))),
- (vecVT (mixhi_inst $hi_src0_modifiers, $hi_src0,
- $hi_src1_modifiers, $hi_src1,
- $hi_src2_modifiers, $hi_src2,
+ (vecVT (mixhi_inst $hi_src0_modifiers, MadFmaMixPatOp<(VT $hi_src0), true16>.ret,
+ $hi_src1_modifiers, MadFmaMixPatOp<(VT $hi_src1), true16>.ret,
+ $hi_src2_modifiers, MadFmaMixPatOp<(VT $hi_src2), true16>.ret,
DSTCLAMP.ENABLE,
- (mixlo_inst $lo_src0_modifiers, $lo_src0,
- $lo_src1_modifiers, $lo_src1,
- $lo_src2_modifiers, $lo_src2,
+ (mixlo_inst $lo_src0_modifiers, MadFmaMixPatOp<(VT $lo_src0), true16>.ret,
+ $lo_src1_modifiers, MadFmaMixPatOp<(VT $lo_src1), true16>.ret,
+ $lo_src2_modifiers, MadFmaMixPatOp<(VT $lo_src2), true16>.ret,
DSTCLAMP.ENABLE,
(i32 (IMPLICIT_DEF)))))
>;
@@ -233,9 +249,9 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
(VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))),
- (mixlo_inst $src0_modifiers, $src0,
- $src1_modifiers, $src1,
- $src2_modifiers, $src2,
+ (mixlo_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
+ $src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
+ $src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
DSTCLAMP.NONE,
(i32 (IMPLICIT_DEF)))
>;
@@ -243,18 +259,15 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
// FIXME: Special case handling for maxhi (especially for clamp)
// because dealing with the write to high half of the register is
// difficult.
- foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in
- let True16Predicate = p in {
-
def : GCNPat <
(build_vector VT:$elt0, (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers)))))),
- (vecVT (mixhi_inst $src0_modifiers, $src0,
- $src1_modifiers, $src1,
- $src2_modifiers, $src2,
+ (vecVT (mixhi_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
+ $src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
+ $src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
DSTCLAMP.NONE,
- VGPR_32:$elt0))
+ MadFmaMixPatOp<(VT $elt0), true16>.ret))
>;
def : GCNPat <
@@ -263,51 +276,42 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
(AMDGPUclamp (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))))),
- (vecVT (mixhi_inst $src0_modifiers, $src0,
- $src1_modifiers, $src1,
- $src2_modifiers, $src2,
+ (vecVT (mixhi_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
+ $src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
+ $src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
DSTCLAMP.ENABLE,
- VGPR_32:$elt0))
+ MadFmaMixPatOp<(VT $elt0), true16>.ret))
>;
+}
- } // end True16Predicate
+multiclass MadFmaMixPats<SDPatternOperator fma_like,
+ Instruction mix_inst,
+ Instruction mixlo_inst,
+ Instruction mixhi_inst,
+ ValueType VT = f16,
+ ValueType vecVT = v2f16> {
- let True16Predicate = UseRealTrue16Insts in {
- def : GCNPat <
- (build_vector (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
- (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
- (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))), VT:$elt1),
- (vecVT (mixlo_inst $src0_modifiers, $src0,
- $src1_modifiers, $src1,
- $src2_modifiers, $src2,
- DSTCLAMP.NONE,
- (REG_SEQUENCE VGPR_32, (VT (IMPLICIT_DEF)), lo16, $elt1, hi16)))
- >;
+ defvar VOP3PMadMixModsPat = !if (!eq(VT, bf16), VOP3PMadMixBF16Mods, VOP3PMadMixMods);
- def : GCNPat <
- (build_vector VT:$elt0, (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
- (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
- (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers)))))),
- (vecVT (mixhi_inst $src0_modifiers, $src0,
- $src1_modifiers, $src1,
- $src2_modifiers, $src2,
- DSTCLAMP.NONE,
- (REG_SEQUENCE VGPR_32, $elt0, lo16, (VT (IMPLICIT_DEF)), hi16)))
- >;
+ foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in
+ let True16Predicate = p in
+ defm : MadFmaMixPatsImpl<fma_like, mix_inst, mixlo_inst, mixhi_inst, VT, vecVT, /*true16*/ 0>;
+
+ let True16Predicate = UseRealTrue16Insts in {
+ defm : MadFmaMixPatsImpl<fma_like, mix_inst, mixlo_inst, mixhi_inst, VT, vecVT, /*true16*/ 1>;
+
+ def : GCNPat <
+ (build_vector (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
+ (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
+ (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))), VT:$elt1),
+ (vecVT (mixlo_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), /*true16*/1>.ret,
+ $src1_modifiers, MadFmaMixPatOp<(VT $src1), /*true16*/1>.ret,
+ $src2_modifiers, MadFmaMixPatOp<(VT $src2), /*true16*/1>.ret,
+ DSTCLAMP.NONE,
+ MadFmaMixPatOp<(VT $elt1), /*true16*/1, /*isHi16*/1>.ret))
+ >;
+ }
- def : GCNPat <
- (build_vector
- VT:$elt0,
- (AMDGPUclamp (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
- (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
- (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))))),
- (vecVT (mixhi_inst $src0_modifiers, $src0,
- $src1_modifiers, $src1,
- $src2_modifiers, $src2,
- DSTCLAMP.ENABLE,
- (REG_SEQUENCE VGPR_32, $elt0, lo16, (VT (IMPLICIT_DEF)), hi16)))
- >;
- } // end True16Predicate
}
class MinimumMaximumByMinimum3Maximum3VOP3P<SDPatternOperator node,
More information about the llvm-commits
mailing list