[llvm] [RISCV] Reverse (add x, (zext c)) back to (select c, (add x, 1), x) (PR #87236)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 1 05:51:05 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Luke Lau (lukel97)
<details>
<summary>Changes</summary>
A common induction variable pattern that seems to be emitted by the loop vectorizer is a vadd.vv where one of the operands is a vmerge.vim of zeroes and ones:
```asm
vmv.v.i v12, 0
.LBB9_109: # %vector.body
vsetvli zero, zero, e8, mf4, ta, ma
vluxei64.v v9, (t1), v10
vmseq.vi v0, v9, 0
vsetvli zero, zero, e32, m1, ta, ma
vmerge.vim v9, v12, 1, v0
vadd.vv v8, v8, v9
```
I'm not sure if this is what it generates directly, but in any case InstCombine will transform `select c, (add x, 1), x` to `add x, (zext c)`.
On RISC-V though we don't have a native instruction for zero extending i1 elements, and it gets lowered as a vmerge.vim and vmv.v.i instead.
We can reverse this transform so that we pull the select outside of the binary op, which allows us to fold it into a masked op:
```asm
vadd.vi v8, v8, 1, v0.t
```
Specifically, we can do this transform for any binary op where the identity is zero.
This pattern doesn't show up in the in-tree tests, but shows up more frequently in llvm-test-suite/SPEC CPU 2017:
[llvm-test-suite.diff.txt](https://github.com/llvm/llvm-project/files/14823564/zext-binop.diff.txt)
Alive2 proof: https://alive2.llvm.org/ce/z/VKFegj
---
Full diff: https://github.com/llvm/llvm-project/pull/87236.diff
5 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp (+47)
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll (+16-12)
- (modified) llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll (+11)
- (modified) llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll (+13)
- (modified) llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll (+10-17)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
index 53fcc527e615dd..d27649b31a985b 100644
--- a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
+++ b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
@@ -52,12 +52,59 @@ class RISCVCodeGenPrepare : public FunctionPass,
}
bool visitInstruction(Instruction &I) { return false; }
+ bool visitAdd(BinaryOperator &BO);
bool visitAnd(BinaryOperator &BO);
bool visitIntrinsicInst(IntrinsicInst &I);
};
} // end anonymous namespace
+/// InstCombine will canonicalize selects of binary ops where the identity is
+/// zero to zexts:
+///
+/// select c, (add x, 1), x -> add x, (zext c)
+///
+/// On RISC-V though, a zext of an i1 vector will be lowered as a vmv.v.i and a
+/// vmerge.vim:
+///
+/// vmv.v.i v12, 0
+/// vmerge.vim v9, v12, 1, v0
+/// vadd.vv v8, v8, v9
+///
+/// Reverse this transform so that we pull the select outside of the binary op,
+/// which allows us to fold it into a masked op:
+///
+/// vadd.vi v8, v8, 1, v0.t
+bool RISCVCodeGenPrepare::visitAdd(BinaryOperator &BO) {
+ VectorType *Ty = dyn_cast<VectorType>(BO.getType());
+ if (!Ty)
+ return false;
+
+ Constant *Identity = ConstantExpr::getIdentity(&BO, BO.getType());
+ if (!Identity->isZeroValue())
+ return false;
+
+ using namespace PatternMatch;
+
+ Value *Mask, *RHS;
+ if (!match(&BO, m_c_BinOp(m_OneUse(m_ZExt(m_Value(Mask))), m_Value(RHS))))
+ return false;
+
+ if (!cast<VectorType>(Mask->getType())->getElementType()->isIntegerTy(1))
+ return false;
+
+ IRBuilder<> Builder(&BO);
+ Value *Splat = Builder.CreateVectorSplat(
+ Ty->getElementCount(), ConstantInt::get(Ty->getElementType(), 1));
+ Value *Add = Builder.CreateAdd(RHS, Splat);
+ Value *Select = Builder.CreateSelect(Mask, Add, RHS);
+
+ BO.replaceAllUsesWith(Select);
+ BO.eraseFromParent();
+
+ return true;
+}
+
// Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set,
// but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill
// the upper 32 bits with ones.
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll
index 954edf872aff8d..43479ed184039c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll
@@ -168,13 +168,14 @@ define <8 x i64> @vaaddu_vv_v8i64_floor(<8 x i64> %x, <8 x i64> %y) {
define <8 x i1> @vaaddu_vv_v8i1_floor(<8 x i1> %x, <8 x i1> %y) {
; CHECK-LABEL: vaaddu_vv_v8i1_floor:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT: vmv.v.i v9, 0
-; CHECK-NEXT: vmerge.vim v10, v9, 1, v0
+; CHECK-NEXT: vmv1r.v v9, v0
+; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu
+; CHECK-NEXT: vmv.v.i v10, 0
; CHECK-NEXT: vmv1r.v v0, v8
-; CHECK-NEXT: vmerge.vim v8, v9, 1, v0
-; CHECK-NEXT: csrwi vxrm, 2
-; CHECK-NEXT: vaaddu.vv v8, v10, v8
+; CHECK-NEXT: vmerge.vim v8, v10, 1, v0
+; CHECK-NEXT: vmv1r.v v0, v9
+; CHECK-NEXT: vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT: vsrl.vi v8, v8, 1
; CHECK-NEXT: vand.vi v8, v8, 1
; CHECK-NEXT: vmsne.vi v0, v8, 0
; CHECK-NEXT: ret
@@ -421,13 +422,16 @@ define <8 x i64> @vaaddu_vv_v8i64_ceil(<8 x i64> %x, <8 x i64> %y) {
define <8 x i1> @vaaddu_vv_v8i1_ceil(<8 x i1> %x, <8 x i1> %y) {
; CHECK-LABEL: vaaddu_vv_v8i1_ceil:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT: vmv.v.i v9, 0
-; CHECK-NEXT: vmerge.vim v10, v9, 1, v0
+; CHECK-NEXT: vmv1r.v v9, v0
+; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu
+; CHECK-NEXT: vmv.v.i v10, 0
; CHECK-NEXT: vmv1r.v v0, v8
-; CHECK-NEXT: vmerge.vim v8, v9, 1, v0
-; CHECK-NEXT: csrwi vxrm, 0
-; CHECK-NEXT: vaaddu.vv v8, v10, v8
+; CHECK-NEXT: vmerge.vim v8, v10, 1, v0
+; CHECK-NEXT: vmv1r.v v0, v9
+; CHECK-NEXT: vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT: li a0, 1
+; CHECK-NEXT: csrwi vxrm, 2
+; CHECK-NEXT: vaaddu.vx v8, v8, a0
; CHECK-NEXT: vand.vi v8, v8, 1
; CHECK-NEXT: vmsne.vi v0, v8, 0
; CHECK-NEXT: ret
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
index 4c5835afd49e64..14d0924f791793 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
@@ -42,3 +42,14 @@ vector.body:
exit:
ret float %acc
}
+
+define <vscale x 2 x i32> @i1_zext_add(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
+; CHECK-LABEL: i1_zext_add:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, mu
+; CHECK-NEXT: vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT: ret
+ %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
+ %add = add <vscale x 2 x i32> %b, %zext
+ ret <vscale x 2 x i32> %add
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
index 006fc269050b0a..97f5f81123bfbb 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
@@ -44,3 +44,16 @@ vector.body:
exit:
ret float %acc
}
+
+define <vscale x 2 x i32> @i1_zext_add(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
+; CHECK-LABEL: define <vscale x 2 x i32> @i1_zext_add(
+; CHECK-SAME: <vscale x 2 x i1> [[A:%.*]], <vscale x 2 x i32> [[B:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT: [[ZEXT:%.*]] = zext <vscale x 2 x i1> [[A]] to <vscale x 2 x i32>
+; CHECK-NEXT: [[TMP1:%.*]] = add <vscale x 2 x i32> [[B]], shufflevector (<vscale x 2 x i32> insertelement (<vscale x 2 x i32> poison, i32 1, i64 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer)
+; CHECK-NEXT: [[TMP2:%.*]] = select <vscale x 2 x i1> [[A]], <vscale x 2 x i32> [[TMP1]], <vscale x 2 x i32> [[B]]
+; CHECK-NEXT: ret <vscale x 2 x i32> [[TMP2]]
+;
+ %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
+ %add = add <vscale x 2 x i32> %b, %zext
+ ret <vscale x 2 x i32> %add
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 66e6883dd1d3e3..f97481abe4b7c5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1365,23 +1365,13 @@ define <vscale x 8 x i64> @vwaddu_wx_nxv8i64_nxv8i8(<vscale x 8 x i64> %va, i8 %
; Make sure that we don't introduce any V{S,Z}EXT_VL nodes with i1 types from
; combineBinOp_VLToVWBinOp_VL, since they can't be selected.
define <vscale x 1 x i64> @i1_zext(<vscale x 1 x i1> %va, <vscale x 1 x i64> %vb, ptr %p) {
-; RV32-LABEL: i1_zext:
-; RV32: # %bb.0:
-; RV32-NEXT: vsetvli a1, zero, e64, m1, ta, ma
-; RV32-NEXT: vmv.v.i v9, 0
-; RV32-NEXT: vmerge.vim v9, v9, 1, v0
-; RV32-NEXT: vadd.vv v8, v9, v8
-; RV32-NEXT: li a1, 42
-; RV32-NEXT: sh a1, 0(a0)
-; RV32-NEXT: ret
-;
-; RV64-LABEL: i1_zext:
-; RV64: # %bb.0:
-; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, mu
-; RV64-NEXT: vadd.vi v8, v8, 1, v0.t
-; RV64-NEXT: li a1, 42
-; RV64-NEXT: sh a1, 0(a0)
-; RV64-NEXT: ret
+; CHECK-LABEL: i1_zext:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a1, zero, e64, m1, ta, mu
+; CHECK-NEXT: vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT: li a1, 42
+; CHECK-NEXT: sh a1, 0(a0)
+; CHECK-NEXT: ret
%vc = zext <vscale x 1 x i1> %va to <vscale x 1 x i64>
%vd = add <vscale x 1 x i64> %vc, %vb
@@ -1466,3 +1456,6 @@ define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsca
%or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
ret <vscale x 2 x i32> %or
}
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}
``````````
</details>
https://github.com/llvm/llvm-project/pull/87236
More information about the llvm-commits
mailing list