[llvm] [AMDGPU] Support V_FMA_MIX*_BF16 instructions on gfx1250 (PR #150381)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 24 00:15:43 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mc

Author: Changpeng Fang (changpeng)

<details>
<summary>Changes</summary>



---

Patch is 112.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150381.diff


12 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AMDGPU.td (+10) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+104-31) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h (+6-2) 
- (modified) llvm/lib/Target/AMDGPU/GCNSubtarget.h (+3) 
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+6-4) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+4) 
- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+84-54) 
- (added) llvm/test/CodeGen/AMDGPU/mad-mix-bf16.ll (+634) 
- (added) llvm/test/CodeGen/AMDGPU/mad-mix-hi-bf16.ll (+189) 
- (added) llvm/test/CodeGen/AMDGPU/mad-mix-lo-bf16.ll (+540) 
- (modified) llvm/test/MC/AMDGPU/gfx1250_asm_vop3p.s (+168) 
- (modified) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_vop3p.txt (+126) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td
index 2a36f3dea34ce..d6298c4ebf24a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.td
@@ -149,6 +149,12 @@ def FeatureFmaMixInsts : SubtargetFeature<"fma-mix-insts",
   "Has v_fma_mix_f32, v_fma_mixlo_f16, v_fma_mixhi_f16 instructions"
 >;
 
+def FeatureFmaMixBF16Insts : SubtargetFeature<"fma-mix-bf16-insts",
+  "HasFmaMixBF16Insts",
+  "true",
+  "Has v_fma_mix_f32_bf16, v_fma_mixlo_bf16, v_fma_mixhi_bf16 instructions"
+>;
+
 def FeatureIEEEMinimumMaximumInsts : SubtargetFeature<"ieee-minimum-maximum-insts",
   "HasIEEEMinimumMaximumInsts",
   "true",
@@ -2007,6 +2013,7 @@ def FeatureISAVersion12_50 : FeatureSet<
    FeatureBF16ConversionInsts,
    FeatureBF16PackedInsts,
    FeatureCvtPkF16F32Inst,
+   FeatureFmaMixBF16Insts,
    FeatureMin3Max3PKF16,
    FeatureMinimum3Maximum3PKF16,
    FeaturePrngInst,
@@ -2599,6 +2606,9 @@ def HasMovrel : Predicate<"Subtarget->hasMovrel()">,
 def HasFmaMixInsts : Predicate<"Subtarget->hasFmaMixInsts()">,
   AssemblerPredicate<(all_of FeatureFmaMixInsts)>;
 
+def HasFmaMixBF16Insts : Predicate<"Subtarget->hasFmaMixBF16Insts()">,
+  AssemblerPredicate<(all_of FeatureFmaMixBF16Insts)>;
+
 def HasDLInsts : Predicate<"Subtarget->hasDLInsts()">,
   AssemblerPredicate<(all_of FeatureDLInsts)>;
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index 5a2416debb417..0ca2286c11c94 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -3861,58 +3861,114 @@ bool AMDGPUDAGToDAGISel::SelectVOP3OpSelMods(SDValue In, SDValue &Src,
   return SelectVOP3Mods(In, Src, SrcMods);
 }
 
