[llvm] [RISCV] Undo unprofitable zext of icmp combine (PR #134306)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 3 13:43:11 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

<details>
<summary>Changes</summary>

InstCombine will combine this zext of an icmp where the source has a single bit set to a lshr plus trunc (`InstCombinerImpl::transformZExtICmp`):

```llvm
define <vscale x 1 x i8> @<!-- -->f(<vscale x 1 x i64> %x) {
  %1 = and <vscale x 1 x i64> %x, splat (i64 8)
  %2 = icmp ne <vscale x 1 x i64> %1, splat (i64 0)
  %3 = zext <vscale x 1 x i1> %2 to <vscale x 1 x i8>
  ret <vscale x 1 x i8> %3
}
```

```llvm
define <vscale x 1 x i8> @<!-- -->f(<vscale x 1 x i64> %x) #<!-- -->0 {
  %1 = and <vscale x 1 x i64> %x, splat (i64 8)
  %.lobit = lshr exact <vscale x 1 x i64> %1, splat (i64 3)
  %2 = trunc nuw nsw <vscale x 1 x i64> %.lobit to <vscale x 1 x i8>
  ret <vscale x 1 x i8> %2
}
```

In a loop, this ends up being unprofitable for RISC-V because the codegen now goes from:

```asm
f:                                      # @<!-- -->f
	.cfi_startproc
# %bb.0:
	vsetvli	a0, zero, e64, m1, ta, ma
	vand.vi	v8, v8, 8
	vmsne.vi	v0, v8, 0
	vsetvli	zero, zero, e8, mf8, ta, ma
	vmv.v.i	v8, 0
	vmerge.vim	v8, v8, 1, v0
	ret
```

To a series of narrowing vnsrl.wis:

```asm
f:                                      # @<!-- -->f
	.cfi_startproc
# %bb.0:
	vsetvli	a0, zero, e64, m1, ta, ma
	vand.vi	v8, v8, 8
	vsetvli	zero, zero, e32, mf2, ta, ma
	vnsrl.wi	v8, v8, 3
	vsetvli	zero, zero, e16, mf4, ta, ma
	vnsrl.wi	v8, v8, 0
	vsetvli	zero, zero, e8, mf8, ta, ma
	vnsrl.wi	v8, v8, 0
	ret
```

In the original form, the vmv.v.i is loop invariant and is hoisted out, and the vmerge.vim usually gets folded away into a masked instruction, so you usually just end up with a vsetvli + vmsne.vi.

The truncate requires multiple instructions and introduces a vtype toggle for each one, and is measurably slower on the BPI-F3.

This reverses the transform in RISCVCodeGenPrepare for truncations greater than twice the bitwidth, i.e. it keeps single vnsrl.wis. 

Fixes #<!-- -->132245 

---
Full diff: https://github.com/llvm/llvm-project/pull/134306.diff


3 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp (+68-3) 
- (modified) llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll (+81) 
- (modified) llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll (+72) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
index b5cb05f30fb26..e04f3b1d3478e 100644
--- a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
+++ b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
@@ -25,6 +25,7 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/Local.h"
 
 using namespace llvm;
 
@@ -62,10 +63,74 @@ class RISCVCodeGenPrepare : public FunctionPass,
 
 } // end anonymous namespace
 
-// Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set,
-// but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill
-// the upper 32 bits with ones.
+// InstCombinerImpl::transformZExtICmp will narrow a zext of an icmp with a
+// truncation. But RVV doesn't have truncation instructions for more than twice
+// the bitwidth.
+//
+// E.g. trunc <vscale x 1 x i64> %x to <vscale x 1 x i8> will generate:
+//
+//     vsetvli a0, zero, e32, m2, ta, ma
+//     vnsrl.wi v12, v8, 0
+//     vsetvli zero, zero, e16, m1, ta, ma
+//     vnsrl.wi v8, v12, 0
+//     vsetvli zero, zero, e8, mf2, ta, ma
+//     vnsrl.wi v8, v8, 0
+//
+// So reverse the combine so we generate an vmseq/vmsne again:
+//
+// and (lshr (trunc X), ShAmt), 1
+// -->
+// zext (icmp ne (and X, (1 << ShAmt)), 0)
+//
+// and (lshr (not (trunc X)), ShAmt), 1
+// -->
+// zext (icmp eq (and X, (1 << ShAmt)), 0)
+static bool reverseZExtICmpCombine(BinaryOperator &BO) {
+  using namespace PatternMatch;
+
+  assert(BO.getOpcode() == BinaryOperator::And);
+
+  if (!BO.getType()->isVectorTy())
+    return false;
+  const APInt *ShAmt;
+  Value *Inner;
+  if (!match(&BO,
+             m_And(m_OneUse(m_LShr(m_OneUse(m_Value(Inner)), m_APInt(ShAmt))),
+                   m_One())))
+    return false;
+
+  Value *X;
+  bool IsNot;
+  if (match(Inner, m_Not(m_Trunc(m_Value(X)))))
+    IsNot = true;
+  else if (match(Inner, m_Trunc(m_Value(X))))
+    IsNot = false;
+  else
+    return false;
+
+  if (BO.getType()->getScalarSizeInBits() >=
+      X->getType()->getScalarSizeInBits() / 2)
+    return false;
+
+  IRBuilder<> Builder(&BO);
+  Value *Res = Builder.CreateAnd(
+      X, ConstantInt::get(X->getType(), 1 << ShAmt->getZExtValue()));
+  Res = Builder.CreateICmp(IsNot ? CmpInst::Predicate::ICMP_EQ
+                                 : CmpInst::Predicate::ICMP_NE,
+                           Res, ConstantInt::get(X->getType(), 0));
+  Res = Builder.CreateZExt(Res, BO.getType());
+  BO.replaceAllUsesWith(Res);
+  RecursivelyDeleteTriviallyDeadInstructions(&BO);
+  return true;
+}
+
 bool RISCVCodeGenPrepare::visitAnd(BinaryOperator &BO) {
+  if (reverseZExtICmpCombine(BO))
+    return true;
+
+  // Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set,
+  // but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill
+  // the upper 32 bits with ones.
   if (!ST->is64Bit())
     return false;
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
index 4e5f6e0f65489..b6593eac6d92c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
@@ -498,3 +498,84 @@ vector.body:                                      ; preds = %vector.body, %entry
 for.cond.cleanup:                                 ; preds = %vector.body
   ret float %red
 }
