[llvm] d34a2c2 - [RISCV] Make more vector pseudos commutable

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 16 00:55:18 PDT 2024


Author: Pengcheng Wang
Date: 2024-04-16T15:55:14+08:00
New Revision: d34a2c2adb2a4f1dc262c5756d3725caa4ea2571

URL: https://github.com/llvm/llvm-project/commit/d34a2c2adb2a4f1dc262c5756d3725caa4ea2571
DIFF: https://github.com/llvm/llvm-project/commit/d34a2c2adb2a4f1dc262c5756d3725caa4ea2571.diff

LOG: [RISCV] Make more vector pseudos commutable

This PR includes:
* vadd.vv/vand.vv/vor.vv/vxor.vv
* vmseq.vv/vmsne.vv
* vmin.vv/vminu.vv/vmax.vv/vmaxu.vv
* vmul.vv/vmulh.vv/vmulhu.vv
* vwadd.vv/vwaddu.vv
* vwmul.vv/vwmulu
* vwmacc.vv/vwmaccu.vv
* vadc.vvm

There is no test change, I may add it later.

Fixes part of #64422

Reviewers: michaelmaitland, preames, lukel97, topperc, asb

Reviewed By: topperc, lukel97

Pull Request: https://github.com/llvm/llvm-project/pull/88379

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index b0fda040519a57..668062c8d33f6f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -2718,6 +2718,50 @@ std::string RISCVInstrInfo::createMIROperandComment(
   return Comment;
 }
 
+// clang-format off
+#define CASE_RVV_OPCODE_UNMASK_LMUL(OP, LMUL)                                 \
+  RISCV::Pseudo##OP##_##LMUL
+
+#define CASE_RVV_OPCODE_MASK_LMUL(OP, LMUL)                                   \
+  RISCV::Pseudo##OP##_##LMUL##_MASK
+
+#define CASE_RVV_OPCODE_LMUL(OP, LMUL)                                        \
+  CASE_RVV_OPCODE_UNMASK_LMUL(OP, LMUL):                                      \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, LMUL)
+
+#define CASE_RVV_OPCODE_UNMASK_WIDEN(OP)                                      \
+  CASE_RVV_OPCODE_UNMASK_LMUL(OP, MF8):                                       \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, MF4):                                  \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, MF2):                                  \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M1):                                   \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M2):                                   \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M4)
+
+#define CASE_RVV_OPCODE_UNMASK(OP)                                            \
+  CASE_RVV_OPCODE_UNMASK_WIDEN(OP):                                           \
+  case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M8)
+
+#define CASE_RVV_OPCODE_MASK_WIDEN(OP)                                        \
+  CASE_RVV_OPCODE_MASK_LMUL(OP, MF8):                                         \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, MF4):                                    \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, MF2):                                    \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, M1):                                     \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, M2):                                     \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, M4)
+
+#define CASE_RVV_OPCODE_MASK(OP)                                              \
+  CASE_RVV_OPCODE_MASK_WIDEN(OP):                                             \
+  case CASE_RVV_OPCODE_MASK_LMUL(OP, M8)
+
+#define CASE_RVV_OPCODE_WIDEN(OP)                                             \
+  CASE_RVV_OPCODE_UNMASK_WIDEN(OP):                                           \
+  case CASE_RVV_OPCODE_MASK_WIDEN(OP)
+
+#define CASE_RVV_OPCODE(OP)                                                   \
+  CASE_RVV_OPCODE_UNMASK(OP):                                                 \
+  case CASE_RVV_OPCODE_MASK(OP)
+// clang-format on
+
 // clang-format off
 #define CASE_VMA_OPCODE_COMMON(OP, TYPE, LMUL)                                 \
   RISCV::PseudoV##OP##_##TYPE##_##LMUL
@@ -2798,6 +2842,28 @@ bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI,
   case RISCV::PseudoCCMOVGPR:
     // Operands 4 and 5 are commutable.
     return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 4, 5);
