[llvm] [RISCV] Make more vector pseudos commutable (PR #88379)

Pengcheng Wang via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 16 00:47:43 PDT 2024


https://github.com/wangpc-pp updated https://github.com/llvm/llvm-project/pull/88379

>From b17f878f2741acfd6a02537ab374650b53312e50 Mon Sep 17 00:00:00 2001
From: Wang Pengcheng <wangpengcheng.pp at bytedance.com>
Date: Thu, 11 Apr 2024 18:52:00 +0800
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
 =?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.6-beta.1
---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp      | 66 +++++++++++++
 .../Target/RISCV/RISCVInstrInfoVPseudos.td    | 92 ++++++++++---------
 2 files changed, 116 insertions(+), 42 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index d78f5bd9dedf3d..0b326914f7fa27 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -2711,6 +2711,50 @@ std::string RISCVInstrInfo::createMIROperandComment(
   return Comment;
 }
 
+// clang-format off
+#define CASE_RVV_OPCODE_UNMASK_LMUL(OP, LMUL)                                 \
+  RISCV::Pseudo##OP##_##LMUL
+
+#define CASE_RVV_OPCODE_MASK_LMUL(OP, LMUL)                                   \
+  RISCV::Pseudo##OP##_##LMUL##_MASK
+
+#define CASE_RVV_OPCODE_LMUL(OP, LMUL)                                        \
+  CASE_RVV_OPCODE_UNMASK_LMUL(OP, LMUL):                                      \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, LMUL)
+
+#define CASE_RVV_OPCODE_UNMASK_WIDEN(OP)                                      \
+  CASE_RVV_OPCODE_UNMASK_LMUL(OP, MF8):                                       \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, MF4):                                  \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, MF2):                                  \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M1):                                   \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M2):                                   \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M4)
+
+#define CASE_RVV_OPCODE_UNMASK(OP)                                            \
+  CASE_RVV_OPCODE_UNMASK_WIDEN(OP):                                           \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M8)
+
+#define CASE_RVV_OPCODE_MASK_WIDEN(OP)                                        \
+  CASE_RVV_OPCODE_MASK_LMUL(OP, MF8):                                         \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, MF4):                                    \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, MF2):                                    \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, M1):                                     \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, M2):                                     \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, M4)
+
+#define CASE_RVV_OPCODE_MASK(OP)                                              \
+  CASE_RVV_OPCODE_MASK_WIDEN(OP):                                             \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, M8)
+
+#define CASE_RVV_OPCODE_WIDEN(OP)                                             \
+  CASE_RVV_OPCODE_UNMASK_WIDEN(OP):                                           \
+  case CASE_RVV_OPCODE_MASK_WIDEN(OP)
+
+#define CASE_RVV_OPCODE(OP)                                                   \
+  CASE_RVV_OPCODE_UNMASK(OP):                                                 \
+  case CASE_RVV_OPCODE_MASK(OP)
+// clang-format on
+
 // clang-format off
 #define CASE_VFMA_OPCODE_COMMON(OP, TYPE, LMUL)                                \
   RISCV::PseudoV##OP##_##TYPE##_##LMUL
@@ -2768,6 +2812,28 @@ bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI,
   case RISCV::PseudoCCMOVGPR:
     // Operands 4 and 5 are commutable.
     return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 4, 5);
+  case CASE_RVV_OPCODE(VADD_VV):
+  case CASE_RVV_OPCODE(VAND_VV):
+  case CASE_RVV_OPCODE(VOR_VV):
+  case CASE_RVV_OPCODE(VXOR_VV):
+  case CASE_RVV_OPCODE_MASK(VMSEQ_VV):
+  case CASE_RVV_OPCODE_MASK(VMSNE_VV):
+  case CASE_RVV_OPCODE(VMIN_VV):
+  case CASE_RVV_OPCODE(VMINU_VV):
+  case CASE_RVV_OPCODE(VMAX_VV):
+  case CASE_RVV_OPCODE(VMAXU_VV):
+  case CASE_RVV_OPCODE(VMUL_VV):
+  case CASE_RVV_OPCODE(VMULH_VV):
+  case CASE_RVV_OPCODE(VMULHU_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWADD_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWADDU_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWMUL_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWMULU_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWMACC_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWMACCU_VV):
+  case CASE_RVV_OPCODE_UNMASK(VADC_VVM):
+    // Operands 2 and 3 are commutable.
+    return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 2, 3);
   case CASE_VFMA_SPLATS(FMADD):
   case CASE_VFMA_SPLATS(FMSUB):
   case CASE_VFMA_SPLATS(FMACC):
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index cf9a31c23a06e0..7fb958ec5fcd50 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -2127,8 +2127,9 @@ multiclass VPseudoBinary<VReg RetClass,
                          LMULInfo MInfo,
                          string Constraint = "",
                          int sew = 0,