+// Match lowered fpext from bf16 to f32. This is a bit operation extending
+// a 16-bit value with 16-bit of zeroes at LSB:
+//
+// 1. (f32 (bitcast (build_vector (i16 0), (i16 (bitcast bf16:val)))))
+// 2. (f32 (bitcast (and i32:val, 0xffff0000))) -> IsExtractHigh = true
+// 3. (f32 (bitcast (shl i32:va, 16) -> IsExtractHigh = false
+static SDValue matchBF16FPExtendLike(SDValue Op, bool &IsExtractHigh) {
+  if (Op.getValueType() != MVT::f32 || Op.getOpcode() != ISD::BITCAST)
+    return SDValue();
+  Op = Op.getOperand(0);
+
+  IsExtractHigh = false;
+  if (Op.getValueType() == MVT::v2i16 && Op.getOpcode() == ISD::BUILD_VECTOR) {
+    auto Low16 = dyn_cast<ConstantSDNode>(Op.getOperand(0));
+    if (!Low16 || !Low16->isZero())
+      return SDValue();
+    Op = stripBitcast(Op.getOperand(1));
+    if (Op.getValueType() != MVT::bf16)
+      return SDValue();
+    return Op;
+  }
+
+  if (Op.getValueType() != MVT::i32)
+    return SDValue();
+
+  if (Op.getOpcode() == ISD::AND) {
+    if (auto Mask = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+      if (Mask->getZExtValue() == 0xffff0000) {
+        IsExtractHigh = true;
+        return Op.getOperand(0);
+      }
+    }
+    return SDValue();
+  }
+
+  if (Op.getOpcode() == ISD::SHL) {
+    if (auto Amt = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+      if (Amt->getZExtValue() == 16)
+        return Op.getOperand(0);
+    }
+  }
+
+  return SDValue();
+}
+
 // The return value is not whether the match is possible (which it always is),
 // but whether or not it a conversion is really used.
 bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src,
-                                                   unsigned &Mods) const {
+                                                   unsigned &Mods,
+                                                   MVT VT) const {
   Mods = 0;
   SelectVOP3ModsImpl(In, Src, Mods);
 
+  bool IsExtractHigh = false;
   if (Src.getOpcode() == ISD::FP_EXTEND) {
     Src = Src.getOperand(0);
-    assert(Src.getValueType() == MVT::f16);
-    Src = stripBitcast(Src);
+  } else if (VT == MVT::bf16) {
+    SDValue B16 = matchBF16FPExtendLike(Src, IsExtractHigh);
+    if (!B16)
+      return false;
+    Src = B16;
+  } else
+    return false;
 
-    // Be careful about folding modifiers if we already have an abs. fneg is
-    // applied last, so we don't want to apply an earlier fneg.
-    if ((Mods & SISrcMods::ABS) == 0) {
-      unsigned ModsTmp;
-      SelectVOP3ModsImpl(Src, Src, ModsTmp);
+  if (Src.getValueType() != VT &&
+      (VT != MVT::bf16 || Src.getValueType() != MVT::i32))
+    return false;
 
-      if ((ModsTmp & SISrcMods::NEG) != 0)
-        Mods ^= SISrcMods::NEG;
+  Src = stripBitcast(Src);
 
-      if ((ModsTmp & SISrcMods::ABS) != 0)
-        Mods |= SISrcMods::ABS;
-    }
+  // Be careful about folding modifiers if we already have an abs. fneg is
+  // applied last, so we don't want to apply an earlier fneg.
+  if ((Mods & SISrcMods::ABS) == 0) {
+    unsigned ModsTmp;
+    SelectVOP3ModsImpl(Src, Src, ModsTmp);
+
+    if ((ModsTmp & SISrcMods::NEG) != 0)
+      Mods ^= SISrcMods::NEG;
 
-    // op_sel/op_sel_hi decide the source type and source.
-    // If the source's op_sel_hi is set, it indicates to do a conversion from fp16.
-    // If the sources's op_sel is set, it picks the high half of the source
-    // register.
+    if ((ModsTmp & SISrcMods::ABS) != 0)
+      Mods |= SISrcMods::ABS;
+  }
 
-    Mods |= SISrcMods::OP_SEL_1;
-    if (isExtractHiElt(Src, Src)) {
-      Mods |= SISrcMods::OP_SEL_0;
+  // op_sel/op_sel_hi decide the source type and source.
+  // If the source's op_sel_hi is set, it indicates to do a conversion from
+  // fp16. If the sources's op_sel is set, it picks the high half of the source
+  // register.
 
-      // TODO: Should we try to look for neg/abs here?
-    }
+  Mods |= SISrcMods::OP_SEL_1;
+  if (IsExtractHigh ||
+      (Src.getValueSizeInBits() == 16 && isExtractHiElt(Src, Src))) {
+    Mods |= SISrcMods::OP_SEL_0;
 
-    // Prevent unnecessary subreg COPY to VGPR_16
-    if (Src.getOpcode() == ISD::TRUNCATE &&
-        Src.getOperand(0).getValueType() == MVT::i32) {
-      Src = Src.getOperand(0);
-    }
-    return true;
+    // TODO: Should we try to look for neg/abs here?
   }
 
-  return false;
+  // Prevent unnecessary subreg COPY to VGPR_16
+  if (Src.getOpcode() == ISD::TRUNCATE &&
+      Src.getOperand(0).getValueType() == MVT::i32) {
+    Src = Src.getOperand(0);
+  }
+  return true;
 }
 
 bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
                                                   SDValue &SrcMods) const {
   unsigned Mods = 0;
-  if (!SelectVOP3PMadMixModsImpl(In, Src, Mods))
+  if (!SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::f16))
     return false;
   SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
   return true;
