[llvm] 6449bea - [RISCV] Select unmasked RVV pseudos in a DAG post-process

Fraser Cormack via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 9 00:00:56 PST 2022


Author: Fraser Cormack
Date: 2022-02-09T07:50:15Z
New Revision: 6449bea508f1d1e193497697f185953769ad65e2

URL: https://github.com/llvm/llvm-project/commit/6449bea508f1d1e193497697f185953769ad65e2
DIFF: https://github.com/llvm/llvm-project/commit/6449bea508f1d1e193497697f185953769ad65e2.diff

LOG: [RISCV] Select unmasked RVV pseudos in a DAG post-process

This patch drops TableGen patterns matching all-ones masked RVV pseudos
in the case where there are fallback patterns matching the generic
masked forms to "_MASK" pseudos. This optimization is now performed with
a SelectionDAG post-processing step which peephole-optimizes these same
pseudos with all-ones masks and swaps them out to their unmasked
pseudos.

This cuts our generated ISel table down by around ~5% (~110kB) in lieu
of a far smaller auto-generated table to help with the peephole.

This only targets our custom RISCVISD::*_VL binary operator nodes, which
use the one form for both masked and unmasked variants. A similar
approach could be used for our intrinsics but we'd need to do some work,
e.g., to represent unmasked intrinsics as true-masked intrinsics at the
IR or ISel level. At a rough estimate, this could save us a further 9%
on the size of our ISel table for the binary intrinsic patterns alone.

There is no observable impact on our tests.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D118810

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index bc0fde3f66632..b0c5fcd53c41e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -37,6 +37,7 @@ namespace RISCV {
 #define GET_RISCVVSETable_IMPL
 #define GET_RISCVVLXTable_IMPL
 #define GET_RISCVVSXTable_IMPL
+#define GET_RISCVMaskedPseudosTable_IMPL
 #include "RISCVGenSearchableTables.inc"
 } // namespace RISCV
 } // namespace llvm
@@ -123,6 +124,7 @@ void RISCVDAGToDAGISel::PostprocessISelDAG() {
 
     MadeChange |= doPeepholeSExtW(N);
     MadeChange |= doPeepholeLoadStoreADDI(N);
+    MadeChange |= doPeepholeMaskedRVV(N);
   }
 
   if (MadeChange)
@@ -2133,6 +2135,102 @@ bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
   return false;
 }
 
