[llvm] [RISCV] Reverse (add x, (zext c)) back to (select c, (add x, 1), x) (PR #87236)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 1 05:50:33 PDT 2024


https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/87236

A common induction variable pattern that seems to be emitted by the loop vectorizer is a vadd.vv where one of the operands is a vmerge.vim of zeroes and ones:

```asm
          vmv.v.i v12, 0
  .LBB9_109:                              # %vector.body
          vsetvli zero, zero, e8, mf4, ta, ma
          vluxei64.v      v9, (t1), v10
          vmseq.vi        v0, v9, 0
          vsetvli zero, zero, e32, m1, ta, ma
          vmerge.vim      v9, v12, 1, v0
          vadd.vv v8, v8, v9
```

I'm not sure if this is what it generates directly, but in any case InstCombine will transform `select c, (add x, 1), x` to `add x, (zext c)`.

On RISC-V though we don't have a native instruction for zero extending i1 elements, and it gets lowered as a vmerge.vim and vmv.v.i instead.

We can reverse this transform so that we pull the select outside of the binary op, which allows us to fold it into a masked op:

```asm
      vadd.vi v8, v8, 1, v0.t
```

Specifically, we can do this transform for any binary op where the identity is zero.

This pattern doesn't show up in the in-tree tests, but shows up more frequently in llvm-test-suite/SPEC CPU 2017:
[llvm-test-suite.diff.txt](https://github.com/llvm/llvm-project/files/14823564/zext-binop.diff.txt)

Alive2 proof: https://alive2.llvm.org/ce/z/VKFegj

>From fb91c9d24ec47ff8eb4811af6cf4172b00803a80 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 1 Apr 2024 19:15:31 +0800
Subject: [PATCH 1/2] Add tests for converting i1 zext into select in
 RISCVCodeGenPrepare

---
 .../CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll   | 13 +++++++++++++
 llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll | 12 ++++++++++++
 2 files changed, 25 insertions(+)

diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
index 4c5835afd49e64..13b636cbbadb7b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
@@ -42,3 +42,16 @@ vector.body:
 exit:
   ret float %acc
 }
+
+define <vscale x 2 x i32> @i1_zext_add(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
+; CHECK-LABEL: i1_zext_add:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.v.i v9, 0
+; CHECK-NEXT:    vmerge.vim v9, v9, 1, v0
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
+  %add = add <vscale x 2 x i32> %b, %zext
+  ret <vscale x 2 x i32> %add
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
index 006fc269050b0a..2e6a1dcab31863 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
@@ -44,3 +44,15 @@ vector.body:
 exit:
   ret float %acc
 }
+
+define <vscale x 2 x i32> @i1_zext_add(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
+; CHECK-LABEL: define <vscale x 2 x i32> @i1_zext_add(
+; CHECK-SAME: <vscale x 2 x i1> [[A:%.*]], <vscale x 2 x i32> [[B:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext <vscale x 2 x i1> [[A]] to <vscale x 2 x i32>
+; CHECK-NEXT:    [[ADD:%.*]] = add <vscale x 2 x i32> [[B]], [[ZEXT]]
+; CHECK-NEXT:    ret <vscale x 2 x i32> [[ADD]]
+;
+  %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
+  %add = add <vscale x 2 x i32> %b, %zext
+  ret <vscale x 2 x i32> %add
+}

>From 8dd319b88f9b6d63a36f8aa93207b15a3bb36d5b Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 1 Apr 2024 18:44:11 +0800
Subject: [PATCH 2/2] [RISCV] Reverse (add x, (zext c)) back to (select c, (add
 x, 1), x)

A common induction variable pattern that seems to be emitted by the loop vectorizer is a vadd.vv where one of the operands is a vmerge.vim of zeroes and ones:

	  vmv.v.i v12, 0
  .LBB9_109:                              # %vector.body
	  vsetvli zero, zero, e8, mf4, ta, ma
	  vluxei64.v      v9, (t1), v10
	  vmseq.vi        v0, v9, 0
	  vsetvli zero, zero, e32, m1, ta, ma
	  vmerge.vim      v9, v12, 1, v0
	  vadd.vv v8, v8, v9

I'm not sure if this is what it generates directly, but in any case InstCombine will transform `select c, (add x, 1), x` to `add x, (zext c)`.

On RISC-V though we don't have a native instruction for zero extending i1 elements, and it gets lowered as a vmerge.vim and vmv.v.i instead.

We can reverse this transform so that we pull the select outside of the binary op, which allows us to fold it into a masked op:

      vadd.vi v8, v8, 1, v0.t

Specifically, we can do this transform for any binary op where the identity is zero.

Alive2 proof: https://alive2.llvm.org/ce/z/VKFegj
---
 llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp | 47 +++++++++++++++++++
 .../CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll | 28 ++++++-----
 .../RISCV/rvv/riscv-codegenprepare-asm.ll     |  6 +--
 .../CodeGen/RISCV/rvv/riscv-codegenprepare.ll |  5 +-
 llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll   | 27 ++++-------
 5 files changed, 78 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
index 53fcc527e615dd..d27649b31a985b 100644
--- a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
+++ b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
@@ -52,12 +52,59 @@ class RISCVCodeGenPrepare : public FunctionPass,
   }
 
   bool visitInstruction(Instruction &I) { return false; }
+  bool visitAdd(BinaryOperator &BO);
   bool visitAnd(BinaryOperator &BO);
   bool visitIntrinsicInst(IntrinsicInst &I);
 };
 
 } // end anonymous namespace
 
+/// InstCombine will canonicalize selects of binary ops where the identity is
+/// zero to zexts:
+///
+/// select c, (add x, 1), x -> add x, (zext c)
+///
+/// On RISC-V though, a zext of an i1 vector will be lowered as a vmv.v.i and a
+/// vmerge.vim:
+///
+///       vmv.v.i v12, 0
+///       vmerge.vim      v9, v12, 1, v0
+///       vadd.vv v8, v8, v9
+///
+/// Reverse this transform so that we pull the select outside of the binary op,
+/// which allows us to fold it into a masked op:
+///
+///       vadd.vi v8, v8, 1, v0.t
+bool RISCVCodeGenPrepare::visitAdd(BinaryOperator &BO) {
+  VectorType *Ty = dyn_cast<VectorType>(BO.getType());
+  if (!Ty)
+    return false;
+
+  Constant *Identity = ConstantExpr::getIdentity(&BO, BO.getType());
+  if (!Identity->isZeroValue())
+    return false;
+
+  using namespace PatternMatch;
+
+  Value *Mask, *RHS;
+  if (!match(&BO, m_c_BinOp(m_OneUse(m_ZExt(m_Value(Mask))), m_Value(RHS))))
+    return false;
+
+  if (!cast<VectorType>(Mask->getType())->getElementType()->isIntegerTy(1))
+    return false;
+
+  IRBuilder<> Builder(&BO);
+  Value *Splat = Builder.CreateVectorSplat(
+      Ty->getElementCount(), ConstantInt::get(Ty->getElementType(), 1));
+  Value *Add = Builder.CreateAdd(RHS, Splat);
+  Value *Select = Builder.CreateSelect(Mask, Add, RHS);
+
+  BO.replaceAllUsesWith(Select);
+  BO.eraseFromParent();
+
+  return true;
+}
+
 // Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set,
 // but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill
 // the upper 32 bits with ones.
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll
index 954edf872aff8d..43479ed184039c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vaaddu.ll
@@ -168,13 +168,14 @@ define <8 x i64> @vaaddu_vv_v8i64_floor(<8 x i64> %x, <8 x i64> %y) {
 define <8 x i1> @vaaddu_vv_v8i1_floor(<8 x i1> %x, <8 x i1> %y) {
 ; CHECK-LABEL: vaaddu_vv_v8i1_floor:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vmv.v.i v9, 0
-; CHECK-NEXT:    vmerge.vim v10, v9, 1, v0
+; CHECK-NEXT:    vmv1r.v v9, v0
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, mu
+; CHECK-NEXT:    vmv.v.i v10, 0
 ; CHECK-NEXT:    vmv1r.v v0, v8
-; CHECK-NEXT:    vmerge.vim v8, v9, 1, v0
-; CHECK-NEXT:    csrwi vxrm, 2
-; CHECK-NEXT:    vaaddu.vv v8, v10, v8
+; CHECK-NEXT:    vmerge.vim v8, v10, 1, v0
+; CHECK-NEXT:    vmv1r.v v0, v9
+; CHECK-NEXT:    vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    vsrl.vi v8, v8, 1
 ; CHECK-NEXT:    vand.vi v8, v8, 1
 ; CHECK-NEXT:    vmsne.vi v0, v8, 0
 ; CHECK-NEXT:    ret
@@ -421,13 +422,16 @@ define <8 x i64> @vaaddu_vv_v8i64_ceil(<8 x i64> %x, <8 x i64> %y) {
 define <8 x i1> @vaaddu_vv_v8i1_ceil(<8 x i1> %x, <8 x i1> %y) {
 ; CHECK-LABEL: vaaddu_vv_v8i1_ceil:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vmv.v.i v9, 0
-; CHECK-NEXT:    vmerge.vim v10, v9, 1, v0
+; CHECK-NEXT:    vmv1r.v v9, v0
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, mu
+; CHECK-NEXT:    vmv.v.i v10, 0
 ; CHECK-NEXT:    vmv1r.v v0, v8
-; CHECK-NEXT:    vmerge.vim v8, v9, 1, v0
-; CHECK-NEXT:    csrwi vxrm, 0
-; CHECK-NEXT:    vaaddu.vv v8, v10, v8
+; CHECK-NEXT:    vmerge.vim v8, v10, 1, v0
+; CHECK-NEXT:    vmv1r.v v0, v9
+; CHECK-NEXT:    vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    li a0, 1
+; CHECK-NEXT:    csrwi vxrm, 2
+; CHECK-NEXT:    vaaddu.vx v8, v8, a0
 ; CHECK-NEXT:    vand.vi v8, v8, 1
 ; CHECK-NEXT:    vmsne.vi v0, v8, 0
 ; CHECK-NEXT:    ret
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
index 13b636cbbadb7b..14d0924f791793 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
@@ -46,10 +46,8 @@ exit:
 define <vscale x 2 x i32> @i1_zext_add(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
 ; CHECK-LABEL: i1_zext_add:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vmv.v.i v9, 0
-; CHECK-NEXT:    vmerge.vim v9, v9, 1, v0
-; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, mu
+; CHECK-NEXT:    vadd.vi v8, v8, 1, v0.t
 ; CHECK-NEXT:    ret
   %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
   %add = add <vscale x 2 x i32> %b, %zext
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
index 2e6a1dcab31863..97f5f81123bfbb 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
@@ -49,8 +49,9 @@ define <vscale x 2 x i32> @i1_zext_add(<vscale x 2 x i1> %a, <vscale x 2 x i32>
 ; CHECK-LABEL: define <vscale x 2 x i32> @i1_zext_add(
 ; CHECK-SAME: <vscale x 2 x i1> [[A:%.*]], <vscale x 2 x i32> [[B:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[ZEXT:%.*]] = zext <vscale x 2 x i1> [[A]] to <vscale x 2 x i32>
-; CHECK-NEXT:    [[ADD:%.*]] = add <vscale x 2 x i32> [[B]], [[ZEXT]]
-; CHECK-NEXT:    ret <vscale x 2 x i32> [[ADD]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add <vscale x 2 x i32> [[B]], shufflevector (<vscale x 2 x i32> insertelement (<vscale x 2 x i32> poison, i32 1, i64 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer)
+; CHECK-NEXT:    [[TMP2:%.*]] = select <vscale x 2 x i1> [[A]], <vscale x 2 x i32> [[TMP1]], <vscale x 2 x i32> [[B]]
+; CHECK-NEXT:    ret <vscale x 2 x i32> [[TMP2]]
 ;
   %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
   %add = add <vscale x 2 x i32> %b, %zext
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 66e6883dd1d3e3..f97481abe4b7c5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1365,23 +1365,13 @@ define <vscale x 8 x i64> @vwaddu_wx_nxv8i64_nxv8i8(<vscale x 8 x i64> %va, i8 %
 ; Make sure that we don't introduce any V{S,Z}EXT_VL nodes with i1 types from
 ; combineBinOp_VLToVWBinOp_VL, since they can't be selected.
 define <vscale x 1 x i64> @i1_zext(<vscale x 1 x i1> %va, <vscale x 1 x i64> %vb, ptr %p) {
-; RV32-LABEL: i1_zext:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
-; RV32-NEXT:    vmv.v.i v9, 0
-; RV32-NEXT:    vmerge.vim v9, v9, 1, v0
-; RV32-NEXT:    vadd.vv v8, v9, v8
-; RV32-NEXT:    li a1, 42
-; RV32-NEXT:    sh a1, 0(a0)
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: i1_zext:
-; RV64:       # %bb.0:
-; RV64-NEXT:    vsetvli a1, zero, e64, m1, ta, mu
-; RV64-NEXT:    vadd.vi v8, v8, 1, v0.t
-; RV64-NEXT:    li a1, 42
-; RV64-NEXT:    sh a1, 0(a0)
-; RV64-NEXT:    ret
+; CHECK-LABEL: i1_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e64, m1, ta, mu
+; CHECK-NEXT:    vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    li a1, 42
+; CHECK-NEXT:    sh a1, 0(a0)
+; CHECK-NEXT:    ret
   %vc = zext <vscale x 1 x i1> %va to <vscale x 1 x i64>
   %vd = add <vscale x 1 x i64> %vc, %vb
 
@@ -1466,3 +1456,6 @@ define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsca
   %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
   ret <vscale x 2 x i32> %or
 }
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}



More information about the llvm-commits mailing list