[llvm] 3855757 - [RISCV][ISel] Remove redundant vmerge for the vwadd. (#78403)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jan 27 03:03:36 PST 2024


Author: Chia
Date: 2024-01-27T20:03:32+09:00
New Revision: 3855757f98935208d7d0481af01a474139de9519

URL: https://github.com/llvm/llvm-project/commit/3855757f98935208d7d0481af01a474139de9519
DIFF: https://github.com/llvm/llvm-project/commit/3855757f98935208d7d0481af01a474139de9519.diff

LOG: [RISCV][ISel] Remove redundant vmerge for the vwadd. (#78403)

This patch is aiming at resolving the below missed-optimization case. 

### Code
```
define <8 x i64> @vwadd_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
    %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
    %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
    %sa = sext <8 x i32> %a to <8 x i64>
    %ret = add <8 x i64> %sa, %y
    ret <8 x i64> %ret
}
```

### Before this patch
[Compiler Explorer](https://godbolt.org/z/cd1bKTrx6)
```
vwadd_mask_v8i32:
        li      a0, 42
        vsetivli        zero, 8, e32, m2, ta, ma
        vmslt.vx        v0, v8, a0
        vmv.v.i v10, 0
        vmerge.vvm      v16, v10, v8, v0
        vwadd.wv        v8, v12, v16
        ret
```

### After this patch
```
vwadd_mask_v8i32:
        li a0, 42
        vsetivli zero, 8, e32, m2, ta, ma
        vmslt.vx v0, v8, a0
        vsetvli zero, zero, e32, m2, tu, mu
        vwadd.wv v12, v12, v8, v0.t
        vmv4r.v v8, v12
        ret
```
This pattern could be found in a reduction with a widening destination

Specifically, we first do a fold like `(vwadd.wv y, (vmerge cond, x, 0))
-> (vwadd.wv y, x, y, cond)`, then do pattern matching on it.

Added: 
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5b1a246a19c8e8f..47617642a5e9aed 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13709,6 +13709,57 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
   return InputRootReplacement;
 }
 
+// Fold (vwadd.wv y, (vmerge cond, x, 0)) -> vwadd.wv y, x, y, cond
+// y will be the Passthru and cond will be the Mask.
+static SDValue combineVWADDWSelect(SDNode *N, SelectionDAG &DAG) {
+  unsigned Opc = N->getOpcode();
+  assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL);
+
+  SDValue Y = N->getOperand(0);
+  SDValue MergeOp = N->getOperand(1);
+  if (MergeOp.getOpcode() != RISCVISD::VMERGE_VL)
+    return SDValue();
+  SDValue X = MergeOp->getOperand(1);
+
+  if (!MergeOp.hasOneUse())
+    return SDValue();
+
+  // Passthru should be undef
+  SDValue Passthru = N->getOperand(2);
+  if (!Passthru.isUndef())
+    return SDValue();
+
+  // Mask should be all ones
+  SDValue Mask = N->getOperand(3);
+  if (Mask.getOpcode() != RISCVISD::VMSET_VL)
+    return SDValue();
+
+  // False value of MergeOp should be all zeros
+  SDValue Z = MergeOp->getOperand(2);
+  if (Z.getOpcode() != ISD::INSERT_SUBVECTOR)
+    return SDValue();
+  if (!ISD::isBuildVectorAllZeros(Z.getOperand(1).getNode()))
+    return SDValue();
+  if (!isNullOrNullSplat(Z.getOperand(0)) && !Z.getOperand(0).isUndef())
+    return SDValue();
+
+  return DAG.getNode(Opc, SDLoc(N), N->getValueType(0),
+                     {Y, X, Y, MergeOp->getOperand(0), N->getOperand(4)},
+                     N->getFlags());
+}
+
+static SDValue performVWADDW_VLCombine(SDNode *N,
+                                       TargetLowering::DAGCombinerInfo &DCI,
+                                       const RISCVSubtarget &Subtarget) {
+  unsigned Opc = N->getOpcode();
+  assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL);
+
+  if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+    return V;
+
+  return combineVWADDWSelect(N, DCI.DAG);
+}
+
 // Helper function for performMemPairCombine.
 // Try to combine the memory loads/stores LSNode1 and LSNode2
 // into a single memory pair operation.
@@ -15777,9 +15828,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
       return V;
     return combineToVWMACC(N, DAG, Subtarget);
-  case RISCVISD::SUB_VL:
   case RISCVISD::VWADD_W_VL:
   case RISCVISD::VWADDU_W_VL:
+    return performVWADDW_VLCombine(N, DCI, Subtarget);
+  case RISCVISD::SUB_VL:
   case RISCVISD::VWSUB_W_VL:
   case RISCVISD::VWSUBU_W_VL:
   case RISCVISD::MUL_VL:

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 6e7be2647e8f838..cc44092700c66eb 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -691,6 +691,27 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
                      GPR:$vl, sew, TU_MU)>;
 }
 