+// Optimize masked RVV pseudo instructions with a known all-ones mask to their
+// corresponding "unmasked" pseudo versions. The mask we're interested in will
+// take the form of a V0 physical register operand, with a glued
+// register-setting instruction.
+bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(SDNode *N) {
+  const RISCV::RISCVMaskedPseudoInfo *I =
+      RISCV::getMaskedPseudoInfo(N->getMachineOpcode());
+  if (!I)
+    return false;
+
+  unsigned MaskOpIdx = I->MaskOpIdx;
+
+  // Check that we're using V0 as a mask register.
+  if (!isa<RegisterSDNode>(N->getOperand(MaskOpIdx)) ||
+      cast<RegisterSDNode>(N->getOperand(MaskOpIdx))->getReg() != RISCV::V0)
+    return false;
+
+  // The glued user defines V0.
+  const auto *Glued = N->getGluedNode();
+
+  if (!Glued || Glued->getOpcode() != ISD::CopyToReg)
+    return false;
+
+  // Check that we're defining V0 as a mask register.
+  if (!isa<RegisterSDNode>(Glued->getOperand(1)) ||
+      cast<RegisterSDNode>(Glued->getOperand(1))->getReg() != RISCV::V0)
+    return false;
+
+  // Check the instruction defining V0; it needs to be a VMSET pseudo.
+  SDValue MaskSetter = Glued->getOperand(2);
+
+  const auto IsVMSet = [](unsigned Opc) {
+    return Opc == RISCV::PseudoVMSET_M_B1 || Opc == RISCV::PseudoVMSET_M_B16 ||
+           Opc == RISCV::PseudoVMSET_M_B2 || Opc == RISCV::PseudoVMSET_M_B32 ||
+           Opc == RISCV::PseudoVMSET_M_B4 || Opc == RISCV::PseudoVMSET_M_B64 ||
+           Opc == RISCV::PseudoVMSET_M_B8;
+  };
+
+  // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
+  // undefined behaviour if it's the wrong bitwidth, so we could choose to
+  // assume that it's all-ones? Same applies to its VL.
+  if (!MaskSetter->isMachineOpcode() || !IsVMSet(MaskSetter.getMachineOpcode()))
+    return false;
+
+  // Retrieve the tail policy operand index, if any.
+  Optional<unsigned> TailPolicyOpIdx;
+  const RISCVInstrInfo *TII = static_cast<const RISCVInstrInfo *>(
+      CurDAG->getSubtarget().getInstrInfo());
+
+  const MCInstrDesc &MaskedMCID = TII->get(N->getMachineOpcode());
+
+  if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) {
+    // The last operand of the pseudo is the policy op, but we're expecting a
+    // Glue operand last. We may also have a chain.
+    TailPolicyOpIdx = N->getNumOperands() - 1;
+    if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Glue)
+      (*TailPolicyOpIdx)--;
+    if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Other)
+      (*TailPolicyOpIdx)--;
+
+    // If the policy isn't TAIL_AGNOSTIC we can't perform this optimization.
+    if (N->getConstantOperandVal(*TailPolicyOpIdx) != RISCVII::TAIL_AGNOSTIC)
+      return false;
+  }
+
+  const MCInstrDesc &UnmaskedMCID = TII->get(I->UnmaskedPseudo);
+
+  // Check that we're dropping the merge operand, the mask operand, and any
+  // policy operand when we transform to this unmasked pseudo.
+  assert(!RISCVII::hasMergeOp(UnmaskedMCID.TSFlags) &&
+         RISCVII::hasDummyMaskOp(UnmaskedMCID.TSFlags) &&
+         !RISCVII::hasVecPolicyOp(UnmaskedMCID.TSFlags) &&
+         "Unexpected pseudo to transform to");
+
+  SmallVector<SDValue, 8> Ops;
+  // Skip the merge operand at index 0.
+  for (unsigned I = 1, E = N->getNumOperands(); I != E; I++) {
+    // Skip the mask, the policy, and the Glue.
+    SDValue Op = N->getOperand(I);
+    if (I == MaskOpIdx || I == TailPolicyOpIdx ||
+        Op.getValueType() == MVT::Glue)
+      continue;
+    Ops.push_back(Op);
+  }
+
+  // Transitively apply any node glued to our new node.
+  if (auto *TGlued = Glued->getGluedNode())
+    Ops.push_back(SDValue(TGlued, TGlued->getNumValues() - 1));
+
+  SDNode *Result =
+      CurDAG->getMachineNode(I->UnmaskedPseudo, SDLoc(N), N->getVTList(), Ops);
+  ReplaceUses(N, Result);
+
+  return true;
+}
+
 // This pass converts a legalized DAG into a RISCV-specific DAG, ready
 // for instruction scheduling.
 FunctionPass *llvm::createRISCVISelDag(RISCVTargetMachine &TM) {

diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
index c429a9298739f..84acd28b277eb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
@@ -117,6 +117,7 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
 private:
   bool doPeepholeLoadStoreADDI(SDNode *Node);
   bool doPeepholeSExtW(SDNode *Node);
+  bool doPeepholeMaskedRVV(SDNode *Node);
 };
 
 namespace RISCV {
@@ -187,6 +188,12 @@ struct VLX_VSXPseudo {
   uint16_t Pseudo;
 };
 
+struct RISCVMaskedPseudoInfo {
+  uint16_t MaskedPseudo;
+  uint16_t UnmaskedPseudo;
+  uint8_t MaskOpIdx;
+};
+
 #define GET_RISCVVSSEGTable_DECL
 #define GET_RISCVVLSEGTable_DECL
 #define GET_RISCVVLXSEGTable_DECL
@@ -195,6 +202,7 @@ struct VLX_VSXPseudo {
 #define GET_RISCVVSETable_DECL
 #define GET_RISCVVLXTable_DECL
 #define GET_RISCVVSXTable_DECL
+#define GET_RISCVMaskedPseudosTable_DECL
 #include "RISCVGenSearchableTables.inc"
 } // namespace RISCV
 

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index 47f11e8510ad6..41fa5087e9303 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -424,6 +424,20 @@ def RISCVVIntrinsicsTable : GenericTable {
   let PrimaryKeyName = "getRISCVVIntrinsicInfo";
 }
 
+class RISCVMaskedPseudo<bits<4> MaskIdx> {
+  Pseudo MaskedPseudo = !cast<Pseudo>(NAME);
+  Pseudo UnmaskedPseudo = !cast<Pseudo>(!subst("_MASK", "", NAME));
+  bits<4> MaskOpIdx = MaskIdx;
+}
+
+def RISCVMaskedPseudosTable : GenericTable {
+  let FilterClass = "RISCVMaskedPseudo";
+  let CppTypeName = "RISCVMaskedPseudoInfo";
+  let Fields = ["MaskedPseudo", "UnmaskedPseudo", "MaskOpIdx"];
+  let PrimaryKey = ["MaskedPseudo"];
+  let PrimaryKeyName = "getMaskedPseudoInfo";
+}
+
 class RISCVVLE<bit M, bit TU, bit Str, bit F, bits<3> S, bits<3> L> {
   bits<1> Masked = M;
   bits<1> IsTU = TU;
@@ -1639,7 +1653,8 @@ multiclass VPseudoBinary<VReg RetClass,
     def "_" # MInfo.MX : VPseudoBinaryNoMask<RetClass, Op1Class, Op2Class,
                                              Constraint>;
     def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskTA<RetClass, Op1Class, Op2Class,
-                                                       Constraint>;
+                                                       Constraint>,
+                                   RISCVMaskedPseudo</*MaskOpIdx*/ 3>;
   }
 }
 

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 8f37e7ac246f3..f146242e01de2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -309,15 +309,6 @@ multiclass VPatBinaryVL_V<SDNode vop,
                           LMULInfo vlmul,
                           VReg op1_reg_class,
                           VReg op2_reg_class> {
-  def : Pat<(result_type (vop
-                         (op1_type op1_reg_class:$rs1),
-                         (op2_type op2_reg_class:$rs2),
-                         (mask_type true_mask),
-                         VLOpFrag)),
-            (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX)
-                         op1_reg_class:$rs1,
-                         op2_reg_class:$rs2,
-                         GPR:$vl, sew)>;
   def : Pat<(result_type (vop
                          (op1_type op1_reg_class:$rs1),
                          (op2_type op2_reg_class:$rs2),
@@ -342,15 +333,6 @@ multiclass VPatBinaryVL_XI<SDNode vop,
                            VReg vop_reg_class,
                            ComplexPattern SplatPatKind,
                            DAGOperand xop_kind> {
-  def : Pat<(result_type (vop
-                     (vop1_type vop_reg_class:$rs1),
-                     (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
-                     (mask_type true_mask),
-                     VLOpFrag)),
-        (!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX)
-                     vop_reg_class:$rs1,
-                     xop_kind:$rs2,
-                     GPR:$vl, sew)>;
   def : Pat<(result_type (vop
                      (vop1_type vop_reg_class:$rs1),
                      (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
@@ -422,14 +404,6 @@ multiclass VPatBinaryVL_VF<SDNode vop,
                            LMULInfo vlmul,
                            VReg vop_reg_class,
                            RegisterClass scalar_reg_class> {
-  def : Pat<(result_type (vop (vop_type vop_reg_class:$rs1),
-                         (vop_type (SplatFPOp scalar_reg_class:$rs2)),
-                         (mask_type true_mask),
-                         VLOpFrag)),
-        (!cast<Instruction>(instruction_name#"_"#vlmul.MX)
-                     vop_reg_class:$rs1,
-                     scalar_reg_class:$rs2,
-                     GPR:$vl, sew)>;
   def : Pat<(result_type (vop (vop_type vop_reg_class:$rs1),
                          (vop_type (SplatFPOp scalar_reg_class:$rs2)),
                          (mask_type V0),
@@ -454,13 +428,6 @@ multiclass VPatBinaryFPVL_VV_VF<SDNode vop, string instruction_name> {
 
 multiclass VPatBinaryFPVL_R_VF<SDNode vop, string instruction_name> {
   foreach fvti = AllFloatVectors in {
-    def : Pat<(fvti.Vector (vop (SplatFPOp fvti.ScalarRegClass:$rs2),
-                                fvti.RegClass:$rs1,
-                                (fvti.Mask true_mask),
-                                VLOpFrag)),
-              (!cast<Instruction>(instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
-                           fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2,
-                           GPR:$vl, fvti.Log2SEW)>;
     def : Pat<(fvti.Vector (vop (SplatFPOp fvti.ScalarRegClass:$rs2),
                                 fvti.RegClass:$rs1,
                                 (fvti.Mask V0),
@@ -747,22 +714,12 @@ defm : VPatBinaryVL_VV_VX<riscv_sub_vl, "PseudoVSUB">;
 // Handle VRSUB specially since it's the only integer binary op with reversed
 // pattern operands
 foreach vti = AllIntegerVectors in {
-  def : Pat<(riscv_sub_vl (vti.Vector (SplatPat (XLenVT GPR:$rs2))),
-                          (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask),
-                          VLOpFrag),
-            (!cast<Instruction>("PseudoVRSUB_VX_"# vti.LMul.MX)
-                 vti.RegClass:$rs1, GPR:$rs2, GPR:$vl, vti.Log2SEW)>;
   def : Pat<(riscv_sub_vl (vti.Vector (SplatPat (XLenVT GPR:$rs2))),
                           (vti.Vector vti.RegClass:$rs1), (vti.Mask V0),
                           VLOpFrag),
             (!cast<Instruction>("PseudoVRSUB_VX_"# vti.LMul.MX#"_MASK")
                  (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, GPR:$rs2,
                  (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-  def : Pat<(riscv_sub_vl (vti.Vector (SplatPat_simm5 simm5:$rs2)),
-                          (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask),
-                          VLOpFrag),
-            (!cast<Instruction>("PseudoVRSUB_VI_"# vti.LMul.MX)
-                 vti.RegClass:$rs1, simm5:$rs2, GPR:$vl, vti.Log2SEW)>;
   def : Pat<(riscv_sub_vl (vti.Vector (SplatPat_simm5 simm5:$rs2)),
                           (vti.Vector vti.RegClass:$rs1), (vti.Mask V0),
                           VLOpFrag),


        


More information about the llvm-commits mailing list