[llvm] use vgpr16 for madmixfma (PR #159421)

Brox Chen via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 17 11:48:27 PDT 2025


https://github.com/broxigarchen created https://github.com/llvm/llvm-project/pull/159421

None

>From ae4cd7fd6e149906a9b9f648108d4edccfcf614a Mon Sep 17 00:00:00 2001
From: guochen2 <guochen2 at amd.com>
Date: Tue, 16 Sep 2025 17:49:50 -0400
Subject: [PATCH] use vgpr16 for madmixfma

---
 llvm/lib/Target/AMDGPU/VOP3PInstructions.td | 132 ++++++++++----------
 1 file changed, 68 insertions(+), 64 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index f7279b664ed27..800a6001bfe5a 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -158,16 +158,26 @@ defm V_PK_MAXIMUM3_F16 : VOP3PInst<"v_pk_maximum3_f16", VOP3P_Profile<VOP_V2F16_
 }
 } // End isCommutable = 1, FPDPRounding = 1
 
+class MadFmaMixPatOp<dag op1, bit true16, bit isHi16 = 0> {
+  dag ret = !if(true16,
+                !if(isHi16,
+                    (REG_SEQUENCE VGPR_32, (i16 (IMPLICIT_DEF)), lo16, op1, hi16),
+                    (REG_SEQUENCE VGPR_32, op1, lo16, (i16 (IMPLICIT_DEF)), hi16)),
+               op1);
+}
+
 // TODO: Make sure we're doing the right thing with denormals. Note
 // that FMA and MAD will differ.
-multiclass MadFmaMixPats<SDPatternOperator fma_like,
-                         Instruction mix_inst,
-                         Instruction mixlo_inst,
-                         Instruction mixhi_inst,
-                         ValueType VT = f16,
-                         ValueType vecVT = v2f16> {
+multiclass MadFmaMixPatsImpl<SDPatternOperator fma_like,
+                             Instruction mix_inst,
+                             Instruction mixlo_inst,
+                             Instruction mixhi_inst,
+                             ValueType VT,
+                             ValueType vecVT,
+                             bit true16> {
   defvar VOP3PMadMixModsPat = !if (!eq(VT, bf16), VOP3PMadMixBF16Mods, VOP3PMadMixMods);
   defvar VOP3PMadMixModsExtPat = !if (!eq(VT, bf16), VOP3PMadMixBF16ModsExt, VOP3PMadMixModsExt);
+
   // At least one of the operands needs to be an fpextend of an f16
   // for this to be worthwhile, so we need three patterns here.
   // TODO: Could we use a predicate to inspect src1/2/3 instead?
@@ -175,19 +185,25 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
     (f32 (fma_like (f32 (VOP3PMadMixModsExtPat VT:$src0, i32:$src0_mods)),
                    (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_mods)),
                    (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_mods)))),
-    (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
+    (mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
+              $src1_mods, MadFmaMixPatOp<(VT $src1), true16>.ret,
+              $src2_mods, MadFmaMixPatOp<(VT $src2), true16>.ret,
               DSTCLAMP.NONE)>;
   def : GCNPat <
     (f32 (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_mods)),
                    (f32 (VOP3PMadMixModsExtPat VT:$src1, i32:$src1_mods)),
                    (f32 (VOP3PMadMixModsPat f32:$src2, i32:$src2_mods)))),
-    (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
+    (mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
+              $src1_mods, MadFmaMixPatOp<(VT $src1), true16>.ret,
+              $src2_mods, $src2,
               DSTCLAMP.NONE)>;
   def : GCNPat <
     (f32 (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_mods)),
                    (f32 (VOP3PMadMixModsPat f32:$src1, i32:$src1_mods)),
                    (f32 (VOP3PMadMixModsExtPat VT:$src2, i32:$src2_mods)))),
-    (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
+    (mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
+              $src1_mods, $src1,
+              $src2_mods, MadFmaMixPatOp<(VT $src2), true16>.ret,
               DSTCLAMP.NONE)>;
 
   def : GCNPat <
@@ -198,13 +214,13 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
       (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$hi_src0, i32:$hi_src0_modifiers)),
                         (f32 (VOP3PMadMixModsPat VT:$hi_src1, i32:$hi_src1_modifiers)),
                         (f32 (VOP3PMadMixModsPat VT:$hi_src2, i32:$hi_src2_modifiers))))))),
