[llvm] 3f83a69 - [RISCV] Allow folding vmerge into masked ops when mask is the same (#97989)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 8 21:12:05 PDT 2024


Author: Luke Lau
Date: 2024-07-09T12:12:02+08:00
New Revision: 3f83a69bcb2c6b5fa3efbc41d1822e6fa69a6620

URL: https://github.com/llvm/llvm-project/commit/3f83a69bcb2c6b5fa3efbc41d1822e6fa69a6620
DIFF: https://github.com/llvm/llvm-project/commit/3f83a69bcb2c6b5fa3efbc41d1822e6fa69a6620.diff

LOG: [RISCV] Allow folding vmerge into masked ops when mask is the same (#97989)

We currently only fold a vmerge into a masked true operand if the vmerge
has an all-ones mask, since we end up keeping the mask from the true
operand.

But if the masks are the same then we can still fold, because vmerge and
true have the same passthru. If an element was masked off in the
original vmerge, it will also be masked off in the resulting true, and
will have the same passthru value.

The motivation for this is to lower masked VP loads and stores with
passthrus to masked RVV instructions. Normally you can express a masked
RVV instruction with a mask undisturbed passthru via a combination of a
VP op with an all-ones mask and a vp.merge. But for loads and stores you
need the same mask on the VP op as well as the vp.merge.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
    llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
    llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
    llvm/test/CodeGen/RISCV/rvv/vpload.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index ce6a396e9ced9..adde745a5a91b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -3523,24 +3523,26 @@ bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
   return false;
 }
 
-static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
+// After ISel, a vector pseudo's mask will be copied to V0 via a CopyToReg
+// that's glued to the pseudo. This tries to look up the value that was copied
+// to V0.
+static SDValue getMaskSetter(SDValue MaskOp, SDValue GlueOp) {
   // Check that we're using V0 as a mask register.
   if (!isa<RegisterSDNode>(MaskOp) ||
       cast<RegisterSDNode>(MaskOp)->getReg() != RISCV::V0)
-    return false;
+    return SDValue();
 
   // The glued user defines V0.
   const auto *Glued = GlueOp.getNode();
 
   if (!Glued || Glued->getOpcode() != ISD::CopyToReg)
-    return false;
+    return SDValue();
 
   // Check that we're defining V0 as a mask register.
   if (!isa<RegisterSDNode>(Glued->getOperand(1)) ||
       cast<RegisterSDNode>(Glued->getOperand(1))->getReg() != RISCV::V0)
-    return false;
+    return SDValue();
 
-  // Check the instruction defining V0; it needs to be a VMSET pseudo.
   SDValue MaskSetter = Glued->getOperand(2);
 
   // Sometimes the VMSET is wrapped in a COPY_TO_REGCLASS, e.g. if the mask came
@@ -3549,6 +3551,15 @@ static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
       MaskSetter->getMachineOpcode() == RISCV::COPY_TO_REGCLASS)
     MaskSetter = MaskSetter->getOperand(0);
 
+  return MaskSetter;
+}
+
+static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
+  // Check the instruction defining V0; it needs to be a VMSET pseudo.
+  SDValue MaskSetter = getMaskSetter(MaskOp, GlueOp);
+  if (!MaskSetter)
+    return false;
+
   const auto IsVMSet = [](unsigned Opc) {
     return Opc == RISCV::PseudoVMSET_M_B1 || Opc == RISCV::PseudoVMSET_M_B16 ||
            Opc == RISCV::PseudoVMSET_M_B2 || Opc == RISCV::PseudoVMSET_M_B32 ||
@@ -3755,12 +3766,16 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
       return false;
   }
 
-  // If True is masked then the vmerge must have an all 1s mask, since we're
-  // going to keep the mask from True.
+  // If True is masked then the vmerge must have either the same mask or an all
+  // 1s mask, since we're going to keep the mask from True.
   if (IsMasked && Mask) {
     // FIXME: Support mask agnostic True instruction which would have an
     // undef merge operand.
-    if (!usesAllOnesMask(Mask, Glue))
+    SDValue TrueMask =
+        getMaskSetter(True->getOperand(Info->MaskOpIdx),
+                      True->getOperand(True->getNumOperands() - 1));
+    assert(TrueMask);
+    if (!usesAllOnesMask(Mask, Glue) && getMaskSetter(Mask, Glue) != TrueMask)
       return false;
   }
 

diff  --git a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
index 033a1d7e297f7..d26fd0ca26c72 100644
--- a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
@@ -240,3 +240,14 @@ define <vscale x 2 x i32> @vpmerge_viota(<vscale x 2 x i32> %passthru, <vscale x
   %b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> splat (i1 -1), i64 %1)
   ret <vscale x 2 x i32> %b
 }
