[llvm] [RISCV] Allow folding vmerge into masked ops when mask is the same (PR #97989)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 7 23:21:58 PDT 2024


https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/97989

We currently only fold a vmerge into a masked true operand if the vmerge has an all-ones mask, since we end up keeping the mask from the true operand.

But if the masks are the same then we can still fold, because vmerge and true have the same passthru. If an element was masked off in the original vmerge, it will also be masked off in the resulting true, and will have the same passthru value.

The motivation for this is to lower masked VP loads and stores with passthrus to masked RVV instructions. Normally you can express a masked RVV instruction with a mask undisturbed passthru via a combination of a VP op with an all-ones mask and a vp.merge. But for loads and stores you need the same mask on the VP op as well as the vp.merge.


>From b9f014c188c308432f5f968130e8d9c415865948 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 5 Jul 2024 17:32:16 +0800
Subject: [PATCH 1/2] Precommit tests

---
 .../rvv/rvv-peephole-vmerge-masked-vops.ll    | 14 +++++++++
 llvm/test/CodeGen/RISCV/rvv/vpload.ll         | 29 ++++++++++++++-----
 2 files changed, 35 insertions(+), 8 deletions(-)

diff --git a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
index 033a1d7e297f7..4acd241a652f4 100644
--- a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
@@ -240,3 +240,17 @@ define <vscale x 2 x i32> @vpmerge_viota(<vscale x 2 x i32> %passthru, <vscale x
   %b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> splat (i1 -1), i64 %1)
   ret <vscale x 2 x i32> %b
 }
+
+define <vscale x 2 x i32> @vpmerge_vadd_same_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl) {
+; CHECK-LABEL: vpmerge_vadd_same_mask:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vmv1r.v v11, v8
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, mu
+; CHECK-NEXT:    vadd.vv v11, v9, v10, v0.t
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
+; CHECK-NEXT:    vmerge.vvm v8, v8, v11, v0
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i32> @llvm.riscv.vadd.mask.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl, i64 1)
+  %b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> %m, i64 %vl)
+  ret <vscale x 2 x i32> %b
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vpload.ll b/llvm/test/CodeGen/RISCV/rvv/vpload.ll
index 1b1e9153a2fd5..54d194d3b73d2 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vpload.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vpload.ll
@@ -26,6 +26,19 @@ define <vscale x 1 x i8> @vpload_nxv1i8_allones_mask(ptr %ptr, i32 zeroext %evl)
   ret <vscale x 1 x i8> %load
 }
 
