[llvm] [LV][EVL] Replace VPInstruction::Select with vp.merge for predicated div/rem (PR #154072)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 18 01:07:48 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-vectorizers
@llvm/pr-subscribers-backend-risc-v
Author: Mel Chen (Mel-Chen)
<details>
<summary>Changes</summary>
Since div/rem operations don’t support a mask operand, the lanes of the divisor that are masked out are currently replaced with 1 using VPInstruction::Select before the predicated div/rem operation.
This patch replaces
```
VPInstruction::Select(logical_and(header_mask, conditional_mask), LHS, RHS)
```
with
```
vp.merge(conditional_mask, LHS, RHS, EVL)
```
so that the header mask can be replaced by EVL in this usage scenario when tail folding with EVL.
---
Full diff: https://github.com/llvm/llvm-project/pull/154072.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+10-8)
- (modified) llvm/test/Transforms/LoopVectorize/RISCV/divrem.ll (+3-3)
``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 05c12b7a1adcc..d015a1ccf9c2a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2160,18 +2160,20 @@ static VPRecipeBase *optimizeMaskToEVL(VPValue *HeaderMask,
return new VPReductionEVLRecipe(*Red, EVL, NewMask);
})
.Case<VPInstruction>([&](VPInstruction *VPI) -> VPRecipeBase * {
- VPValue *LHS, *RHS;
+ VPValue *Cond, *LHS, *RHS;
// Transform select with a header mask condition
- // select(header_mask, LHS, RHS)
+ // select(mask_w/_header_mask, LHS, RHS)
// into vector predication merge.
- // vp.merge(all-true, LHS, RHS, EVL)
- if (!match(VPI, m_Select(m_Specific(HeaderMask), m_VPValue(LHS),
- m_VPValue(RHS))))
+ // vp.merge(mask_w/o_header_mask, LHS, RHS, EVL)
+ if (!match(VPI,
+ m_Select(m_VPValue(Cond), m_VPValue(LHS), m_VPValue(RHS))))
return nullptr;
- // Use all true as the condition because this transformation is
- // limited to selects whose condition is a header mask.
+
+ VPValue *NewMask = GetNewMask(Cond);
+ if (!NewMask)
+ NewMask = &AllOneMask;
return new VPWidenIntrinsicRecipe(
- Intrinsic::vp_merge, {&AllOneMask, LHS, RHS, &EVL},
+ Intrinsic::vp_merge, {NewMask, LHS, RHS, &EVL},
TypeInfo.inferScalarType(LHS), VPI->getDebugLoc());
})
.Default([&](VPRecipeBase *R) { return nullptr; });
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/divrem.ll b/llvm/test/Transforms/LoopVectorize/RISCV/divrem.ll
index 3af328fb6568e..7efaf2080810b 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/divrem.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/divrem.ll
@@ -371,7 +371,7 @@ define void @predicated_udiv(ptr noalias nocapture %a, i64 %v, i64 %n) {
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[A:%.*]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = call <vscale x 2 x i64> @llvm.vp.load.nxv2i64.p0(ptr align 8 [[TMP8]], <vscale x 2 x i1> splat (i1 true), i32 [[TMP12]])
; CHECK-NEXT: [[TMP16:%.*]] = select <vscale x 2 x i1> [[TMP15]], <vscale x 2 x i1> [[TMP6]], <vscale x 2 x i1> zeroinitializer
-; CHECK-NEXT: [[TMP10:%.*]] = select <vscale x 2 x i1> [[TMP16]], <vscale x 2 x i64> [[BROADCAST_SPLAT]], <vscale x 2 x i64> splat (i64 1)
+; CHECK-NEXT: [[TMP10:%.*]] = call <vscale x 2 x i64> @llvm.vp.merge.nxv2i64(<vscale x 2 x i1> [[TMP6]], <vscale x 2 x i64> [[BROADCAST_SPLAT]], <vscale x 2 x i64> splat (i64 1), i32 [[TMP12]])
; CHECK-NEXT: [[TMP11:%.*]] = udiv <vscale x 2 x i64> [[WIDE_LOAD]], [[TMP10]]
; CHECK-NEXT: [[PREDPHI:%.*]] = select <vscale x 2 x i1> [[TMP16]], <vscale x 2 x i64> [[TMP11]], <vscale x 2 x i64> [[WIDE_LOAD]]
; CHECK-NEXT: call void @llvm.vp.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI]], ptr align 8 [[TMP8]], <vscale x 2 x i1> splat (i1 true), i32 [[TMP12]])
@@ -486,7 +486,7 @@ define void @predicated_sdiv(ptr noalias nocapture %a, i64 %v, i64 %n) {
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[A:%.*]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = call <vscale x 2 x i64> @llvm.vp.load.nxv2i64.p0(ptr align 8 [[TMP8]], <vscale x 2 x i1> splat (i1 true), i32 [[TMP12]])
; CHECK-NEXT: [[TMP16:%.*]] = select <vscale x 2 x i1> [[TMP15]], <vscale x 2 x i1> [[TMP6]], <vscale x 2 x i1> zeroinitializer
-; CHECK-NEXT: [[TMP10:%.*]] = select <vscale x 2 x i1> [[TMP16]], <vscale x 2 x i64> [[BROADCAST_SPLAT]], <vscale x 2 x i64> splat (i64 1)
+; CHECK-NEXT: [[TMP10:%.*]] = call <vscale x 2 x i64> @llvm.vp.merge.nxv2i64(<vscale x 2 x i1> [[TMP6]], <vscale x 2 x i64> [[BROADCAST_SPLAT]], <vscale x 2 x i64> splat (i64 1), i32 [[TMP12]])
; CHECK-NEXT: [[TMP11:%.*]] = sdiv <vscale x 2 x i64> [[WIDE_LOAD]], [[TMP10]]
; CHECK-NEXT: [[PREDPHI:%.*]] = select <vscale x 2 x i1> [[TMP16]], <vscale x 2 x i64> [[TMP11]], <vscale x 2 x i64> [[WIDE_LOAD]]
; CHECK-NEXT: call void @llvm.vp.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI]], ptr align 8 [[TMP8]], <vscale x 2 x i1> splat (i1 true), i32 [[TMP12]])
@@ -817,7 +817,7 @@ define void @predicated_sdiv_by_minus_one(ptr noalias nocapture %a, i64 %n) {
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = call <vscale x 16 x i8> @llvm.vp.load.nxv16i8.p0(ptr align 1 [[TMP7]], <vscale x 16 x i1> splat (i1 true), i32 [[TMP12]])
; CHECK-NEXT: [[TMP9:%.*]] = icmp ne <vscale x 16 x i8> [[WIDE_LOAD]], splat (i8 -128)
; CHECK-NEXT: [[TMP16:%.*]] = select <vscale x 16 x i1> [[TMP15]], <vscale x 16 x i1> [[TMP9]], <vscale x 16 x i1> zeroinitializer
-; CHECK-NEXT: [[TMP10:%.*]] = select <vscale x 16 x i1> [[TMP16]], <vscale x 16 x i8> splat (i8 -1), <vscale x 16 x i8> splat (i8 1)
+; CHECK-NEXT: [[TMP10:%.*]] = call <vscale x 16 x i8> @llvm.vp.merge.nxv16i8(<vscale x 16 x i1> [[TMP9]], <vscale x 16 x i8> splat (i8 -1), <vscale x 16 x i8> splat (i8 1), i32 [[TMP12]])
; CHECK-NEXT: [[TMP11:%.*]] = sdiv <vscale x 16 x i8> [[WIDE_LOAD]], [[TMP10]]
; CHECK-NEXT: [[PREDPHI:%.*]] = select <vscale x 16 x i1> [[TMP16]], <vscale x 16 x i8> [[TMP11]], <vscale x 16 x i8> [[WIDE_LOAD]]
; CHECK-NEXT: call void @llvm.vp.store.nxv16i8.p0(<vscale x 16 x i8> [[PREDPHI]], ptr align 1 [[TMP7]], <vscale x 16 x i1> splat (i1 true), i32 [[TMP12]])
``````````
</details>
https://github.com/llvm/llvm-project/pull/154072
More information about the llvm-commits
mailing list