+
+define <vscale x 2 x i32> @vpmerge_vadd_same_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl) {
+; CHECK-LABEL: vpmerge_vadd_same_mask:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, tu, mu
+; CHECK-NEXT:    vadd.vv v8, v9, v10, v0.t
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i32> @llvm.riscv.vadd.mask.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl, i64 1)
+  %b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> %m, i64 %vl)
+  ret <vscale x 2 x i32> %b
+}

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
index f9d992a40299c..f635b61fc3e5b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
@@ -85,11 +85,8 @@ define <vscale x 1 x float> @vfmacc_vv_nxv1f32_tu(<vscale x 1 x half> %a, <vscal
 define <vscale x 1 x float> @vfmacc_vv_nxv1f32_masked__tu(<vscale x 1 x half> %a, <vscale x 1 x half> %b, <vscale x 1 x float> %c, <vscale x 1 x i1> %m, i32 zeroext %evl) {
 ; ZVFH-LABEL: vfmacc_vv_nxv1f32_masked__tu:
 ; ZVFH:       # %bb.0:
-; ZVFH-NEXT:    vmv1r.v v11, v10
-; ZVFH-NEXT:    vsetvli zero, a0, e16, mf4, ta, ma
-; ZVFH-NEXT:    vfwmacc.vv v11, v8, v9, v0.t
-; ZVFH-NEXT:    vsetvli zero, zero, e32, mf2, tu, ma
-; ZVFH-NEXT:    vmerge.vvm v10, v10, v11, v0
+; ZVFH-NEXT:    vsetvli zero, a0, e16, mf4, tu, mu
+; ZVFH-NEXT:    vfwmacc.vv v10, v8, v9, v0.t
 ; ZVFH-NEXT:    vmv1r.v v8, v10
 ; ZVFH-NEXT:    ret
 ;

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vpload.ll b/llvm/test/CodeGen/RISCV/rvv/vpload.ll
index 1b1e9153a2fd5..c0a210e680c79 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vpload.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vpload.ll
@@ -26,6 +26,17 @@ define <vscale x 1 x i8> @vpload_nxv1i8_allones_mask(ptr %ptr, i32 zeroext %evl)
   ret <vscale x 1 x i8> %load
 }
 
+define <vscale x 1 x i8> @vpload_nxv1i8_passthru(ptr %ptr, <vscale x 1 x i1> %m, <vscale x 1 x i8> %passthru, i32 zeroext %evl) {
+; CHECK-LABEL: vpload_nxv1i8_passthru:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a1, e8, mf8, tu, mu
+; CHECK-NEXT:    vle8.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x i8> @llvm.vp.load.nxv1i8.p0(ptr %ptr, <vscale x 1 x i1> %m, i32 %evl)
+  %merge = call <vscale x 1 x i8> @llvm.vp.merge.nxv1i8(<vscale x 1 x i1> %m, <vscale x 1 x i8> %load, <vscale x 1 x i8> %passthru, i32 %evl)
+  ret <vscale x 1 x i8> %merge
+}
+
 declare <vscale x 2 x i8> @llvm.vp.load.nxv2i8.p0(ptr, <vscale x 2 x i1>, i32)
 
 define <vscale x 2 x i8> @vpload_nxv2i8(ptr %ptr, <vscale x 2 x i1> %m, i32 zeroext %evl) {
@@ -450,10 +461,10 @@ define <vscale x 16 x double> @vpload_nxv16f64(ptr %ptr, <vscale x 16 x i1> %m,
 ; CHECK-NEXT:    add a4, a0, a4
 ; CHECK-NEXT:    vsetvli zero, a3, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v16, (a4), v0.t
-; CHECK-NEXT:    bltu a1, a2, .LBB37_2
+; CHECK-NEXT:    bltu a1, a2, .LBB38_2
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    mv a1, a2
-; CHECK-NEXT:  .LBB37_2:
+; CHECK-NEXT:  .LBB38_2:
 ; CHECK-NEXT:    vmv1r.v v0, v8
 ; CHECK-NEXT:    vsetvli zero, a1, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v8, (a0), v0.t
@@ -480,10 +491,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
 ; CHECK-NEXT:    slli a5, a3, 1
 ; CHECK-NEXT:    vmv1r.v v8, v0
 ; CHECK-NEXT:    mv a4, a2
-; CHECK-NEXT:    bltu a2, a5, .LBB38_2
+; CHECK-NEXT:    bltu a2, a5, .LBB39_2
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    mv a4, a5
-; CHECK-NEXT:  .LBB38_2:
+; CHECK-NEXT:  .LBB39_2:
 ; CHECK-NEXT:    sub a6, a4, a3
 ; CHECK-NEXT:    sltu a7, a4, a6
 ; CHECK-NEXT:    addi a7, a7, -1
@@ -499,10 +510,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
 ; CHECK-NEXT:    sltu a2, a2, a5
 ; CHECK-NEXT:    addi a2, a2, -1
 ; CHECK-NEXT:    and a2, a2, a5
-; CHECK-NEXT:    bltu a2, a3, .LBB38_4
+; CHECK-NEXT:    bltu a2, a3, .LBB39_4
 ; CHECK-NEXT:  # %bb.3:
 ; CHECK-NEXT:    mv a2, a3
-; CHECK-NEXT:  .LBB38_4:
+; CHECK-NEXT:  .LBB39_4:
 ; CHECK-NEXT:    slli a5, a3, 4
 ; CHECK-NEXT:    srli a6, a3, 2
 ; CHECK-NEXT:    vsetvli a7, zero, e8, mf2, ta, ma
@@ -510,10 +521,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
 ; CHECK-NEXT:    add a5, a0, a5
 ; CHECK-NEXT:    vsetvli zero, a2, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v24, (a5), v0.t
-; CHECK-NEXT:    bltu a4, a3, .LBB38_6
+; CHECK-NEXT:    bltu a4, a3, .LBB39_6
 ; CHECK-NEXT:  # %bb.5:
 ; CHECK-NEXT:    mv a4, a3
-; CHECK-NEXT:  .LBB38_6:
+; CHECK-NEXT:  .LBB39_6:
 ; CHECK-NEXT:    vmv1r.v v0, v8
 ; CHECK-NEXT:    vsetvli zero, a4, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v8, (a0), v0.t


        


More information about the llvm-commits mailing list