@@ -3921,7 +3977,24 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
 bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixMods(SDValue In, SDValue &Src,
                                                SDValue &SrcMods) const {
   unsigned Mods = 0;
-  SelectVOP3PMadMixModsImpl(In, Src, Mods);
+  SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::f16);
+  SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+  return true;
+}
+
+bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixBF16ModsExt(SDValue In, SDValue &Src,
+                                                      SDValue &SrcMods) const {
+  unsigned Mods = 0;
+  if (!SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::bf16))
+    return false;
+  SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+  return true;
+}
+
+bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixBF16Mods(SDValue In, SDValue &Src,
+                                                   SDValue &SrcMods) const {
+  unsigned Mods = 0;
+  SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::bf16);
   SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
   return true;
 }
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
index 6123d75d7b616..7ecba1e24ff51 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
@@ -254,11 +254,15 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
   bool SelectVOP3OpSel(SDValue In, SDValue &Src, SDValue &SrcMods) const;
 
   bool SelectVOP3OpSelMods(SDValue In, SDValue &Src, SDValue &SrcMods) const;
-  bool SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src,
-                                 unsigned &Mods) const;
+  bool SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src, unsigned &Mods,
+                                 MVT VT) const;
   bool SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
                                 SDValue &SrcMods) const;
   bool SelectVOP3PMadMixMods(SDValue In, SDValue &Src, SDValue &SrcMods) const;
+  bool SelectVOP3PMadMixBF16ModsExt(SDValue In, SDValue &Src,
+                                    SDValue &SrcMods) const;
+  bool SelectVOP3PMadMixBF16Mods(SDValue In, SDValue &Src,
+                                 SDValue &SrcMods) const;
 
   bool SelectBITOP3(SDValue In, SDValue &Src0, SDValue &Src1, SDValue &Src2,
                    SDValue &Tbl) const;
diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
index 0435e7f9e51d2..0683f02955594 100644
--- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h
+++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
@@ -123,6 +123,7 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,
   bool HasSMemRealTime = false;
   bool HasIntClamp = false;
   bool HasFmaMixInsts = false;
+  bool HasFmaMixBF16Insts = false;
   bool HasMovrel = false;
   bool HasVGPRIndexMode = false;
   bool HasScalarDwordx3Loads = false;
@@ -462,6 +463,8 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,
     return HasFmaMixInsts;
   }
 
