[llvm] [RISCV] Reverse (add x, (zext c)) back to (select c, (add x, 1), x) (PR #87236)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 1 05:51:05 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

<details>
<summary>Changes</summary>

A common induction variable pattern that seems to be emitted by the loop vectorizer is a vadd.vv where one of the operands is a vmerge.vim of zeroes and ones:

```asm
          vmv.v.i v12, 0
  .LBB9_109:                              # %vector.body
          vsetvli zero, zero, e8, mf4, ta, ma
          vluxei64.v      v9, (t1), v10
          vmseq.vi        v0, v9, 0
          vsetvli zero, zero, e32, m1, ta, ma
          vmerge.vim      v9, v12, 1, v0
          vadd.vv v8, v8, v9
```

I'm not sure if this is what it generates directly, but in any case InstCombine will transform `select c, (add x, 1), x` to `add x, (zext c)`.

On RISC-V though we don't have a native instruction for zero extending i1 elements, and it gets lowered as a vmerge.vim and vmv.v.i instead.

We can reverse this transform so that we pull the select outside of the binary op, which allows us to fold it into a masked op:

```asm
      vadd.vi v8, v8, 1, v0.t
```

Specifically, we can do this transform for any binary op where the identity is zero.

This pattern doesn't show up in the in-tree tests, but shows up more frequently in llvm-test-suite/SPEC CPU 2017:
[llvm-test-suite.diff.txt](https://github.com/llvm/llvm-project/files/14823564/zext-binop.diff.txt)

Alive2 proof: https://alive2.llvm.org/ce/z/VKFegj

---
Full diff: https://github.com/llvm/llvm-project/pull/87236.diff


5 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp (+47) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll (+16-12) 
- (modified) llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll (+11) 
- (modified) llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll (+13) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll (+10-17) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
index 53fcc527e615dd..d27649b31a985b 100644
--- a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
+++ b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
@@ -52,12 +52,59 @@ class RISCVCodeGenPrepare : public FunctionPass,
   }
 
   bool visitInstruction(Instruction &I) { return false; }
+  bool visitAdd(BinaryOperator &BO);
   bool visitAnd(BinaryOperator &BO);
   bool visitIntrinsicInst(IntrinsicInst &I);
 };
 
 } // end anonymous namespace
 