+  case CASE_RVV_OPCODE(VADD_VV):
+  case CASE_RVV_OPCODE(VAND_VV):
+  case CASE_RVV_OPCODE(VOR_VV):
+  case CASE_RVV_OPCODE(VXOR_VV):
+  case CASE_RVV_OPCODE_MASK(VMSEQ_VV):
+  case CASE_RVV_OPCODE_MASK(VMSNE_VV):
+  case CASE_RVV_OPCODE(VMIN_VV):
+  case CASE_RVV_OPCODE(VMINU_VV):
+  case CASE_RVV_OPCODE(VMAX_VV):
+  case CASE_RVV_OPCODE(VMAXU_VV):
+  case CASE_RVV_OPCODE(VMUL_VV):
+  case CASE_RVV_OPCODE(VMULH_VV):
+  case CASE_RVV_OPCODE(VMULHU_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWADD_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWADDU_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWMUL_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWMULU_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWMACC_VV):
+  case CASE_RVV_OPCODE_WIDEN(VWMACCU_VV):
+  case CASE_RVV_OPCODE_UNMASK(VADC_VVM):
+    // Operands 2 and 3 are commutable.
+    return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 2, 3);
   case CASE_VFMA_SPLATS(FMADD):
   case CASE_VFMA_SPLATS(FMSUB):
   case CASE_VFMA_SPLATS(FMACC):

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index ad1821d57256bc..435cd7f84c6122 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -2127,8 +2127,9 @@ multiclass VPseudoBinary<VReg RetClass,
                          LMULInfo MInfo,
                          string Constraint = "",
                          int sew = 0,