+define <vscale x 1 x i8> @vpload_nxv1i8_passthru(ptr %ptr, <vscale x 1 x i1> %m, <vscale x 1 x i8> %passthru, i32 zeroext %evl) {
+; CHECK-LABEL: vpload_nxv1i8_passthru:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a1, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v9, (a0), v0.t
+; CHECK-NEXT:    vsetvli zero, zero, e8, mf8, tu, ma
+; CHECK-NEXT:    vmerge.vvm v8, v8, v9, v0
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x i8> @llvm.vp.load.nxv1i8.p0(ptr %ptr, <vscale x 1 x i1> %m, i32 %evl)
+  %merge = call <vscale x 1 x i8> @llvm.vp.merge.nxv1i8(<vscale x 1 x i1> %m, <vscale x 1 x i8> %load, <vscale x 1 x i8> %passthru, i32 %evl)
+  ret <vscale x 1 x i8> %merge
+}
+
 declare <vscale x 2 x i8> @llvm.vp.load.nxv2i8.p0(ptr, <vscale x 2 x i1>, i32)
 
 define <vscale x 2 x i8> @vpload_nxv2i8(ptr %ptr, <vscale x 2 x i1> %m, i32 zeroext %evl) {
@@ -450,10 +463,10 @@ define <vscale x 16 x double> @vpload_nxv16f64(ptr %ptr, <vscale x 16 x i1> %m,
 ; CHECK-NEXT:    add a4, a0, a4
 ; CHECK-NEXT:    vsetvli zero, a3, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v16, (a4), v0.t
-; CHECK-NEXT:    bltu a1, a2, .LBB37_2
+; CHECK-NEXT:    bltu a1, a2, .LBB38_2
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    mv a1, a2
-; CHECK-NEXT:  .LBB37_2:
+; CHECK-NEXT:  .LBB38_2:
 ; CHECK-NEXT:    vmv1r.v v0, v8
 ; CHECK-NEXT:    vsetvli zero, a1, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v8, (a0), v0.t
@@ -480,10 +493,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
 ; CHECK-NEXT:    slli a5, a3, 1
 ; CHECK-NEXT:    vmv1r.v v8, v0
 ; CHECK-NEXT:    mv a4, a2
-; CHECK-NEXT:    bltu a2, a5, .LBB38_2
+; CHECK-NEXT:    bltu a2, a5, .LBB39_2
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    mv a4, a5
-; CHECK-NEXT:  .LBB38_2:
+; CHECK-NEXT:  .LBB39_2:
 ; CHECK-NEXT:    sub a6, a4, a3
 ; CHECK-NEXT:    sltu a7, a4, a6
 ; CHECK-NEXT:    addi a7, a7, -1
@@ -499,10 +512,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
 ; CHECK-NEXT:    sltu a2, a2, a5
 ; CHECK-NEXT:    addi a2, a2, -1
 ; CHECK-NEXT:    and a2, a2, a5
-; CHECK-NEXT:    bltu a2, a3, .LBB38_4
+; CHECK-NEXT:    bltu a2, a3, .LBB39_4
 ; CHECK-NEXT:  # %bb.3:
 ; CHECK-NEXT:    mv a2, a3
-; CHECK-NEXT:  .LBB38_4:
+; CHECK-NEXT:  .LBB39_4:
 ; CHECK-NEXT:    slli a5, a3, 4
 ; CHECK-NEXT:    srli a6, a3, 2
 ; CHECK-NEXT:    vsetvli a7, zero, e8, mf2, ta, ma
@@ -510,10 +523,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
 ; CHECK-NEXT:    add a5, a0, a5
 ; CHECK-NEXT:    vsetvli zero, a2, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v24, (a5), v0.t
-; CHECK-NEXT:    bltu a4, a3, .LBB38_6
+; CHECK-NEXT:    bltu a4, a3, .LBB39_6
 ; CHECK-NEXT:  # %bb.5:
 ; CHECK-NEXT:    mv a4, a3
-; CHECK-NEXT:  .LBB38_6:
+; CHECK-NEXT:  .LBB39_6:
 ; CHECK-NEXT:    vmv1r.v v0, v8
 ; CHECK-NEXT:    vsetvli zero, a4, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v8, (a0), v0.t

>From 2b4e480a233c726620a4956f694134cd86aaf0a7 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Sat, 6 Jul 2024 12:57:14 +0800
Subject: [PATCH 2/2] [RISCV] Allow folding vmerge into masked ops when mask is
 the same

We currently only fold a vmerge into a masked true operand if the vmerge has an all-ones mask, since we end up keeping the mask from the true operand.

But if the masks are the same then we can still fold, because vmerge and true have the same passthru. If an element was masked off in the original vmerge, it will also be masked off in the resulting true, and will have the same passthru value.

The motivation for this is to lower masked VP loads and stores with passthrus to masked RVV instructions. Normally you can express a masked RVV instruction with a mask undisturbed passthru via a combination of a VP op with an all-ones mask and a vp.merge. But for loads and stores you need the same mask on the VP op as well as the vp.merge.
---
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp   | 31 ++++++++++++++-----
 .../rvv/rvv-peephole-vmerge-masked-vops.ll    |  7 ++---
 llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll     |  7 ++---
 llvm/test/CodeGen/RISCV/rvv/vpload.ll         |  6 ++--
 4 files changed, 29 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index ce6a396e9ced9..ce1e55a80fe62 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -3523,24 +3523,26 @@ bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
   return false;
 }
 
-static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
+// After ISel, a vector pseudo's mask will be copied to V0, and the CopyToReg
+// will be glued to the pseudo. This tries to up the value that was copied to
+// V0.
+static SDValue getMaskSetter(SDValue MaskOp, SDValue GlueOp) {
   // Check that we're using V0 as a mask register.
   if (!isa<RegisterSDNode>(MaskOp) ||
       cast<RegisterSDNode>(MaskOp)->getReg() != RISCV::V0)
-    return false;
+    return SDValue();
 
   // The glued user defines V0.
   const auto *Glued = GlueOp.getNode();
 
   if (!Glued || Glued->getOpcode() != ISD::CopyToReg)
-    return false;
+    return SDValue();
 
   // Check that we're defining V0 as a mask register.
   if (!isa<RegisterSDNode>(Glued->getOperand(1)) ||
       cast<RegisterSDNode>(Glued->getOperand(1))->getReg() != RISCV::V0)
-    return false;
+    return SDValue();
 
-  // Check the instruction defining V0; it needs to be a VMSET pseudo.
   SDValue MaskSetter = Glued->getOperand(2);
 
   // Sometimes the VMSET is wrapped in a COPY_TO_REGCLASS, e.g. if the mask came
@@ -3549,6 +3551,15 @@ static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
       MaskSetter->getMachineOpcode() == RISCV::COPY_TO_REGCLASS)
     MaskSetter = MaskSetter->getOperand(0);
 
+  return MaskSetter;
+}
+
+static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
+  // Check the instruction defining V0; it needs to be a VMSET pseudo.
+  SDValue MaskSetter = getMaskSetter(MaskOp, GlueOp);
+  if (!MaskSetter)
+    return false;
+
   const auto IsVMSet = [](unsigned Opc) {
     return Opc == RISCV::PseudoVMSET_M_B1 || Opc == RISCV::PseudoVMSET_M_B16 ||
            Opc == RISCV::PseudoVMSET_M_B2 || Opc == RISCV::PseudoVMSET_M_B32 ||
@@ -3755,12 +3766,16 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
       return false;
   }
 
