[llvm] 38706dd - [RISCV][NFC] Refactor patterns for Multiply Add instructions

Lian Wang via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 14 01:15:33 PDT 2022


Author: Lian Wang
Date: 2022-04-14T08:00:00Z
New Revision: 38706dd9401407fa9de649f6acbb2f2bec756279

URL: https://github.com/llvm/llvm-project/commit/38706dd9401407fa9de649f6acbb2f2bec756279
DIFF: https://github.com/llvm/llvm-project/commit/38706dd9401407fa9de649f6acbb2f2bec756279.diff

LOG: [RISCV][NFC] Refactor patterns for Multiply Add instructions

Reviewed By: craig.topper, frasercrmck

Differential Revision: https://reviews.llvm.org/D123355

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index ed81976732c02..ba2000e5a3128 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -537,6 +537,26 @@ multiclass VPatWidenFPNegMulSacSDNode_VV_VF<string instruction_name> {
   }
 }
 
+multiclass VPatMultiplyAddSDNode_VV_VX<SDNode op, string instruction_name> {
+  foreach vti = AllIntegerVectors in {
+    defvar suffix = vti.LMul.MX;
+    // NOTE: We choose VMADD because it has the most commuting freedom. So it
+    // works best with how TwoAddressInstructionPass tries commuting.
+    def : Pat<(vti.Vector (op vti.RegClass:$rs2,
+                              (mul_oneuse vti.RegClass:$rs1, vti.RegClass:$rd))),
+              (!cast<Instruction>(instruction_name#"_VV_"# suffix)
+                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                 vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
+    // The choice of VMADD here is arbitrary, vmadd.vx and vmacc.vx are equally
+    // commutable.
+    def : Pat<(vti.Vector (op vti.RegClass:$rs2,
+                              (mul_oneuse (SplatPat XLenVT:$rs1), vti.RegClass:$rd))),
+              (!cast<Instruction>(instruction_name#"_VX_" # suffix)
+                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+                 vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Patterns.
 //===----------------------------------------------------------------------===//
@@ -682,36 +702,8 @@ defm : VPatWidenBinarySDNode_VV_VX<mul, sext_oneuse, anyext_oneuse,
                                    "PseudoVWMULSU">;
 
 // 12.13 Vector Single-Width Integer Multiply-Add Instructions.
-foreach vti = AllIntegerVectors in {
-  // NOTE: We choose VMADD because it has the most commuting freedom. So it
-  // works best with how TwoAddressInstructionPass tries commuting.
-  defvar suffix = vti.LMul.MX;
-  def : Pat<(vti.Vector (add vti.RegClass:$rs2,
-                              (mul_oneuse vti.RegClass:$rs1, vti.RegClass:$rd))),
-            (!cast<Instruction>("PseudoVMADD_VV_"# suffix)
-                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                 vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
-  def : Pat<(vti.Vector (sub vti.RegClass:$rs2,
-                              (mul_oneuse vti.RegClass:$rs1, vti.RegClass:$rd))),
-            (!cast<Instruction>("PseudoVNMSUB_VV_"# suffix)
-                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                 vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
-
-  // The choice of VMADD here is arbitrary, vmadd.vx and vmacc.vx are equally
-  // commutable.
-  def : Pat<(vti.Vector (add vti.RegClass:$rs2,
-                              (mul_oneuse (SplatPat XLenVT:$rs1),
-                                          vti.RegClass:$rd))),
-            (!cast<Instruction>("PseudoVMADD_VX_" # suffix)
-                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                 vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
-  def : Pat<(vti.Vector (sub vti.RegClass:$rs2,
-                              (mul_oneuse (SplatPat XLenVT:$rs1),
-                                          vti.RegClass:$rd))),
-            (!cast<Instruction>("PseudoVNMSUB_VX_" # suffix)
-                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                 vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
-}
+defm : VPatMultiplyAddSDNode_VV_VX<add, "PseudoVMADD">;
+defm : VPatMultiplyAddSDNode_VV_VX<sub, "PseudoVNMSUB">;
 
 // 12.14 Vector Widening Integer Multiply-Add Instructions
 defm : VPatWidenMulAddSDNode_VV<sext_oneuse, sext_oneuse, "PseudoVWMACC">;

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 521b387e33ac1..f22dfa3f6f085 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -829,6 +829,59 @@ multiclass VPatNarrowShiftSplatExt_WX<SDNode op, PatFrags extop, string instruct
   }
 }
 
+multiclass VPatMultiplyAddVL_VV_VX<SDNode op, string instruction_name> {
+  foreach vti = AllIntegerVectors in {
+    defvar suffix = vti.LMul.MX;
+    // NOTE: We choose VMADD because it has the most commuting freedom. So it
+    // works best with how TwoAddressInstructionPass tries commuting.
+    def : Pat<(vti.Vector
+             (op vti.RegClass:$rs2,
+                 (riscv_mul_vl_oneuse vti.RegClass:$rs1,
+                                      vti.RegClass:$rd,
+                                      (vti.Mask true_mask), VLOpFrag),
+                           (vti.Mask true_mask), VLOpFrag)),
+            (!cast<Instruction>(instruction_name#"_VV_"# suffix)
+                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+    // The choice of VMADD here is arbitrary, vmadd.vx and vmacc.vx are equally
+    // commutable.
+    def : Pat<(vti.Vector
+             (op vti.RegClass:$rs2,
+                 (riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1),
+                                       vti.RegClass:$rd,
+                                       (vti.Mask true_mask), VLOpFrag),
+                           (vti.Mask true_mask), VLOpFrag)),
+            (!cast<Instruction>(instruction_name#"_VX_" # suffix)
+                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+  }
+}
+
+multiclass VPatWidenMultiplyAddVL_VV_VX<PatFrag op1, string instruction_name> {
+  foreach vtiTowti = AllWidenableIntVectors in {
+    defvar vti = vtiTowti.Vti;
+    defvar wti = vtiTowti.Wti;
+      def : Pat<(wti.Vector
+             (riscv_add_vl wti.RegClass:$rd,
+                           (op1 vti.RegClass:$rs1,
+                                (vti.Vector vti.RegClass:$rs2),
+                                (vti.Mask true_mask), VLOpFrag),
+                          (vti.Mask true_mask), VLOpFrag)),
+            (!cast<Instruction>(instruction_name#"_VV_" # vti.LMul.MX)
+                 wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+    def : Pat<(wti.Vector
+             (riscv_add_vl wti.RegClass:$rd,
+                          (op1 (SplatPat XLenVT:$rs1),
+                               (vti.Vector vti.RegClass:$rs2),
+                               (vti.Mask true_mask), VLOpFrag),
+                           (vti.Mask true_mask), VLOpFrag)),
+            (!cast<Instruction>(instruction_name#"_VX_" # vti.LMul.MX)
+                 wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Patterns.
 //===----------------------------------------------------------------------===//
@@ -1008,111 +1061,16 @@ defm : VPatBinaryWVL_VV_VX<riscv_vwmulu_vl, "PseudoVWMULU">;
 defm : VPatBinaryWVL_VV_VX<riscv_vwmulsu_vl, "PseudoVWMULSU">;
 
 // 12.13 Vector Single-Width Integer Multiply-Add Instructions
-foreach vti = AllIntegerVectors in {
-  // NOTE: We choose VMADD because it has the most commuting freedom. So it
-  // works best with how TwoAddressInstructionPass tries commuting.
-  defvar suffix = vti.LMul.MX;
-  def : Pat<(vti.Vector
-             (riscv_add_vl vti.RegClass:$rs2,
-                           (riscv_mul_vl_oneuse vti.RegClass:$rs1,
-                                                vti.RegClass:$rd,
-                                                (vti.Mask true_mask), VLOpFrag),
-                           (vti.Mask true_mask), VLOpFrag)),
-            (!cast<Instruction>("PseudoVMADD_VV_"# suffix)
-                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-  def : Pat<(vti.Vector
-             (riscv_sub_vl vti.RegClass:$rs2,
-                           (riscv_mul_vl_oneuse vti.RegClass:$rs1,
-                                                vti.RegClass:$rd,
-                                                (vti.Mask true_mask), VLOpFrag),
-                           (vti.Mask true_mask), VLOpFrag)),
-            (!cast<Instruction>("PseudoVNMSUB_VV_"# suffix)
-                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-
-  // The choice of VMADD here is arbitrary, vmadd.vx and vmacc.vx are equally
-  // commutable.
-  def : Pat<(vti.Vector
-             (riscv_add_vl vti.RegClass:$rs2,
-                           (riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1),
-                                                vti.RegClass:$rd,
-                                                (vti.Mask true_mask), VLOpFrag),
-                           (vti.Mask true_mask), VLOpFrag)),
-            (!cast<Instruction>("PseudoVMADD_VX_" # suffix)
-                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-  def : Pat<(vti.Vector
-             (riscv_sub_vl vti.RegClass:$rs2,
-                           (riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1),
-                                                vti.RegClass:$rd,
-                                                (vti.Mask true_mask),
-                                                VLOpFrag),
-                           (vti.Mask true_mask), VLOpFrag)),
-            (!cast<Instruction>("PseudoVNMSUB_VX_" # suffix)
-                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-}
+defm : VPatMultiplyAddVL_VV_VX<riscv_add_vl, "PseudoVMADD">;
+defm : VPatMultiplyAddVL_VV_VX<riscv_sub_vl, "PseudoVNMSUB">;
 
 // 12.14. Vector Widening Integer Multiply-Add Instructions
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmul_vl_oneuse, "PseudoVWMACC">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmulu_vl_oneuse, "PseudoVWMACCU">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmulsu_vl_oneuse, "PseudoVWMACCSU">;
 foreach vtiTowti = AllWidenableIntVectors in {
   defvar vti = vtiTowti.Vti;
   defvar wti = vtiTowti.Wti;
-  def : Pat<(wti.Vector
-             (riscv_add_vl wti.RegClass:$rd,
-                           (riscv_vwmul_vl_oneuse vti.RegClass:$rs1,
-                                                  (vti.Vector vti.RegClass:$rs2),
-                                                  (vti.Mask true_mask), VLOpFrag),
-                           (vti.Mask true_mask), VLOpFrag)),
-            (!cast<Instruction>("PseudoVWMACC_VV_" # vti.LMul.MX)
-                 wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-  def : Pat<(wti.Vector
-             (riscv_add_vl wti.RegClass:$rd,
-                           (riscv_vwmulu_vl_oneuse vti.RegClass:$rs1,
-                                                   (vti.Vector vti.RegClass:$rs2),
-                                                   (vti.Mask true_mask), VLOpFrag),
-                           (vti.Mask true_mask), VLOpFrag)),
-            (!cast<Instruction>("PseudoVWMACCU_VV_" # vti.LMul.MX)
-                 wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-  def : Pat<(wti.Vector
-             (riscv_add_vl wti.RegClass:$rd,
-                           (riscv_vwmulsu_vl_oneuse vti.RegClass:$rs1,
-                                                    (vti.Vector vti.RegClass:$rs2),
-                                                    (vti.Mask true_mask), VLOpFrag),
-                           (vti.Mask true_mask), VLOpFrag)),
-            (!cast<Instruction>("PseudoVWMACCSU_VV_" # vti.LMul.MX)
-                 wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-
-  def : Pat<(wti.Vector
-             (riscv_add_vl wti.RegClass:$rd,
-                           (riscv_vwmul_vl_oneuse (SplatPat XLenVT:$rs1),
-                                                  (vti.Vector vti.RegClass:$rs2),
-                                                  (vti.Mask true_mask), VLOpFrag),
-                           (vti.Mask true_mask), VLOpFrag)),
-            (!cast<Instruction>("PseudoVWMACC_VX_" # vti.LMul.MX)
-                 wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-  def : Pat<(wti.Vector
-             (riscv_add_vl wti.RegClass:$rd,
-                           (riscv_vwmulu_vl_oneuse (SplatPat XLenVT:$rs1),
-                                                   (vti.Vector vti.RegClass:$rs2),
-                                                   (vti.Mask true_mask), VLOpFrag),
-                           (vti.Mask true_mask), VLOpFrag)),
-            (!cast<Instruction>("PseudoVWMACCU_VX_" # vti.LMul.MX)
-                 wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-  def : Pat<(wti.Vector
-             (riscv_add_vl wti.RegClass:$rd,
-                           (riscv_vwmulsu_vl_oneuse (SplatPat XLenVT:$rs1),
-                                                    (vti.Vector vti.RegClass:$rs2),
-                                                    (vti.Mask true_mask), VLOpFrag),
-                           (vti.Mask true_mask), VLOpFrag)),
-            (!cast<Instruction>("PseudoVWMACCSU_VX_" # vti.LMul.MX)
-                 wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
   def : Pat<(wti.Vector
              (riscv_add_vl wti.RegClass:$rd,
                            (riscv_vwmulsu_vl_oneuse (vti.Vector vti.RegClass:$rs1),


        


More information about the llvm-commits mailing list