+  bool hasFmaMixBF16Insts() const { return HasFmaMixBF16Insts; }
+
   bool hasCARRY() const {
     return true;
   }
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index f1a8ee118356e..6d963b77850f4 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -1061,10 +1061,12 @@ ArrayRef<MCPhysReg> SITargetLowering::getRoundingControlRegisters() const {
 // where this is OK to use.
 bool SITargetLowering::isFPExtFoldable(const SelectionDAG &DAG, unsigned Opcode,
                                        EVT DestVT, EVT SrcVT) const {
-  return ((Opcode == ISD::FMAD && Subtarget->hasMadMixInsts()) ||
-          (Opcode == ISD::FMA && Subtarget->hasFmaMixInsts())) &&
-         DestVT.getScalarType() == MVT::f32 &&
-         SrcVT.getScalarType() == MVT::f16 &&
+  return DestVT.getScalarType() == MVT::f32 &&
+         ((((Opcode == ISD::FMAD && Subtarget->hasMadMixInsts()) ||
+            (Opcode == ISD::FMA && Subtarget->hasFmaMixInsts())) &&
+           SrcVT.getScalarType() == MVT::f16) ||
+          (Opcode == ISD::FMA && Subtarget->hasFmaMixBF16Insts() &&
+           SrcVT.getScalarType() == MVT::bf16)) &&
          // TODO: This probably only requires no input flushing?
          denormalModeIsFlushAllF32(DAG.getMachineFunction());
 }
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 485ca78db93a7..b0be3f864b94d 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1662,6 +1662,8 @@ def VOP3OpSelMods  : ComplexPattern<untyped, 2, "SelectVOP3OpSelMods">;
 
 def VOP3PMadMixModsExt : ComplexPattern<untyped, 2, "SelectVOP3PMadMixModsExt">;
 def VOP3PMadMixMods : ComplexPattern<untyped, 2, "SelectVOP3PMadMixMods">;
+def VOP3PMadMixBF16ModsExt : ComplexPattern<untyped, 2, "SelectVOP3PMadMixBF16ModsExt">;
+def VOP3PMadMixBF16Mods : ComplexPattern<untyped, 2, "SelectVOP3PMadMixBF16Mods">;
 
 def VINTERPMods  : ComplexPattern<untyped, 2, "SelectVINTERPMods">;
 def VINTERPModsHi  : ComplexPattern<untyped, 2, "SelectVINTERPModsHi">;
@@ -2866,6 +2868,7 @@ def VOP_I16_I16_I16_ARITH : VOPProfile <[i16, i16, i16, untyped], /*EnableClamp=
 
 def VOP_I16_I16_I16_I16 : VOPProfile <[i16, i16, i16, i16, untyped]>;
 def VOP_F16_F16_F16_F16 : VOPProfile <[f16, f16, f16, f16, untyped]>;
+def VOP_BF16_BF16_BF16_BF16 : VOPProfile <[bf16, bf16, bf16, bf16, untyped]>;
 
 def VOP_I32_I16_I16_I32 : VOPProfile <[i32, i16, i16, i32, untyped]>;
 def VOP_I32_I16 : VOPProfile <[i32, i16, untyped, untyped]>;
@@ -2917,6 +2920,7 @@ def VOP_I32_I32_I32_ARITH : VOPProfile <[i32, i32, i32, untyped], /*EnableClamp=
 def VOP_I64_I64_I64_ARITH : VOPProfile <[i64, i64, i64, untyped], /*EnableClamp=*/1>;
 def VOP_V2F16_F32_F32 : VOPProfile <[v2f16, f32, f32, untyped]>;
 def VOP_F32_F16_F16_F16 : VOPProfile <[f32, f16, f16, f16]>;
+def VOP_F32_BF16_BF16_BF16 : VOPProfile <[f32, bf16, bf16, bf16]>;
 def VOP_V2BF16_F32_F32 : VOPProfile <[v2bf16, f32, f32, untyped]>;
 def VOP_V32F32_V6I32_F32 : VOPProfile <[v32f32, v6i32, f32, untyped]>;
 def VOP_V32F16_V6I32_F32 : VOPProfile <[v32f16, v6i32, f32, untyped]>;
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index ea14c77cdff0b..7017da9dc3521 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -35,14 +35,18 @@ class VOP3P_Mix_Profile<VOPProfile P, VOP3Features Features = VOP3_REGULAR,
                     bit useTiedOutput = 0> : VOP3P_Profile<P, Features, 1> {
     bit UseTiedOutput = useTiedOutput;
 
+    defvar Src0RC = getVCSrcForVT<P.Src0VT>.ret;
+    defvar Src1RC = getVCSrcForVT<P.Src1VT>.ret;
+    defvar Src2RC = getVCSrcForVT<P.Src2VT>.ret;
+
     dag srcs =
-          (ins FP16InputMods:$src0_modifiers, VCSrc_f16:$src0,
-               FP16InputMods:$src1_modifiers, VCSrc_f16:$src1,
-               FP16InputMods:$src2_modifiers, VCSrc_f16:$src2);
+          (ins FP16InputMods:$src0_modifiers, Src0RC:$src0,
+               FP16InputMods:$src1_modifiers, Src1RC:$src1,
+               FP16InputMods:$src2_modifiers, Src2RC:$src2);
     dag dpp_srcs =
           (ins FPVRegInputMods:$src0_modifiers, VGPRSrc_32:$src0,
                FPVRegInputMods:$src1_modifiers, VRegSrc_32:$src1,
-               FP16InputMods:$src2_modifiers, VCSrc_f16:$src2);
+               FP16InputMods:$src2_modifiers, Src2RC:$src2);
 
            // FIXME: Clamp0 misbehaves with the non-default vdst_in
            // following it. For now workaround this by requiring clamp
@@ -161,38 +165,42 @@ defm V_PK_MAXIMUM3_F16 : VOP3PInst<"v_pk_maximum3_f16", VOP3P_Profile<VOP_V2F16_
 multiclass MadFmaMixPats<SDPatternOperator fma_like,
                          Instruction mix_inst,
                          Instruction mixlo_inst,
-                         Instruction mixhi_inst> {
+                         Instruction mixhi_inst,
+                         ValueType VT = f16,
+                         ValueType vecVT = v2f16> {
+  defvar VOP3PMadMixModsPat = !if (!eq(VT, bf16), VOP3PMadMixBF16Mods, VOP3PMadMixMods);
+  defvar VOP3PMadMixModsExtPat = !if (!eq(VT, bf16), VOP3PMadMixBF16ModsExt, VOP3PMadMixModsExt);
   // At least one of the operands needs to be an fpextend of an f16
   // for this to be worthwhile, so we need three patterns here.
   // TODO: Could we use a predicate to inspect src1/2/3 instead?
   def : GCNPat <
-    (f32 (fma_like (f32 (VOP3PMadMixModsExt f16:$src0, i32:$src0_mods)),
-                   (f32 (VOP3PMadMixMods f16:$src1, i32:$src1_mods)),
-                   (f32 (VOP3PMadMixMods f16:$src2, i32:$src2_mods)))),
+    (f32 (fma_like (f32 (VOP3PMadMixModsExtPat VT:$src0, i32:$src0_mods)),
+                   (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_mods)),
+                   (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_mods)))),
     (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
               DSTCLAMP.NONE)>;
   def : GCNPat <
-    (f32 (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_mods)),
-                   (f32 (VOP3PMadMixModsExt f16:$src1, i32:$src1_mods)),
-                   (f32 (VOP3PMadMixMods f32:$src2, i32:$src2_mods)))),
+    (f32 (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_mods)),
+                   (f32 (VOP3PMadMixModsExtPat VT:$src1, i32:$src1_mods)),
+                   (f32 (VOP3PMadMixModsPat f32:$src2, i32:$src2_mods)))),
     (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
               DSTCLAMP.NONE)>;
   def : GCNPat <
