[llvm] [WASM] Fold bitselect with splat zero (PR #147305)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 7 06:57:25 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-webassembly

Author: jjasmine (badumbatish)

<details>
<summary>Changes</summary>

[WASM] Fold bitselect with argument <0>

Fixes #<!-- -->73454
Fold bitselect with the splat <0> in either first or second
argument.

- For first argument <0>: vselect <0>, X, Y -> Y
- For second argument <0>: vselect Y, <0>, X -> AND(<X>, !<Y>)
- For third argument <0>: vselect X, Y, <0> -> AND(<Y>, <X>)

Detailed explanation in the implementation.


---
Full diff: https://github.com/llvm/llvm-project/pull/147305.diff


2 Files Affected:

- (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+63) 
- (added) llvm/test/CodeGen/WebAssembly/simd-bitselect.ll (+48) 


``````````diff
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index bf2e04caa0a61..8009bfcc861c3 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -191,6 +191,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
     // Combine vector mask reductions into alltrue/anytrue
     setTargetDAGCombine(ISD::SETCC);
 
+    // Convert vselect of various zero arguments to AND
+    setTargetDAGCombine(ISD::VSELECT);
+
     // Convert vector to integer bitcasts to bitmask
     setTargetDAGCombine(ISD::BITCAST);
 
@@ -3210,6 +3213,64 @@ static SDValue performTruncateCombine(SDNode *N,
   return truncateVectorWithNARROW(OutVT, In, DL, DAG);
 }
 
+static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
+  // In the tablegen.td, vselect A B C -> bitselect B C A
+
+  // SCENARIO A
+  // vselect <0>, X, Y
+  // -> bitselect X, Y, <0>
+  // -> or (AND(X, <0>), AND(<Y>, !<0>))
+  // -> or (0, AND(<Y>, !<0>))
+  // -> AND(Y, !<0>)
+  // -> AND(Y, 1)
+  // -> Y
+
+  // SCENARIO B
+  // vselect Y, <0>, X
+  // -> bitselect <0>, X, Y
+  // -> or (AND(<0>, Y), AND(<X>, !<Y>))
+  // -> or (0, AND(<X>, !<Y>))
+  // -> AND(<X>, !<Y>)
+
+  // SCENARIO C
+  // vselect X, Y, <0>
+  // -> bitselect Y, <0>, X
+  // -> or (AND(Y, X), AND(<0>, !X))
+  // -> or (AND(Y, X), <0>)
+  // -> AND(Y, X)
+
+  using namespace llvm::SDPatternMatch;
+  assert(N->getOpcode() == ISD::VSELECT);
+
+  SDLoc DL(N);
+
+  SDValue Cond = N->getOperand(0), LHS = N->getOperand(1),
+          RHS = N->getOperand(2);
+  EVT NVT = N->getValueType(0);
+
+  APInt SplatValue;
+
+  // SCENARIO A
+  if (ISD::isConstantSplatVector(Cond.getNode(), SplatValue) &&
+      SplatValue.isZero())
+    return RHS;
+
+  // SCENARIO B
+  if (ISD::isConstantSplatVector(LHS.getNode(), SplatValue) &&
+      SplatValue.isZero())
+    return DAG.getNode(
+        ISD::AND, DL, NVT,
+        {RHS, DAG.getSExtOrTrunc(DAG.getNOT(DL, Cond, Cond.getValueType()), DL,
+                                 NVT)});
+
+  // SCENARIO C
+  if (ISD::isConstantSplatVector(RHS.getNode(), SplatValue) &&
+      SplatValue.isZero())
+    return DAG.getNode(ISD::AND, DL, NVT,
+                       {LHS, DAG.getSExtOrTrunc(Cond, DL, NVT)});
+  return SDValue();
+}
+
 static SDValue performBitcastCombine(SDNode *N,
                                      TargetLowering::DAGCombinerInfo &DCI) {
   using namespace llvm::SDPatternMatch;
@@ -3505,6 +3566,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
   switch (N->getOpcode()) {
   default:
     return SDValue();
+  case ISD::VSELECT:
+    return performVSelectCombine(N, DCI.DAG);
   case ISD::BITCAST:
     return performBitcastCombine(N, DCI);
   case ISD::SETCC:
diff --git a/llvm/test/CodeGen/WebAssembly/simd-bitselect.ll b/llvm/test/CodeGen/WebAssembly/simd-bitselect.ll
new file mode 100644
index 0000000000000..7803709dc8b4a
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-bitselect.ll
@@ -0,0 +1,48 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s  -O3 -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
+target triple = "wasm32-unknown-unknown"
+
+define void @bitselect_first_zero(ptr %output, ptr  %input) {
+; CHECK-LABEL: bitselect_first_zero:
+; CHECK:         .functype bitselect_first_zero (i32, i32) -> ()
+; CHECK-NEXT:  # %bb.0: # %start
+; CHECK-NEXT:    v128.load $push6=, 0($1)
+; CHECK-NEXT:    local.tee $push5=, $2=, $pop6
+; CHECK-NEXT:    v128.const $push0=, 2139095040, 2139095040, 2139095040, 2139095040
+; CHECK-NEXT:    v128.and $push1=, $2, $pop0
+; CHECK-NEXT:    v128.const $push2=, 0, 0, 0, 0
+; CHECK-NEXT:    i32x4.ne $push3=, $pop1, $pop2
+; CHECK-NEXT:    v128.and $push4=, $pop5, $pop3
+; CHECK-NEXT:    v128.store 0($0), $pop4
+; CHECK-NEXT:    return
+start:
+  %input.val = load <4 x i32>, ptr %input, align 16
+  %0 = and <4 x i32> %input.val, splat (i32 2139095040)
+  %1 = icmp eq <4 x i32> %0, zeroinitializer
+  %2 = select <4 x i1> %1, <4 x i32> zeroinitializer, <4 x i32> %input.val
+  store <4 x i32> %2, ptr %output, align 16
+  ret void
+}
+
+
+define void @bitselect_second_zero(ptr %output, ptr %input) {
+; CHECK-LABEL: bitselect_second_zero:
+; CHECK:         .functype bitselect_second_zero (i32, i32) -> ()
+; CHECK-NEXT:  # %bb.0: # %start
+; CHECK-NEXT:    v128.load $push6=, 0($1)
+; CHECK-NEXT:    local.tee $push5=, $2=, $pop6
+; CHECK-NEXT:    v128.const $push0=, 2139095040, 2139095040, 2139095040, 2139095040
+; CHECK-NEXT:    v128.and $push1=, $2, $pop0
+; CHECK-NEXT:    v128.const $push2=, 0, 0, 0, 0
+; CHECK-NEXT:    i32x4.eq $push3=, $pop1, $pop2
+; CHECK-NEXT:    v128.and $push4=, $pop5, $pop3
+; CHECK-NEXT:    v128.store 0($0), $pop4
+; CHECK-NEXT:    return
+start:
+  %input.val = load <4 x i32>, ptr %input, align 16
+  %0 = and <4 x i32> %input.val, splat (i32 2139095040)
+  %1 = icmp eq <4 x i32> %0, zeroinitializer
+  %2 = select  <4 x i1> %1, <4 x i32> %input.val, <4 x i32> zeroinitializer
+  store <4 x i32> %2, ptr %output, align 16
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/147305


More information about the llvm-commits mailing list