[llvm] 98f59b2 - [RISCV] Teach doPeepholeMaskedRVV to handle FMA instructions.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri May 12 23:36:36 PDT 2023


Author: Craig Topper
Date: 2023-05-12T23:36:27-07:00
New Revision: 98f59b2f5bc0654ba2936931f7ff837ae1d07acc

URL: https://github.com/llvm/llvm-project/commit/98f59b2f5bc0654ba2936931f7ff837ae1d07acc
DIFF: https://github.com/llvm/llvm-project/commit/98f59b2f5bc0654ba2936931f7ff837ae1d07acc.diff

LOG: [RISCV] Teach doPeepholeMaskedRVV to handle FMA instructions.

This lets us remove some isel patterns.

Reviewed By: fakepaper56

Differential Revision: https://reviews.llvm.org/D150463

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 76c4a596e911..e0e6cbe288f9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -3157,37 +3157,42 @@ bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(SDNode *N) {
   const RISCVInstrInfo &TII = *Subtarget->getInstrInfo();
   const MCInstrDesc &MaskedMCID = TII.get(N->getMachineOpcode());
 
-  bool IsTA = true;
+  bool UseTUPseudo = false;
   if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) {
-    TailPolicyOpIdx = getVecPolicyOpIdx(N, MaskedMCID);
-    if (!(N->getConstantOperandVal(*TailPolicyOpIdx) &
-          RISCVII::TAIL_AGNOSTIC)) {
-      // Keep the true-masked instruction when there is no unmasked TU
-      // instruction
-      if (I->UnmaskedTUPseudo == I->MaskedPseudo && !N->getOperand(0).isUndef())
-        return false;
-      // We can't use TA if the tie-operand is not IMPLICIT_DEF
-      if (!N->getOperand(0).isUndef())
-        IsTA = false;
+    // Some operations are their own TU.
+    if (I->UnmaskedTUPseudo == I->UnmaskedPseudo) {
+      UseTUPseudo = true;
+    } else {
+      TailPolicyOpIdx = getVecPolicyOpIdx(N, MaskedMCID);
+      if (!(N->getConstantOperandVal(*TailPolicyOpIdx) &
+            RISCVII::TAIL_AGNOSTIC)) {
+        // We can't use TA if the tie-operand is not IMPLICIT_DEF
+        if (!N->getOperand(0).isUndef()) {
+          // Keep the true-masked instruction when there is no unmasked TU
+          // instruction
+          if (I->UnmaskedTUPseudo == I->MaskedPseudo)
+            return false;
+          UseTUPseudo = true;
+        }
+      }
     }
   }
 
-  unsigned Opc = IsTA ? I->UnmaskedPseudo : I->UnmaskedTUPseudo;
+  unsigned Opc = UseTUPseudo ? I->UnmaskedTUPseudo : I->UnmaskedPseudo;
 
   // Check that we're dropping the mask operand and any policy operand
   // when we transform to this unmasked pseudo. Additionally, if this
   // instruction is tail agnostic, the unmasked instruction should not have a
   // merge op.
   uint64_t TSFlags = TII.get(Opc).TSFlags;
-  assert((IsTA != RISCVII::hasMergeOp(TSFlags)) &&
+  assert((UseTUPseudo == RISCVII::hasMergeOp(TSFlags)) &&
          RISCVII::hasDummyMaskOp(TSFlags) &&
-         !RISCVII::hasVecPolicyOp(TSFlags) &&
          "Unexpected pseudo to transform to");
   (void)TSFlags;
 
   SmallVector<SDValue, 8> Ops;
-  // Skip the merge operand at index 0 if IsTA
-  for (unsigned I = IsTA, E = N->getNumOperands(); I != E; I++) {
+  // Skip the merge operand at index 0 if !UseTUPseudo.
+  for (unsigned I = !UseTUPseudo, E = N->getNumOperands(); I != E; I++) {
     // Skip the mask, the policy, and the Glue.
     SDValue Op = N->getOperand(I);
     if (I == MaskOpIdx || I == TailPolicyOpIdx ||

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index da8be7342cfb..eb7dc8bfa18a 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -472,10 +472,12 @@ def RISCVVIntrinsicsTable : GenericTable {
   let PrimaryKeyName = "getRISCVVIntrinsicInfo";
 }
 
-class RISCVMaskedPseudo<bits<4> MaskIdx, bit HasTU = true> {
+class RISCVMaskedPseudo<bits<4> MaskIdx, bit HasTU = true, bit IsTernary = false> {
   Pseudo MaskedPseudo = !cast<Pseudo>(NAME);
   Pseudo UnmaskedPseudo = !cast<Pseudo>(!subst("_MASK", "", NAME));
-  Pseudo UnmaskedTUPseudo = !if(HasTU, !cast<Pseudo>(!subst("_MASK", "", NAME # "_TU")), MaskedPseudo);
+  Pseudo UnmaskedTUPseudo = !cond(HasTU : !cast<Pseudo>(!subst("_MASK", "", NAME # "_TU")),
+                                  IsTernary : UnmaskedPseudo,
+                                  true  :  MaskedPseudo);
   bits<4> MaskOpIdx = MaskIdx;
 }
 
@@ -3192,7 +3194,8 @@ multiclass VPseudoTernaryWithPolicy<VReg RetClass,
   let VLMul = MInfo.value in {
     let isCommutable = Commutable in
     def "_" # MInfo.MX : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
-    def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>;
+    def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>,
+                                   RISCVMaskedPseudo</*MaskOpIdx*/ 3, /*HasTU*/ false, /*IsTernary*/true>;
   }
 }
 

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 56478c078ac8..fc59d1f049ed 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -1459,12 +1459,6 @@ multiclass VPatNarrowShiftSplat_WX_WI<SDNode op, string instruction_name> {
 multiclass VPatFPMulAddVL_VV_VF<SDPatternOperator vop, string instruction_name> {
   foreach vti = AllFloatVectors in {
   defvar suffix = vti.LMul.MX;
-  def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd,
-                             vti.RegClass:$rs2, (vti.Mask true_mask),
-                             VLOpFrag)),
-            (!cast<Instruction>(instruction_name#"_VV_"# suffix)
-                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
   def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd,
                              vti.RegClass:$rs2, (vti.Mask V0),
                              VLOpFrag)),
@@ -1472,13 +1466,6 @@ multiclass VPatFPMulAddVL_VV_VF<SDPatternOperator vop, string instruction_name>
                  vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                  (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
 
-  def : Pat<(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1),
-                             vti.RegClass:$rd, vti.RegClass:$rs2,
-                             (vti.Mask true_mask),
-                             VLOpFrag)),
-            (!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix)
-                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
   def : Pat<(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1),
                              vti.RegClass:$rd, vti.RegClass:$rs2,
                              (vti.Mask V0),
@@ -1492,13 +1479,6 @@ multiclass VPatFPMulAddVL_VV_VF<SDPatternOperator vop, string instruction_name>
 multiclass VPatFPMulAccVL_VV_VF<PatFrag vop, string instruction_name> {
   foreach vti = AllFloatVectors in {
   defvar suffix = vti.LMul.MX;
-  def : Pat<(riscv_vp_merge_vl (vti.Mask true_mask),
-                         (vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
-                          vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
-                          vti.RegClass:$rd, VLOpFrag),
-            (!cast<Instruction>(instruction_name#"_VV_"# suffix)
-                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
   def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
                          (vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
                           vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
@@ -1506,13 +1486,6 @@ multiclass VPatFPMulAccVL_VV_VF<PatFrag vop, string instruction_name> {
             (!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
                  vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                  (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
-  def : Pat<(riscv_vp_merge_vl (vti.Mask true_mask),
-                         (vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
-                          vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
-                          vti.RegClass:$rd, VLOpFrag),
-            (!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix)
-                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
   def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
                          (vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
                           vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),


        


More information about the llvm-commits mailing list