+
+define <vscale x 1 x i8> @reverse_zexticmp_i16(<vscale x 1 x i16> %x) {
+; CHECK-LABEL: reverse_zexticmp_i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e8, mf8, ta, ma
+; CHECK-NEXT:    vnsrl.wi v8, v8, 0
+; CHECK-NEXT:    vsrl.vi v8, v8, 2
+; CHECK-NEXT:    vand.vi v8, v8, 1
+; CHECK-NEXT:    ret
+  %1 = trunc <vscale x 1 x i16> %x to <vscale x 1 x i8>
+  %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+  %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+  ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_i32(<vscale x 1 x i32> %x) {
+; CHECK-LABEL: reverse_zexticmp_i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vand.vi v8, v8, 4
+; CHECK-NEXT:    vmsne.vi v0, v8, 0
+; CHECK-NEXT:    vsetvli zero, zero, e8, mf8, ta, ma
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmerge.vim v8, v8, 1, v0
+; CHECK-NEXT:    ret
+  %1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
+  %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+  %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+  ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_neg_i32(<vscale x 1 x i32> %x) {
+; CHECK-LABEL: reverse_zexticmp_neg_i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vand.vi v8, v8, 4
+; CHECK-NEXT:    vmseq.vi v0, v8, 0
+; CHECK-NEXT:    vsetvli zero, zero, e8, mf8, ta, ma
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmerge.vim v8, v8, 1, v0
+; CHECK-NEXT:    ret
+  %1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
+  %2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
+  %3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
+  %4 = and <vscale x 1 x i8> %3, splat (i8 1)
+  ret <vscale x 1 x i8> %4
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_i64(<vscale x 1 x i64> %x) {
+; CHECK-LABEL: reverse_zexticmp_i64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vand.vi v8, v8, 4
+; CHECK-NEXT:    vmsne.vi v0, v8, 0
+; CHECK-NEXT:    vsetvli zero, zero, e8, mf8, ta, ma
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmerge.vim v8, v8, 1, v0
+; CHECK-NEXT:    ret
+  %1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
+  %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+  %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+  ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_neg_i64(<vscale x 1 x i64> %x) {
+; CHECK-LABEL: reverse_zexticmp_neg_i64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vand.vi v8, v8, 4
+; CHECK-NEXT:    vmseq.vi v0, v8, 0
+; CHECK-NEXT:    vsetvli zero, zero, e8, mf8, ta, ma
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmerge.vim v8, v8, 1, v0
+; CHECK-NEXT:    ret
+  %1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
+  %2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
+  %3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
+  %4 = and <vscale x 1 x i8> %3, splat (i8 1)
+  ret <vscale x 1 x i8> %4
+}
+
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
index 8967fb8bf01ac..483e797151325 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
@@ -528,3 +528,75 @@ vector.body:                                      ; preds = %vector.body, %entry
 for.cond.cleanup:                                 ; preds = %vector.body
   ret float %red
 }
+
+define <vscale x 1 x i8> @reverse_zexticmp_i16(<vscale x 1 x i16> %x) {
+; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_i16(
+; CHECK-SAME: <vscale x 1 x i16> [[X:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <vscale x 1 x i16> [[X]] to <vscale x 1 x i8>
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr <vscale x 1 x i8> [[TMP1]], splat (i8 2)
+; CHECK-NEXT:    [[TMP3:%.*]] = and <vscale x 1 x i8> [[TMP2]], splat (i8 1)
+; CHECK-NEXT:    ret <vscale x 1 x i8> [[TMP3]]
+;
+  %1 = trunc <vscale x 1 x i16> %x to <vscale x 1 x i8>
+  %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+  %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+  ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_i32(<vscale x 1 x i32> %x) {
+; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_i32(
+; CHECK-SAME: <vscale x 1 x i32> [[X:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = and <vscale x 1 x i32> [[X]], splat (i32 4)
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <vscale x 1 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
+; CHECK-NEXT:    ret <vscale x 1 x i8> [[TMP3]]
+;
+  %1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
+  %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+  %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+  ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_neg_i32(<vscale x 1 x i32> %x) {
+; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_neg_i32(
+; CHECK-SAME: <vscale x 1 x i32> [[X:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = and <vscale x 1 x i32> [[X]], splat (i32 4)
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq <vscale x 1 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
+; CHECK-NEXT:    ret <vscale x 1 x i8> [[TMP4]]
+;
+  %1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
+  %2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
+  %3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
+  %4 = and <vscale x 1 x i8> %3, splat (i8 1)
+  ret <vscale x 1 x i8> %4
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_i64(<vscale x 1 x i64> %x) {
+; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_i64(
+; CHECK-SAME: <vscale x 1 x i64> [[X:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = and <vscale x 1 x i64> [[X]], splat (i64 4)
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <vscale x 1 x i64> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
+; CHECK-NEXT:    ret <vscale x 1 x i8> [[TMP3]]
+;
+  %1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
+  %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+  %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+  ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_neg_i64(<vscale x 1 x i64> %x) {
+; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_neg_i64(
+; CHECK-SAME: <vscale x 1 x i64> [[X:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = and <vscale x 1 x i64> [[X]], splat (i64 4)
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq <vscale x 1 x i64> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
+; CHECK-NEXT:    ret <vscale x 1 x i8> [[TMP4]]
+;
+  %1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
+  %2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
+  %3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
+  %4 = and <vscale x 1 x i8> %3, splat (i8 1)
+  ret <vscale x 1 x i8> %4
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/134306


More information about the llvm-commits mailing list