-                         int TargetConstraintType = 1> {
-  let VLMul = MInfo.value, SEW=sew in {
+                         int TargetConstraintType = 1,
+                         bit Commutable = 0> {
+  let VLMul = MInfo.value, SEW=sew, isCommutable = Commutable in {
     defvar suffix = !if(sew, "_" # MInfo.MX # "_E" # sew, "_" # MInfo.MX);
     def suffix : VPseudoBinaryNoMaskTU<RetClass, Op1Class, Op2Class,
                                        Constraint, TargetConstraintType>;
@@ -2167,8 +2168,9 @@ multiclass VPseudoBinaryM<VReg RetClass,
                           DAGOperand Op2Class,
                           LMULInfo MInfo,
                           string Constraint = "",
-                          int TargetConstraintType = 1> {
-  let VLMul = MInfo.value in {
+                          int TargetConstraintType = 1,
+                          bit Commutable = 0> {
+  let VLMul = MInfo.value, isCommutable = Commutable in {
     def "_" # MInfo.MX : VPseudoBinaryMOutNoMask<RetClass, Op1Class, Op2Class,
                                                  Constraint, TargetConstraintType>;
     let ForceTailAgnostic = true in
@@ -2224,8 +2226,8 @@ multiclass VPseudoTiedBinaryRoundingMode<VReg RetClass,
 }
 
 
-multiclass VPseudoBinaryV_VV<LMULInfo m, string Constraint = "", int sew = 0> {
-  defm _VV : VPseudoBinary<m.vrclass, m.vrclass, m.vrclass, m, Constraint, sew>;
+multiclass VPseudoBinaryV_VV<LMULInfo m, string Constraint = "", int sew = 0, bit Commutable = 0> {
+  defm _VV : VPseudoBinary<m.vrclass, m.vrclass, m.vrclass, m, Constraint, sew, Commutable=Commutable>;
 }
 
 multiclass VPseudoBinaryV_VV_RM<LMULInfo m, string Constraint = ""> {
@@ -2329,9 +2331,10 @@ multiclass VPseudoVALU_MM<bit Commutable = 0> {
 // * The destination EEW is greater than the source EEW, the source EMUL is
 //   at least 1, and the overlap is in the highest-numbered part of the
 //   destination register group is legal. Otherwise, it is illegal.
-multiclass VPseudoBinaryW_VV<LMULInfo m> {
+multiclass VPseudoBinaryW_VV<LMULInfo m, bit Commutable = 0> {
   defm _VV : VPseudoBinary<m.wvrclass, m.vrclass, m.vrclass, m,
-                           "@earlyclobber $rd", TargetConstraintType=3>;
+                           "@earlyclobber $rd", TargetConstraintType=3,
+                           Commutable=Commutable>;
 }
 
 multiclass VPseudoBinaryW_VV_RM<LMULInfo m> {
@@ -2449,7 +2452,9 @@ multiclass VPseudoBinaryV_VM<LMULInfo m, bit CarryOut = 0, bit CarryIn = 1,
                          m.vrclass, m.vrclass, m, CarryIn, Constraint, TargetConstraintType>;
 }
 
-multiclass VPseudoTiedBinaryV_VM<LMULInfo m, int TargetConstraintType = 1> {
+multiclass VPseudoTiedBinaryV_VM<LMULInfo m, int TargetConstraintType = 1,
+                                 bit Commutable = 0> {
+  let isCommutable = Commutable in
   def "_VVM" # "_" # m.MX:
     VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
                              m.vrclass, m.vrclass, m, 1, "",
@@ -2659,9 +2664,11 @@ multiclass PseudoVEXT_VF8 {
 //  lowest-numbered part of the source register group".
 // With LMUL<=1 the source and dest occupy a single register so any overlap
 // is in the lowest-numbered part.
-multiclass VPseudoBinaryM_VV<LMULInfo m, int TargetConstraintType = 1> {
+multiclass VPseudoBinaryM_VV<LMULInfo m, int TargetConstraintType = 1,
+                             bit Commutable = 0> {
   defm _VV : VPseudoBinaryM<VR, m.vrclass, m.vrclass, m,
-                            !if(!ge(m.octuple, 16), "@earlyclobber $rd", ""), TargetConstraintType>;
+                            !if(!ge(m.octuple, 16), "@earlyclobber $rd", ""),
+                            TargetConstraintType, Commutable=Commutable>;
 }
 
 multiclass VPseudoBinaryM_VX<LMULInfo m, int TargetConstraintType = 1> {
@@ -2743,10 +2750,11 @@ multiclass VPseudoVSSHT_VV_VX_VI_RM<Operand ImmType = simm5, string Constraint =
   }
 }
 
-multiclass VPseudoVALU_VV_VX_VI<Operand ImmType = simm5, string Constraint = ""> {
+multiclass VPseudoVALU_VV_VX_VI<Operand ImmType = simm5, string Constraint = "",
+                                bit Commutable = 0> {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryV_VV<m, Constraint>,
+    defm "" : VPseudoBinaryV_VV<m, Constraint, Commutable=Commutable>,
             SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", mx,
                         forceMergeOpRead=true>;
     defm "" : VPseudoBinaryV_VX<m, Constraint>,
@@ -2796,17 +2804,17 @@ multiclass VPseudoVAALU_VV_VX_RM {
 multiclass VPseudoVMINMAX_VV_VX {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryV_VV<m>,
+    defm "" : VPseudoBinaryV_VV<m, Commutable=1>,
               SchedBinary<"WriteVIMinMaxV", "ReadVIMinMaxV", "ReadVIMinMaxV", mx>;
     defm "" : VPseudoBinaryV_VX<m>,
               SchedBinary<"WriteVIMinMaxX", "ReadVIMinMaxV", "ReadVIMinMaxX", mx>;
   }
 }
 
-multiclass VPseudoVMUL_VV_VX {
+multiclass VPseudoVMUL_VV_VX<bit Commutable = 0> {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryV_VV<m>,
+    defm "" : VPseudoBinaryV_VV<m, Commutable=Commutable>,
               SchedBinary<"WriteVIMulV", "ReadVIMulV", "ReadVIMulV", mx>;
     defm "" : VPseudoBinaryV_VX<m>,
               SchedBinary<"WriteVIMulX", "ReadVIMulV", "ReadVIMulX", mx>;
@@ -2952,10 +2960,10 @@ multiclass VPseudoVALU_VX_VI<Operand ImmType = simm5> {
   }
 }
 
-multiclass VPseudoVWALU_VV_VX {
+multiclass VPseudoVWALU_VV_VX<bit Commutable = 0> {
   foreach m = MxListW in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryW_VV<m>,
+    defm "" : VPseudoBinaryW_VV<m, Commutable=Commutable>,
               SchedBinary<"WriteVIWALUV", "ReadVIWALUV", "ReadVIWALUV", mx,
                           forceMergeOpRead=true>;
     defm "" : VPseudoBinaryW_VX<m>, 
@@ -2964,10 +2972,10 @@ multiclass VPseudoVWALU_VV_VX {
   }
 }
 
-multiclass VPseudoVWMUL_VV_VX {
+multiclass VPseudoVWMUL_VV_VX<bit Commutable = 0> {
   foreach m = MxListW in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryW_VV<m>,
+    defm "" : VPseudoBinaryW_VV<m, Commutable=Commutable>,
               SchedBinary<"WriteVIWMulV", "ReadVIWMulV", "ReadVIWMulV", mx,
                           forceMergeOpRead=true>;
     defm "" : VPseudoBinaryW_VX<m>,
@@ -3059,7 +3067,7 @@ multiclass VPseudoVMRG_VM_XM_IM {
 multiclass VPseudoVCALU_VM_XM_IM {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defm "" : VPseudoTiedBinaryV_VM<m>,
+    defm "" : VPseudoTiedBinaryV_VM<m, Commutable=1>,
               SchedBinary<"WriteVICALUV", "ReadVICALUV", "ReadVICALUV", mx,
                           forceMergeOpRead=true>;
     defm "" : VPseudoTiedBinaryV_XM<m>,
@@ -3269,10 +3277,10 @@ multiclass VPseudoTernaryV_VF_AAXA_RM<LMULInfo m, FPR_Info f, string Constraint
                                                           Commutable=1>;
 }
 
-multiclass VPseudoTernaryW_VV<LMULInfo m> {
+multiclass VPseudoTernaryW_VV<LMULInfo m, bit Commutable = 0> {
   defvar constraint = "@earlyclobber $rd";
   defm _VV : VPseudoTernaryWithPolicy<m.wvrclass, m.vrclass, m.vrclass, m,
-                                      constraint, /*Commutable*/ 0, TargetConstraintType=3>;
+                                      constraint, Commutable=Commutable, TargetConstraintType=3>;
 }
 
 multiclass VPseudoTernaryW_VV_RM<LMULInfo m> {
@@ -3361,10 +3369,10 @@ multiclass VPseudoVSLD_VX_VI<Operand ImmType = simm5, string Constraint = ""> {
   }
 }
 
-multiclass VPseudoVWMAC_VV_VX {
+multiclass VPseudoVWMAC_VV_VX<bit Commutable = 0> {
   foreach m = MxListW in {
     defvar mx = m.MX;
-    defm "" : VPseudoTernaryW_VV<m>,
+    defm "" : VPseudoTernaryW_VV<m, Commutable=Commutable>,
               SchedTernary<"WriteVIWMulAddV", "ReadVIWMulAddV", "ReadVIWMulAddV",
                            "ReadVIWMulAddV", mx>;
     defm "" : VPseudoTernaryW_VX<m>,
@@ -3415,10 +3423,10 @@ multiclass VPseudoVWMAC_VV_VF_BF_RM {
   }
 }
 
-multiclass VPseudoVCMPM_VV_VX_VI {
+multiclass VPseudoVCMPM_VV_VX_VI<bit Commutable = 0> {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryM_VV<m, TargetConstraintType=2>,
+    defm "" : VPseudoBinaryM_VV<m, TargetConstraintType=2, Commutable=Commutable>,
               SchedBinary<"WriteVICmpV", "ReadVICmpV", "ReadVICmpV", mx>;
     defm "" : VPseudoBinaryM_VX<m, TargetConstraintType=2>,
               SchedBinary<"WriteVICmpX", "ReadVICmpV", "ReadVICmpX", mx>;
@@ -6159,7 +6167,7 @@ defm PseudoVLSEG : VPseudoUSSegLoadFF;
 //===----------------------------------------------------------------------===//
 // 11.1. Vector Single-Width Integer Add and Subtract
 //===----------------------------------------------------------------------===//
-defm PseudoVADD   : VPseudoVALU_VV_VX_VI;
+defm PseudoVADD   : VPseudoVALU_VV_VX_VI<Commutable=1>;
 defm PseudoVSUB   : VPseudoVALU_VV_VX;
 defm PseudoVRSUB  : VPseudoVALU_VX_VI;
 
@@ -6224,9 +6232,9 @@ foreach vti = AllIntegerVectors in {
 //===----------------------------------------------------------------------===//
 // 11.2. Vector Widening Integer Add/Subtract
 //===----------------------------------------------------------------------===//
-defm PseudoVWADDU : VPseudoVWALU_VV_VX;
+defm PseudoVWADDU : VPseudoVWALU_VV_VX<Commutable=1>;
 defm PseudoVWSUBU : VPseudoVWALU_VV_VX;
-defm PseudoVWADD  : VPseudoVWALU_VV_VX;
+defm PseudoVWADD  : VPseudoVWALU_VV_VX<Commutable=1>;
 defm PseudoVWSUB  : VPseudoVWALU_VV_VX;
 defm PseudoVWADDU : VPseudoVWALU_WV_WX;
 defm PseudoVWSUBU : VPseudoVWALU_WV_WX;
@@ -6257,9 +6265,9 @@ defm PseudoVMSBC : VPseudoVCALUM_V_X<"@earlyclobber $rd">;
 //===----------------------------------------------------------------------===//
 // 11.5. Vector Bitwise Logical Instructions
 //===----------------------------------------------------------------------===//
-defm PseudoVAND : VPseudoVALU_VV_VX_VI;
-defm PseudoVOR  : VPseudoVALU_VV_VX_VI;
-defm PseudoVXOR : VPseudoVALU_VV_VX_VI;
+defm PseudoVAND : VPseudoVALU_VV_VX_VI<Commutable=1>;
+defm PseudoVOR  : VPseudoVALU_VV_VX_VI<Commutable=1>;
+defm PseudoVXOR : VPseudoVALU_VV_VX_VI<Commutable=1>;
 
 //===----------------------------------------------------------------------===//
 // 11.6. Vector Single-Width Bit Shift Instructions
@@ -6277,8 +6285,8 @@ defm PseudoVNSRA : VPseudoVNSHT_WV_WX_WI;
 //===----------------------------------------------------------------------===//
 // 11.8. Vector Integer Comparison Instructions
 //===----------------------------------------------------------------------===//
-defm PseudoVMSEQ  : VPseudoVCMPM_VV_VX_VI;
-defm PseudoVMSNE  : VPseudoVCMPM_VV_VX_VI;
+defm PseudoVMSEQ  : VPseudoVCMPM_VV_VX_VI<Commutable=1>;
+defm PseudoVMSNE  : VPseudoVCMPM_VV_VX_VI<Commutable=1>;
 defm PseudoVMSLTU : VPseudoVCMPM_VV_VX;
 defm PseudoVMSLT  : VPseudoVCMPM_VV_VX;
 defm PseudoVMSLEU : VPseudoVCMPM_VV_VX_VI;
@@ -6297,9 +6305,9 @@ defm PseudoVMAX  : VPseudoVMINMAX_VV_VX;
 //===----------------------------------------------------------------------===//
 // 11.10. Vector Single-Width Integer Multiply Instructions
 //===----------------------------------------------------------------------===//
-defm PseudoVMUL    : VPseudoVMUL_VV_VX;
-defm PseudoVMULH   : VPseudoVMUL_VV_VX;
-defm PseudoVMULHU  : VPseudoVMUL_VV_VX;
+defm PseudoVMUL    : VPseudoVMUL_VV_VX<Commutable=1>;
+defm PseudoVMULH   : VPseudoVMUL_VV_VX<Commutable=1>;
+defm PseudoVMULHU  : VPseudoVMUL_VV_VX<Commutable=1>;
 defm PseudoVMULHSU : VPseudoVMUL_VV_VX;
 
 //===----------------------------------------------------------------------===//
@@ -6313,8 +6321,8 @@ defm PseudoVREM  : VPseudoVDIV_VV_VX;
 //===----------------------------------------------------------------------===//
 // 11.12. Vector Widening Integer Multiply Instructions
 //===----------------------------------------------------------------------===//
-defm PseudoVWMUL   : VPseudoVWMUL_VV_VX;
-defm PseudoVWMULU  : VPseudoVWMUL_VV_VX;
+defm PseudoVWMUL   : VPseudoVWMUL_VV_VX<Commutable=1>;
+defm PseudoVWMULU  : VPseudoVWMUL_VV_VX<Commutable=1>;
 defm PseudoVWMULSU : VPseudoVWMUL_VV_VX;
 
 //===----------------------------------------------------------------------===//
@@ -6328,8 +6336,8 @@ defm PseudoVNMSUB : VPseudoVMAC_VV_VX_AAXA;
 //===----------------------------------------------------------------------===//
 // 11.14. Vector Widening Integer Multiply-Add Instructions
 //===----------------------------------------------------------------------===//
-defm PseudoVWMACCU  : VPseudoVWMAC_VV_VX;
-defm PseudoVWMACC   : VPseudoVWMAC_VV_VX;
+defm PseudoVWMACCU  : VPseudoVWMAC_VV_VX<Commutable=1>;
+defm PseudoVWMACC   : VPseudoVWMAC_VV_VX<Commutable=1>;
 defm PseudoVWMACCSU : VPseudoVWMAC_VV_VX;
 defm PseudoVWMACCUS : VPseudoVWMAC_VX;
 



More information about the llvm-commits mailing list