-    (vecVT (mixhi_inst $hi_src0_modifiers, $hi_src0,
-                       $hi_src1_modifiers, $hi_src1,
-                       $hi_src2_modifiers, $hi_src2,
+    (vecVT (mixhi_inst $hi_src0_modifiers, MadFmaMixPatOp<(VT $hi_src0), true16>.ret,
+                       $hi_src1_modifiers, MadFmaMixPatOp<(VT $hi_src1), true16>.ret,
+                       $hi_src2_modifiers, MadFmaMixPatOp<(VT $hi_src2), true16>.ret,
                        DSTCLAMP.ENABLE,
-                       (mixlo_inst $lo_src0_modifiers, $lo_src0,
-                                   $lo_src1_modifiers, $lo_src1,
-                                   $lo_src2_modifiers, $lo_src2,
+                       (mixlo_inst $lo_src0_modifiers, MadFmaMixPatOp<(VT $lo_src0), true16>.ret,
+                                   $lo_src1_modifiers, MadFmaMixPatOp<(VT $lo_src1), true16>.ret,
+                                   $lo_src2_modifiers, MadFmaMixPatOp<(VT $lo_src2), true16>.ret,
                                    DSTCLAMP.ENABLE,
                                    (i32 (IMPLICIT_DEF)))))
   >;
@@ -233,9 +249,9 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
     (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
                            (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
                            (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))),
-    (mixlo_inst $src0_modifiers, $src0,
-                $src1_modifiers, $src1,
-                $src2_modifiers, $src2,
+    (mixlo_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
+                $src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
+                $src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
                 DSTCLAMP.NONE,
                 (i32 (IMPLICIT_DEF)))
   >;
@@ -243,18 +259,15 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
   // FIXME: Special case handling for maxhi (especially for clamp)
   // because dealing with the write to high half of the register is
   // difficult.
-  foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in
-  let True16Predicate = p in {
-
   def : GCNPat <
     (build_vector VT:$elt0, (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
                                                    (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
                                                    (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers)))))),
-    (vecVT (mixhi_inst $src0_modifiers, $src0,
-                       $src1_modifiers, $src1,
-                       $src2_modifiers, $src2,
+    (vecVT (mixhi_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
+                       $src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
+                       $src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
                        DSTCLAMP.NONE,
-                       VGPR_32:$elt0))
+                       MadFmaMixPatOp<(VT $elt0), true16>.ret))
   >;
 
   def : GCNPat <
@@ -263,51 +276,42 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
       (AMDGPUclamp (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
                                      (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
                                      (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))))),
-    (vecVT (mixhi_inst $src0_modifiers, $src0,
-                       $src1_modifiers, $src1,
-                       $src2_modifiers, $src2,
+    (vecVT (mixhi_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
+                       $src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
+                       $src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
                        DSTCLAMP.ENABLE,
-                       VGPR_32:$elt0))
+                       MadFmaMixPatOp<(VT $elt0), true16>.ret))
   >;
+}
 
-  } // end True16Predicate
+multiclass MadFmaMixPats<SDPatternOperator fma_like,
+                         Instruction mix_inst,
+                         Instruction mixlo_inst,
+                         Instruction mixhi_inst,
+                         ValueType VT = f16,
+                         ValueType vecVT = v2f16> {
 
-  let True16Predicate = UseRealTrue16Insts in {
-  def : GCNPat <
-    (build_vector (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
-                           (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
-                           (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))), VT:$elt1),
-    (vecVT (mixlo_inst $src0_modifiers, $src0,
-                $src1_modifiers, $src1,
-                $src2_modifiers, $src2,
-                DSTCLAMP.NONE,
-                (REG_SEQUENCE VGPR_32, (VT (IMPLICIT_DEF)), lo16, $elt1, hi16)))
-  >;
+  defvar VOP3PMadMixModsPat = !if (!eq(VT, bf16), VOP3PMadMixBF16Mods, VOP3PMadMixMods);
 
-  def : GCNPat <
-    (build_vector VT:$elt0, (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
-                                              (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
-                                              (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers)))))),
-    (vecVT (mixhi_inst $src0_modifiers, $src0,
-                       $src1_modifiers, $src1,
-                       $src2_modifiers, $src2,
-                       DSTCLAMP.NONE,
-                       (REG_SEQUENCE VGPR_32, $elt0, lo16, (VT (IMPLICIT_DEF)), hi16)))
-  >;
+  foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in
+  let True16Predicate = p in
+  defm : MadFmaMixPatsImpl<fma_like, mix_inst, mixlo_inst, mixhi_inst, VT, vecVT, /*true16*/ 0>;
+
+  let True16Predicate = UseRealTrue16Insts in {
+    defm : MadFmaMixPatsImpl<fma_like, mix_inst, mixlo_inst, mixhi_inst, VT, vecVT, /*true16*/ 1>;
+
+    def : GCNPat <
+      (build_vector (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
+                             (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
+                             (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))), VT:$elt1),
+      (vecVT (mixlo_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), /*true16*/1>.ret,
+                  $src1_modifiers, MadFmaMixPatOp<(VT $src1), /*true16*/1>.ret,
+                  $src2_modifiers, MadFmaMixPatOp<(VT $src2), /*true16*/1>.ret,
+                  DSTCLAMP.NONE,
+                  MadFmaMixPatOp<(VT $elt1), /*true16*/1, /*isHi16*/1>.ret))
+    >;
+  }
 
-  def : GCNPat <
-    (build_vector
-      VT:$elt0,
-      (AMDGPUclamp (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
-                                     (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
-                                     (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))))),
-    (vecVT (mixhi_inst $src0_modifiers, $src0,
-                       $src1_modifiers, $src1,
-                       $src2_modifiers, $src2,
-                       DSTCLAMP.ENABLE,
-                       (REG_SEQUENCE VGPR_32, $elt0, lo16, (VT (IMPLICIT_DEF)), hi16)))
-  >;
-  } // end True16Predicate
 }
 
 class MinimumMaximumByMinimum3Maximum3VOP3P<SDPatternOperator node,



More information about the llvm-commits mailing list