-                         int TargetConstraintType = 1> {
-  let VLMul = MInfo.value, SEW=sew in {
+                         int TargetConstraintType = 1,
+                         bit Commutable = 0> {
+  let VLMul = MInfo.value, SEW=sew, isCommutable = Commutable in {
     defvar suffix = !if(sew, "_" # MInfo.MX # "_E" # sew, "_" # MInfo.MX);
     def suffix : VPseudoBinaryNoMaskTU<RetClass, Op1Class, Op2Class,
                                        Constraint, TargetConstraintType>;
@@ -2167,8 +2168,9 @@ multiclass VPseudoBinaryM<VReg RetClass,
                           DAGOperand Op2Class,
                           LMULInfo MInfo,
                           string Constraint = "",
-                          int TargetConstraintType = 1> {
-  let VLMul = MInfo.value in {
+                          int TargetConstraintType = 1,
+                          bit Commutable = 0> {
+  let VLMul = MInfo.value, isCommutable = Commutable in {
     def "_" # MInfo.MX : VPseudoBinaryMOutNoMask<RetClass, Op1Class, Op2Class,
                                                  Constraint, TargetConstraintType>;
     let ForceTailAgnostic = true in
@@ -2226,8 +2228,8 @@ multiclass VPseudoTiedBinaryRoundingMode<VReg RetClass,
 }
 
 
-multiclass VPseudoBinaryV_VV<LMULInfo m, string Constraint = "", int sew = 0> {
-  defm _VV : VPseudoBinary<m.vrclass, m.vrclass, m.vrclass, m, Constraint, sew>;
+multiclass VPseudoBinaryV_VV<LMULInfo m, string Constraint = "", int sew = 0, bit Commutable = 0> {
+  defm _VV : VPseudoBinary<m.vrclass, m.vrclass, m.vrclass, m, Constraint, sew, Commutable=Commutable>;
 }
 
 multiclass VPseudoBinaryV_VV_RM<LMULInfo m, string Constraint = ""> {
@@ -2331,9 +2333,10 @@ multiclass VPseudoVALU_MM<bit Commutable = 0> {
 // * The destination EEW is greater than the source EEW, the source EMUL is
 //   at least 1, and the overlap is in the highest-numbered part of the
 //   destination register group is legal. Otherwise, it is illegal.
-multiclass VPseudoBinaryW_VV<LMULInfo m> {
+multiclass VPseudoBinaryW_VV<LMULInfo m, bit Commutable = 0> {
   defm _VV : VPseudoBinary<m.wvrclass, m.vrclass, m.vrclass, m,
-                           "@earlyclobber $rd", TargetConstraintType=3>;
+                           "@earlyclobber $rd", TargetConstraintType=3,
+                           Commutable=Commutable>;
 }
 
 multiclass VPseudoBinaryW_VV_RM<LMULInfo m, int sew = 0> {
@@ -2453,7 +2456,9 @@ multiclass VPseudoBinaryV_VM<LMULInfo m, bit CarryOut = 0, bit CarryIn = 1,
                          m.vrclass, m.vrclass, m, CarryIn, Constraint, TargetConstraintType>;
 }
 
-multiclass VPseudoTiedBinaryV_VM<LMULInfo m, int TargetConstraintType = 1> {
+multiclass VPseudoTiedBinaryV_VM<LMULInfo m, int TargetConstraintType = 1,
+                                 bit Commutable = 0> {
+  let isCommutable = Commutable in
   def "_VVM" # "_" # m.MX:
     VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
                              m.vrclass, m.vrclass, m, 1, "",
@@ -2667,9 +2672,11 @@ multiclass PseudoVEXT_VF8 {
 //  lowest-numbered part of the source register group".
 // With LMUL<=1 the source and dest occupy a single register so any overlap
 // is in the lowest-numbered part.
-multiclass VPseudoBinaryM_VV<LMULInfo m, int TargetConstraintType = 1> {
+multiclass VPseudoBinaryM_VV<LMULInfo m, int TargetConstraintType = 1,
+                             bit Commutable = 0> {
   defm _VV : VPseudoBinaryM<VR, m.vrclass, m.vrclass, m,
-                            !if(!ge(m.octuple, 16), "@earlyclobber $rd", ""), TargetConstraintType>;
+                            !if(!ge(m.octuple, 16), "@earlyclobber $rd", ""),
+                            TargetConstraintType, Commutable=Commutable>;
 }
 
 multiclass VPseudoBinaryM_VX<LMULInfo m, int TargetConstraintType = 1> {
@@ -2751,10 +2758,11 @@ multiclass VPseudoVSSHT_VV_VX_VI_RM<Operand ImmType = simm5, string Constraint =
   }
 }
 
-multiclass VPseudoVALU_VV_VX_VI<Operand ImmType = simm5, string Constraint = ""> {
+multiclass VPseudoVALU_VV_VX_VI<Operand ImmType = simm5, string Constraint = "",
+                                bit Commutable = 0> {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryV_VV<m, Constraint>,
+    defm "" : VPseudoBinaryV_VV<m, Constraint, Commutable=Commutable>,
             SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", mx,
                         forceMergeOpRead=true>;
     defm "" : VPseudoBinaryV_VX<m, Constraint>,
@@ -2804,17 +2812,17 @@ multiclass VPseudoVAALU_VV_VX_RM {
 multiclass VPseudoVMINMAX_VV_VX {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryV_VV<m>,
+    defm "" : VPseudoBinaryV_VV<m, Commutable=1>,
               SchedBinary<"WriteVIMinMaxV", "ReadVIMinMaxV", "ReadVIMinMaxV", mx>;
     defm "" : VPseudoBinaryV_VX<m>,
               SchedBinary<"WriteVIMinMaxX", "ReadVIMinMaxV", "ReadVIMinMaxX", mx>;
   }
 }
 
-multiclass VPseudoVMUL_VV_VX {
+multiclass VPseudoVMUL_VV_VX<bit Commutable = 0> {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryV_VV<m>,
+    defm "" : VPseudoBinaryV_VV<m, Commutable=Commutable>,
               SchedBinary<"WriteVIMulV", "ReadVIMulV", "ReadVIMulV", mx>;
     defm "" : VPseudoBinaryV_VX<m>,
               SchedBinary<"WriteVIMulX", "ReadVIMulV", "ReadVIMulX", mx>;
@@ -2964,10 +2972,10 @@ multiclass VPseudoVALU_VX_VI<Operand ImmType = simm5> {
   }
 }
 
-multiclass VPseudoVWALU_VV_VX {
+multiclass VPseudoVWALU_VV_VX<bit Commutable = 0> {
   foreach m = MxListW in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryW_VV<m>,
+    defm "" : VPseudoBinaryW_VV<m, Commutable=Commutable>,
               SchedBinary<"WriteVIWALUV", "ReadVIWALUV", "ReadVIWALUV", mx,
                           forceMergeOpRead=true>;
     defm "" : VPseudoBinaryW_VX<m>, 
@@ -2976,10 +2984,10 @@ multiclass VPseudoVWALU_VV_VX {
   }
 }
 
-multiclass VPseudoVWMUL_VV_VX {
+multiclass VPseudoVWMUL_VV_VX<bit Commutable = 0> {
   foreach m = MxListW in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryW_VV<m>,
+    defm "" : VPseudoBinaryW_VV<m, Commutable=Commutable>,
               SchedBinary<"WriteVIWMulV", "ReadVIWMulV", "ReadVIWMulV", mx,
                           forceMergeOpRead=true>;
     defm "" : VPseudoBinaryW_VX<m>,
@@ -3074,7 +3082,7 @@ multiclass VPseudoVMRG_VM_XM_IM {
 multiclass VPseudoVCALU_VM_XM_IM {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defm "" : VPseudoTiedBinaryV_VM<m>,
+    defm "" : VPseudoTiedBinaryV_VM<m, Commutable=1>,
               SchedBinary<"WriteVICALUV", "ReadVICALUV", "ReadVICALUV", mx,
                           forceMergeOpRead=true>;
     defm "" : VPseudoTiedBinaryV_XM<m>,
@@ -3287,10 +3295,10 @@ multiclass VPseudoTernaryV_VF_AAXA_RM<LMULInfo m, FPR_Info f,
                                                           sew, Commutable=1>;
 }
 
-multiclass VPseudoTernaryW_VV<LMULInfo m> {
+multiclass VPseudoTernaryW_VV<LMULInfo m, bit Commutable = 0> {
   defvar constraint = "@earlyclobber $rd";
   defm _VV : VPseudoTernaryWithPolicy<m.wvrclass, m.vrclass, m.vrclass, m,
-                                      constraint, /*Commutable*/ 0, TargetConstraintType=3>;
+                                      constraint, Commutable=Commutable, TargetConstraintType=3>;
 }
 
 multiclass VPseudoTernaryW_VV_RM<LMULInfo m, int sew = 0> {
@@ -3380,10 +3388,10 @@ multiclass VPseudoVSLD_VX_VI<Operand ImmType = simm5, string Constraint = ""> {
   }
 }
 
-multiclass VPseudoVWMAC_VV_VX {
+multiclass VPseudoVWMAC_VV_VX<bit Commutable = 0> {
   foreach m = MxListW in {
     defvar mx = m.MX;
-    defm "" : VPseudoTernaryW_VV<m>,
+    defm "" : VPseudoTernaryW_VV<m, Commutable=Commutable>,
               SchedTernary<"WriteVIWMulAddV", "ReadVIWMulAddV", "ReadVIWMulAddV",
                            "ReadVIWMulAddV", mx>;
     defm "" : VPseudoTernaryW_VX<m>,
@@ -3436,10 +3444,10 @@ multiclass VPseudoVWMAC_VV_VF_BF_RM {
   }
 }
 
-multiclass VPseudoVCMPM_VV_VX_VI {
+multiclass VPseudoVCMPM_VV_VX_VI<bit Commutable = 0> {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defm "" : VPseudoBinaryM_VV<m, TargetConstraintType=2>,
+    defm "" : VPseudoBinaryM_VV<m, TargetConstraintType=2, Commutable=Commutable>,
               SchedBinary<"WriteVICmpV", "ReadVICmpV", "ReadVICmpV", mx>;
     defm "" : VPseudoBinaryM_VX<m, TargetConstraintType=2>,
               SchedBinary<"WriteVICmpX", "ReadVICmpV", "ReadVICmpX", mx>;
@@ -6248,7 +6256,7 @@ defm PseudoVLSEG : VPseudoUSSegLoadFF;
 //===----------------------------------------------------------------------===//
 // 11.1. Vector Single-Width Integer Add and Subtract
 //===----------------------------------------------------------------------===//
-defm PseudoVADD   : VPseudoVALU_VV_VX_VI;
+defm PseudoVADD   : VPseudoVALU_VV_VX_VI<Commutable=1>;
 defm PseudoVSUB   : VPseudoVALU_VV_VX;
 defm PseudoVRSUB  : VPseudoVALU_VX_VI;
 
@@ -6313,9 +6321,9 @@ foreach vti = AllIntegerVectors in {
 //===----------------------------------------------------------------------===//
 // 11.2. Vector Widening Integer Add/Subtract
 //===----------------------------------------------------------------------===//
-defm PseudoVWADDU : VPseudoVWALU_VV_VX;
+defm PseudoVWADDU : VPseudoVWALU_VV_VX<Commutable=1>;
 defm PseudoVWSUBU : VPseudoVWALU_VV_VX;
-defm PseudoVWADD  : VPseudoVWALU_VV_VX;
+defm PseudoVWADD  : VPseudoVWALU_VV_VX<Commutable=1>;
 defm PseudoVWSUB  : VPseudoVWALU_VV_VX;
 defm PseudoVWADDU : VPseudoVWALU_WV_WX;
 defm PseudoVWSUBU : VPseudoVWALU_WV_WX;
@@ -6346,9 +6354,9 @@ defm PseudoVMSBC : VPseudoVCALUM_V_X<"@earlyclobber $rd">;
 //===----------------------------------------------------------------------===//
 // 11.5. Vector Bitwise Logical Instructions
 //===----------------------------------------------------------------------===//
-defm PseudoVAND : VPseudoVALU_VV_VX_VI;
-defm PseudoVOR  : VPseudoVALU_VV_VX_VI;
-defm PseudoVXOR : VPseudoVALU_VV_VX_VI;
+defm PseudoVAND : VPseudoVALU_VV_VX_VI<Commutable=1>;
+defm PseudoVOR  : VPseudoVALU_VV_VX_VI<Commutable=1>;
+defm PseudoVXOR : VPseudoVALU_VV_VX_VI<Commutable=1>;
 
 //===----------------------------------------------------------------------===//
 // 11.6. Vector Single-Width Bit Shift Instructions
@@ -6366,8 +6374,8 @@ defm PseudoVNSRA : VPseudoVNSHT_WV_WX_WI;
 //===----------------------------------------------------------------------===//
 // 11.8. Vector Integer Comparison Instructions
 //===----------------------------------------------------------------------===//
-defm PseudoVMSEQ  : VPseudoVCMPM_VV_VX_VI;
-defm PseudoVMSNE  : VPseudoVCMPM_VV_VX_VI;
+defm PseudoVMSEQ  : VPseudoVCMPM_VV_VX_VI<Commutable=1>;
+defm PseudoVMSNE  : VPseudoVCMPM_VV_VX_VI<Commutable=1>;
 defm PseudoVMSLTU : VPseudoVCMPM_VV_VX;
 defm PseudoVMSLT  : VPseudoVCMPM_VV_VX;
 defm PseudoVMSLEU : VPseudoVCMPM_VV_VX_VI;
@@ -6386,9 +6394,9 @@ defm PseudoVMAX  : VPseudoVMINMAX_VV_VX;
 //===----------------------------------------------------------------------===//
 // 11.10. Vector Single-Width Integer Multiply Instructions
 //===----------------------------------------------------------------------===//
-defm PseudoVMUL    : VPseudoVMUL_VV_VX;
-defm PseudoVMULH   : VPseudoVMUL_VV_VX;
-defm PseudoVMULHU  : VPseudoVMUL_VV_VX;
+defm PseudoVMUL    : VPseudoVMUL_VV_VX<Commutable=1>;
+defm PseudoVMULH   : VPseudoVMUL_VV_VX<Commutable=1>;
+defm PseudoVMULHU  : VPseudoVMUL_VV_VX<Commutable=1>;
 defm PseudoVMULHSU : VPseudoVMUL_VV_VX;
 
 //===----------------------------------------------------------------------===//
@@ -6402,8 +6410,8 @@ defm PseudoVREM  : VPseudoVDIV_VV_VX;
 //===----------------------------------------------------------------------===//
 // 11.12. Vector Widening Integer Multiply Instructions
 //===----------------------------------------------------------------------===//
-defm PseudoVWMUL   : VPseudoVWMUL_VV_VX;
-defm PseudoVWMULU  : VPseudoVWMUL_VV_VX;
+defm PseudoVWMUL   : VPseudoVWMUL_VV_VX<Commutable=1>;
+defm PseudoVWMULU  : VPseudoVWMUL_VV_VX<Commutable=1>;
 defm PseudoVWMULSU : VPseudoVWMUL_VV_VX;
 
 //===----------------------------------------------------------------------===//
@@ -6417,8 +6425,8 @@ defm PseudoVNMSUB : VPseudoVMAC_VV_VX_AAXA;
 //===----------------------------------------------------------------------===//
 // 11.14. Vector Widening Integer Multiply-Add Instructions
 //===----------------------------------------------------------------------===//
-defm PseudoVWMACCU  : VPseudoVWMAC_VV_VX;
-defm PseudoVWMACC   : VPseudoVWMAC_VV_VX;
+defm PseudoVWMACCU  : VPseudoVWMAC_VV_VX<Commutable=1>;
+defm PseudoVWMACC   : VPseudoVWMAC_VV_VX<Commutable=1>;
 defm PseudoVWMACCSU : VPseudoVWMAC_VV_VX;
 defm PseudoVWMACCUS : VPseudoVWMAC_VX;
 


        


More information about the llvm-commits mailing list