+class VPatTiedBinaryMaskVL_V<SDNode vop,
+                             string instruction_name,
+                             string suffix,
+                             ValueType result_type,
+                             ValueType op2_type,
+                             ValueType mask_type,
+                             int sew,
+                             LMULInfo vlmul,
+                             VReg result_reg_class,
+                             VReg op2_reg_class> :
+  Pat<(result_type (vop
+                   (result_type result_reg_class:$rs1),
+                   (op2_type op2_reg_class:$rs2),
+                   (result_type result_reg_class:$rs1),
+                   (mask_type V0),
+                   VLOpFrag)),
+      (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK_TIED")
+                   result_reg_class:$rs1,
+                   op2_reg_class:$rs2,
+                   (mask_type V0), GPR:$vl, sew, TU_MU)>;
+
 multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
                                        string instruction_name,
                                        string suffix,
@@ -819,6 +840,10 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
       defm : VPatTiedBinaryNoMaskVL_V<vop_w, instruction_name, "WV",
                                       wti.Vector, vti.Vector, vti.Log2SEW,
                                       vti.LMul, wti.RegClass, vti.RegClass>;
+      def : VPatTiedBinaryMaskVL_V<vop_w, instruction_name, "WV",
+                                   wti.Vector, vti.Vector, wti.Mask,
+                                   vti.Log2SEW, vti.LMul, wti.RegClass,
+                                   vti.RegClass>;
       def : VPatBinaryVL_V<vop_w, instruction_name, "WV",
                            wti.Vector, wti.Vector, vti.Vector, vti.Mask,
                            vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass,

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
new file mode 100644
index 000000000000000..d241b78e41391bf
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
@@ -0,0 +1,90 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
+
+define <8 x i64> @vwadd_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_wv_mask_v8i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 42
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmslt.vx v0, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e32, m2, tu, mu
+; CHECK-NEXT:    vwadd.wv v12, v12, v8, v0.t
+; CHECK-NEXT:    vmv4r.v v8, v12
+; CHECK-NEXT:    ret
+    %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+    %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+    %sa = sext <8 x i32> %a to <8 x i64>
+    %ret = add <8 x i64> %sa, %y
+    ret <8 x i64> %ret
+}
+
+define <8 x i64> @vwaddu_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwaddu_wv_mask_v8i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 42
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmslt.vx v0, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e32, m2, tu, mu
+; CHECK-NEXT:    vwaddu.wv v12, v12, v8, v0.t
+; CHECK-NEXT:    vmv4r.v v8, v12
+; CHECK-NEXT:    ret
+    %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+    %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+    %sa = zext <8 x i32> %a to <8 x i64>
+    %ret = add <8 x i64> %sa, %y
+    ret <8 x i64> %ret
+}
+
+define <8 x i64> @vwaddu_vv_mask_v8i32(<8 x i32> %x, <8 x i32> %y) {
+; CHECK-LABEL: vwaddu_vv_mask_v8i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 42
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmslt.vx v0, v8, a0
+; CHECK-NEXT:    vmv.v.i v12, 0
+; CHECK-NEXT:    vmerge.vvm v8, v12, v8, v0
+; CHECK-NEXT:    vwaddu.vv v12, v8, v10
+; CHECK-NEXT:    vmv4r.v v8, v12
+; CHECK-NEXT:    ret
+    %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+    %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+    %sa = zext <8 x i32> %a to <8 x i64>
+    %sy = zext <8 x i32> %y to <8 x i64>
+    %ret = add <8 x i64> %sa, %sy
+    ret <8 x i64> %ret
+}
+
+define <8 x i64> @vwadd_wv_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_wv_mask_v8i32_commutative:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 42
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmslt.vx v0, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e32, m2, tu, mu
+; CHECK-NEXT:    vwadd.wv v12, v12, v8, v0.t
+; CHECK-NEXT:    vmv4r.v v8, v12
+; CHECK-NEXT:    ret
+    %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+    %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+    %sa = sext <8 x i32> %a to <8 x i64>
+    %ret = add <8 x i64> %y, %sa
+    ret <8 x i64> %ret
+}
+
+define <8 x i64> @vwadd_wv_mask_v8i32_nonzero(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_wv_mask_v8i32_nonzero:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 42
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vmslt.vx v0, v8, a0
+; CHECK-NEXT:    vmv.v.i v10, 1
+; CHECK-NEXT:    vmerge.vvm v16, v10, v8, v0
+; CHECK-NEXT:    vwadd.wv v8, v12, v16
+; CHECK-NEXT:    ret
+    %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+    %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
+    %sa = sext <8 x i32> %a to <8 x i64>
+    %ret = add <8 x i64> %y, %sa
+    ret <8 x i64> %ret
+}


        


More information about the llvm-commits mailing list