-  // If True is masked then the vmerge must have an all 1s mask, since we're
-  // going to keep the mask from True.
+  // If True is masked then the vmerge must have either the same mask or an all
+  // 1s mask, since we're going to keep the mask from True.
   if (IsMasked && Mask) {
     // FIXME: Support mask agnostic True instruction which would have an
     // undef merge operand.
-    if (!usesAllOnesMask(Mask, Glue))
+    SDValue TrueMask =
+        getMaskSetter(True->getOperand(Info->MaskOpIdx),
+                      True->getOperand(True->getNumOperands() - 1));
+    assert(TrueMask);
+    if (!usesAllOnesMask(Mask, Glue) && getMaskSetter(Mask, Glue) != TrueMask)
       return false;
   }
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
index 4acd241a652f4..d26fd0ca26c72 100644
--- a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
@@ -244,11 +244,8 @@ define <vscale x 2 x i32> @vpmerge_viota(<vscale x 2 x i32> %passthru, <vscale x
 define <vscale x 2 x i32> @vpmerge_vadd_same_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl) {
 ; CHECK-LABEL: vpmerge_vadd_same_mask:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vmv1r.v v11, v8
-; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, mu
-; CHECK-NEXT:    vadd.vv v11, v9, v10, v0.t
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, tu, ma
-; CHECK-NEXT:    vmerge.vvm v8, v8, v11, v0
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, tu, mu
+; CHECK-NEXT:    vadd.vv v8, v9, v10, v0.t
 ; CHECK-NEXT:    ret
   %a = call <vscale x 2 x i32> @llvm.riscv.vadd.mask.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl, i64 1)
   %b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> %m, i64 %vl)
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
index f9d992a40299c..f635b61fc3e5b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
@@ -85,11 +85,8 @@ define <vscale x 1 x float> @vfmacc_vv_nxv1f32_tu(<vscale x 1 x half> %a, <vscal
 define <vscale x 1 x float> @vfmacc_vv_nxv1f32_masked__tu(<vscale x 1 x half> %a, <vscale x 1 x half> %b, <vscale x 1 x float> %c, <vscale x 1 x i1> %m, i32 zeroext %evl) {
 ; ZVFH-LABEL: vfmacc_vv_nxv1f32_masked__tu:
 ; ZVFH:       # %bb.0:
-; ZVFH-NEXT:    vmv1r.v v11, v10
-; ZVFH-NEXT:    vsetvli zero, a0, e16, mf4, ta, ma
-; ZVFH-NEXT:    vfwmacc.vv v11, v8, v9, v0.t
-; ZVFH-NEXT:    vsetvli zero, zero, e32, mf2, tu, ma
-; ZVFH-NEXT:    vmerge.vvm v10, v10, v11, v0
+; ZVFH-NEXT:    vsetvli zero, a0, e16, mf4, tu, mu
+; ZVFH-NEXT:    vfwmacc.vv v10, v8, v9, v0.t
 ; ZVFH-NEXT:    vmv1r.v v8, v10
 ; ZVFH-NEXT:    ret
 ;
diff --git a/llvm/test/CodeGen/RISCV/rvv/vpload.ll b/llvm/test/CodeGen/RISCV/rvv/vpload.ll
index 54d194d3b73d2..c0a210e680c79 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vpload.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vpload.ll
@@ -29,10 +29,8 @@ define <vscale x 1 x i8> @vpload_nxv1i8_allones_mask(ptr %ptr, i32 zeroext %evl)
 define <vscale x 1 x i8> @vpload_nxv1i8_passthru(ptr %ptr, <vscale x 1 x i1> %m, <vscale x 1 x i8> %passthru, i32 zeroext %evl) {
 ; CHECK-LABEL: vpload_nxv1i8_passthru:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a1, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v9, (a0), v0.t
-; CHECK-NEXT:    vsetvli zero, zero, e8, mf8, tu, ma
-; CHECK-NEXT:    vmerge.vvm v8, v8, v9, v0
+; CHECK-NEXT:    vsetvli zero, a1, e8, mf8, tu, mu
+; CHECK-NEXT:    vle8.v v8, (a0), v0.t
 ; CHECK-NEXT:    ret
   %load = call <vscale x 1 x i8> @llvm.vp.load.nxv1i8.p0(ptr %ptr, <vscale x 1 x i1> %m, i32 %evl)
   %merge = call <vscale x 1 x i8> @llvm.vp.merge.nxv1i8(<vscale x 1 x i1> %m, <vscale x 1 x i8> %load, <vscale x 1 x i8> %passthru, i32 %evl)



More information about the llvm-commits mailing list