-    (f32 (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_mods)),
-                   (f32 (VOP3PMadMixMods f32:$src1, i32:$src1_mods)),
-                   (f32 (VOP3PMadMixModsExt f16:$src2, i32:$src2_mods)))),
+    (f32 (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_mods)),
+                   (f32 (VOP3PMadMixModsPat f32:$src1, i32:$src1_mods)),
+                   (f32 (VOP3PMadMixModsExtPat VT:$src2, i32:$src2_mods)))),
     (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
               DSTCLAMP.NONE)>;
 
   def : GCNPat <
     (AMDGPUclamp (build_vector
-      (f16 (fpround (fma_like (f32 (VOP3PMadMixMods f16:$lo_src0, i32:$lo_src0_modifiers)),
-                         (f32 (VOP3PMadMixMods f16:$lo_src1, i32:$lo_src1_modifiers)),
-                         (f32 (VOP3PMadMixMods f16:$lo_src2, i32:$lo_src2_modifiers))))),
-      (f16 (fpround (fma_like (f32 (VOP3PMadMixMods f16:$hi_src0, i32:$hi_src0_modifiers)),
-                         (f32 (VOP3PMadMixMods f16:$hi_src1, i32:$hi_src1_modifiers)),
-                         (f32 (VOP3PMadMixMods f16:$hi_src2, i32:$hi_src2_modifiers))))))),
-    (v2f16 (mixhi_inst $hi_src0_modifiers, $hi_src0,
+      (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$lo_src0, i32:$lo_src0_modifiers)),
+                        (f32 (VOP3PMadMixModsPat VT:$lo_src1, i32:$lo_src1_modifiers)),
+                        (f32 (VOP3PMadMixModsPat VT:$lo_src2, i32:$lo_src2_modifiers))))),
+      (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$hi_src0, i32:$hi_src0_modifiers)),
+                        (f32 (VOP3PMadMixModsPat VT:$hi_src1, i32:$hi_src1_modifiers)),
+                        (f32 (VOP3PMadMixModsPat VT:$hi_src2, i32:$hi_src2_modifiers))))))),
+    (vecVT (mixhi_inst $hi_src0_modifiers, $hi_src0,
                        $hi_src1_modifiers, $hi_src1,
                        $hi_src2_modifiers, $hi_src2,
                        DSTCLAMP.ENABLE,
@@ -204,8 +212,8 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
   >;
 
   def : GCNPat <
-    (f16 (fpround (fmul (f32 (VOP3PMadMixMods f32:$src0, i32:$src0_modifiers)),
-                        (f32 (VOP3PMadMixMods f32:$src1, i32:$src1_modifiers))))),
+    (VT (fpround (fmul (f32 (VOP3PMadMixModsPat f32:$src0, i32:$src0_modifiers)),
+                       (f32 (VOP3PMadMixModsPat f32:$src1, i32:$src1_modifiers))))),
     (mixlo_inst $src0_modifiers, $src0,
                 $src1_modifiers, $src1,
                 (i32 0), (i32 0),
@@ -214,9 +222,9 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
   >;
 
   def : GCNPat <
-    (build_vector f16:$elt0, (f16 (fpround (fmul (f32 (VOP3PMadMixMods f32:$src0, i32:$src0_modifiers)),
-                                            (f32 (VOP3PMadMixMods f32:$src1, i32:$src1_modifiers)))))),
-    (v2f16 (mixhi_inst $src0_modifiers, $src0,
+    (build_vector VT:$elt0, (VT (fpround (fmul (f32 (VOP3PMadMixModsPat f32:$src0, i32:$src0_modifiers)),
+                                          (f32 (VOP3PMadMixModsPat f32:$src1, i32:$src1_modifiers)))))),
+    (vecVT (mixhi_inst $src0_modifiers, $src0,
                        $src1_modifiers, $src1,
                        (i32 0), (i32 0),
                        DSTCLAMP.NONE,
@@ -224,9 +232,9 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
   >;
 
   def : GCNPat <
-    (f16 (fpround (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_modifiers)),
-                            (f32 (VOP3PMadMixMods f16:$src1, i32:$src1_modifiers)),
-                            (f32 (VOP3PMadMixMods f16:$src2, i32:$src2_modifiers))))),
+    (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
+                           (f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
+                           (f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))),
     (mixlo_inst $src0_modifiers, $src0,
                 $src1_modifiers, $src1,
                 $src2_modifiers, $src2,
@@ -241,10 +249,10 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
   let True16Predicate = p in {
 
   def : GCNP...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/150381


More information about the llvm-commits mailing list