+/// InstCombine will canonicalize selects of binary ops where the identity is
+/// zero to zexts:
+///
+/// select c, (add x, 1), x -> add x, (zext c)
+///
+/// On RISC-V though, a zext of an i1 vector will be lowered as a vmv.v.i and a
+/// vmerge.vim:
+///
+///       vmv.v.i v12, 0
+///       vmerge.vim      v9, v12, 1, v0
+///       vadd.vv v8, v8, v9
+///
+/// Reverse this transform so that we pull the select outside of the binary op,
+/// which allows us to fold it into a masked op:
+///
+///       vadd.vi v8, v8, 1, v0.t
+bool RISCVCodeGenPrepare::visitAdd(BinaryOperator &BO) {
+  VectorType *Ty = dyn_cast<VectorType>(BO.getType());
+  if (!Ty)
+    return false;
+
+  Constant *Identity = ConstantExpr::getIdentity(&BO, BO.getType());
+  if (!Identity->isZeroValue())
+    return false;
+
+  using namespace PatternMatch;
+
+  Value *Mask, *RHS;
+  if (!match(&BO, m_c_BinOp(m_OneUse(m_ZExt(m_Value(Mask))), m_Value(RHS))))
+    return false;
+
+  if (!cast<VectorType>(Mask->getType())->getElementType()->isIntegerTy(1))
+    return false;
+
+  IRBuilder<> Builder(&BO);
+  Value *Splat = Builder.CreateVectorSplat(
+      Ty->getElementCount(), ConstantInt::get(Ty->getElementType(), 1));
+  Value *Add = Builder.CreateAdd(RHS, Splat);
+  Value *Select = Builder.CreateSelect(Mask, Add, RHS);
+
+  BO.replaceAllUsesWith(Select);
+  BO.eraseFromParent();
+
+  return true;
+}
+
 // Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set,
 // but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill
 // the upper 32 bits with ones.
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll
index 954edf872aff8d..43479ed184039c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll
@@ -168,13 +168,14 @@ define <8 x i64> @vaaddu_vv_v8i64_floor(<8 x i64> %x, <8 x i64> %y) {
 define <8 x i1> @vaaddu_vv_v8i1_floor(<8 x i1> %x, <8 x i1> %y) {
 ; CHECK-LABEL: vaaddu_vv_v8i1_floor:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vmv.v.i v9, 0
-; CHECK-NEXT:    vmerge.vim v10, v9, 1, v0
+; CHECK-NEXT:    vmv1r.v v9, v0
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, mu
+; CHECK-NEXT:    vmv.v.i v10, 0
 ; CHECK-NEXT:    vmv1r.v v0, v8
-; CHECK-NEXT:    vmerge.vim v8, v9, 1, v0
-; CHECK-NEXT:    csrwi vxrm, 2
-; CHECK-NEXT:    vaaddu.vv v8, v10, v8
+; CHECK-NEXT:    vmerge.vim v8, v10, 1, v0
+; CHECK-NEXT:    vmv1r.v v0, v9
+; CHECK-NEXT:    vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    vsrl.vi v8, v8, 1
 ; CHECK-NEXT:    vand.vi v8, v8, 1
 ; CHECK-NEXT:    vmsne.vi v0, v8, 0
 ; CHECK-NEXT:    ret
@@ -421,13 +422,16 @@ define <8 x i64> @vaaddu_vv_v8i64_ceil(<8 x i64> %x, <8 x i64> %y) {
 define <8 x i1> @vaaddu_vv_v8i1_ceil(<8 x i1> %x, <8 x i1> %y) {
 ; CHECK-LABEL: vaaddu_vv_v8i1_ceil:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vmv.v.i v9, 0
-; CHECK-NEXT:    vmerge.vim v10, v9, 1, v0
+; CHECK-NEXT:    vmv1r.v v9, v0
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, mu
+; CHECK-NEXT:    vmv.v.i v10, 0
 ; CHECK-NEXT:    vmv1r.v v0, v8
-; CHECK-NEXT:    vmerge.vim v8, v9, 1, v0
-; CHECK-NEXT:    csrwi vxrm, 0
-; CHECK-NEXT:    vaaddu.vv v8, v10, v8
+; CHECK-NEXT:    vmerge.vim v8, v10, 1, v0
+; CHECK-NEXT:    vmv1r.v v0, v9
+; CHECK-NEXT:    vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    li a0, 1
+; CHECK-NEXT:    csrwi vxrm, 2
+; CHECK-NEXT:    vaaddu.vx v8, v8, a0
 ; CHECK-NEXT:    vand.vi v8, v8, 1
 ; CHECK-NEXT:    vmsne.vi v0, v8, 0
 ; CHECK-NEXT:    ret
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
index 4c5835afd49e64..14d0924f791793 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
@@ -42,3 +42,14 @@ vector.body:
 exit:
   ret float %acc
 }
+
+define <vscale x 2 x i32> @i1_zext_add(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
+; CHECK-LABEL: i1_zext_add:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, mu
+; CHECK-NEXT:    vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    ret
+  %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
+  %add = add <vscale x 2 x i32> %b, %zext
+  ret <vscale x 2 x i32> %add
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
index 006fc269050b0a..97f5f81123bfbb 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
@@ -44,3 +44,16 @@ vector.body:
 exit:
   ret float %acc
 }
+
+define <vscale x 2 x i32> @i1_zext_add(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
+; CHECK-LABEL: define <vscale x 2 x i32> @i1_zext_add(
+; CHECK-SAME: <vscale x 2 x i1> [[A:%.*]], <vscale x 2 x i32> [[B:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext <vscale x 2 x i1> [[A]] to <vscale x 2 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = add <vscale x 2 x i32> [[B]], shufflevector (<vscale x 2 x i32> insertelement (<vscale x 2 x i32> poison, i32 1, i64 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer)
+; CHECK-NEXT:    [[TMP2:%.*]] = select <vscale x 2 x i1> [[A]], <vscale x 2 x i32> [[TMP1]], <vscale x 2 x i32> [[B]]
+; CHECK-NEXT:    ret <vscale x 2 x i32> [[TMP2]]
+;
+  %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
+  %add = add <vscale x 2 x i32> %b, %zext
+  ret <vscale x 2 x i32> %add
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 66e6883dd1d3e3..f97481abe4b7c5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1365,23 +1365,13 @@ define <vscale x 8 x i64> @vwaddu_wx_nxv8i64_nxv8i8(<vscale x 8 x i64> %va, i8 %
 ; Make sure that we don't introduce any V{S,Z}EXT_VL nodes with i1 types from
 ; combineBinOp_VLToVWBinOp_VL, since they can't be selected.
 define <vscale x 1 x i64> @i1_zext(<vscale x 1 x i1> %va, <vscale x 1 x i64> %vb, ptr %p) {
-; RV32-LABEL: i1_zext:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
-; RV32-NEXT:    vmv.v.i v9, 0
-; RV32-NEXT:    vmerge.vim v9, v9, 1, v0
-; RV32-NEXT:    vadd.vv v8, v9, v8
-; RV32-NEXT:    li a1, 42
-; RV32-NEXT:    sh a1, 0(a0)
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: i1_zext:
-; RV64:       # %bb.0:
-; RV64-NEXT:    vsetvli a1, zero, e64, m1, ta, mu
-; RV64-NEXT:    vadd.vi v8, v8, 1, v0.t
-; RV64-NEXT:    li a1, 42
-; RV64-NEXT:    sh a1, 0(a0)
-; RV64-NEXT:    ret
+; CHECK-LABEL: i1_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e64, m1, ta, mu
+; CHECK-NEXT:    vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    li a1, 42
+; CHECK-NEXT:    sh a1, 0(a0)
+; CHECK-NEXT:    ret
   %vc = zext <vscale x 1 x i1> %va to <vscale x 1 x i64>
   %vd = add <vscale x 1 x i64> %vc, %vb
 
@@ -1466,3 +1456,6 @@ define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsca
   %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
   ret <vscale x 2 x i32> %or
 }
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/87236


More information about the llvm-commits mailing list