[llvm] [DAGCombiner] Use generalized pattern matcher in foldBoolSelectToLogic (PR #79101)

Liao Chunyu via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 29 18:12:04 PST 2024


https://github.com/ChunyuLiao updated https://github.com/llvm/llvm-project/pull/79101

>From 6c0a7a8125693274a8a7d598ed689e6f2e52488a Mon Sep 17 00:00:00 2001
From: Liao Chunyu <chunyu at iscas.ac.cn>
Date: Wed, 24 Jan 2024 12:12:23 +0800
Subject: [PATCH] [DAGCombiner] Use generalized pattern matcher in
 foldBoolSelectToLogic support vp.select

TODO: Possibly other functions could be supported, eg: SimplifySelect()
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 40 +++++++---
 llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll     | 80 +++++++++++++++++++
 2 files changed, 110 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 87184fe409eade..b17724cd07209b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -478,6 +478,7 @@ namespace {
     SDValue visitCTPOP(SDNode *N);
     SDValue visitSELECT(SDNode *N);
     SDValue visitVSELECT(SDNode *N);
+    SDValue visitVP_SELECT(SDNode *N);
     SDValue visitSELECT_CC(SDNode *N);
     SDValue visitSETCC(SDNode *N);
     SDValue visitSETCCCARRY(SDNode *N);
@@ -927,6 +928,9 @@ class VPMatchContext {
     assert(Root->isVPOpcode());
     if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
       RootMaskOp = Root->getOperand(*RootMaskPos);
+    else if (Root->getOpcode() == ISD::VP_SELECT)
+      RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
+                                          Root->getOperand(0).getValueType());
 
     if (auto RootVLenPos =
             ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
@@ -11420,35 +11424,42 @@ SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
   return SDValue();
 }
 
+template <class MatchContextClass>
 static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
-  assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT) &&
-         "Expected a (v)select");
+  assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
+          N->getOpcode() == ISD::VP_SELECT) &&
+         "Expected a (v)(vp.)select");
   SDValue Cond = N->getOperand(0);
   SDValue T = N->getOperand(1), F = N->getOperand(2);
   EVT VT = N->getValueType(0);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  MatchContextClass matcher(DAG, TLI, N);
+
   if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
     return SDValue();
 
   // select Cond, Cond, F --> or Cond, F
   // select Cond, 1, F    --> or Cond, F
   if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
-    return DAG.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
+    return matcher.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
 
   // select Cond, T, Cond --> and Cond, T
   // select Cond, T, 0    --> and Cond, T
   if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
-    return DAG.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
+    return matcher.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
 
   // select Cond, T, 1 --> or (not Cond), T
   if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
-    SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
-    return DAG.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
+    SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
+                                      DAG.getAllOnesConstant(SDLoc(N), VT));
+    return matcher.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
   }
 
   // select Cond, 0, F --> and (not Cond), F
   if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
-    SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
-    return DAG.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
+    SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
+                                      DAG.getAllOnesConstant(SDLoc(N), VT));
+    return matcher.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
   }
 
   return SDValue();
@@ -11524,7 +11535,7 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
     return V;
 
-  if (SDValue V = foldBoolSelectToLogic(N, DAG))
+  if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DAG))
     return V;
 
   // select (not Cond), N1, N2 -> select Cond, N2, N1
@@ -12138,6 +12149,13 @@ SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
+  if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DAG))
+    return V;
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -12148,7 +12166,7 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
     return V;
 
-  if (SDValue V = foldBoolSelectToLogic(N, DAG))
+  if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DAG))
     return V;
 
   // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
@@ -26401,6 +26419,8 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
       return visitVP_FSUB(N);
     case ISD::VP_FMA:
       return visitFMA<VPMatchContext>(N);
+    case ISD::VP_SELECT:
+      return visitVP_SELECT(N);
     }
     return SDValue();
   }
diff --git a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll
index 9e7df5eab8dda9..1b568bf8801b10 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll
@@ -745,3 +745,83 @@ define <vscale x 16 x double> @select_nxv16f64(<vscale x 16 x i1> %a, <vscale x
   %v = call <vscale x 16 x double> @llvm.vp.select.nxv16f64(<vscale x 16 x i1> %a, <vscale x 16 x double> %b, <vscale x 16 x double> %c, i32 %evl)
   ret <vscale x 16 x double> %v
 }
+
+define <vscale x 2 x i1> @select_zero(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: select_zero:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vmand.mm v0, v0, v8
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> zeroinitializer, i32 %evl)
+  ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_one(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: select_one:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vmorn.mm v0, v8, v0
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %evl)
+  ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_x_zero(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
+; CHECK-LABEL: select_x_zero:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vmand.mm v0, v0, v8
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> zeroinitializer, i32 %evl)
+  ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_x_one(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
+; CHECK-LABEL: select_x_one:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vmorn.mm v0, v8, v0
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %evl)
+  ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_zero_x(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
+; CHECK-LABEL: select_zero_x:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vmandn.mm v0, v8, v0
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> zeroinitializer, <vscale x 2 x i1> %y, i32 %evl)
+ ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_one_x(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
+; CHECK-LABEL: select_one_x:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vmor.mm v0, v0, v8
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %y, i32 %evl)
+ ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_cond_cond_x(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: select_cond_cond_x:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vmor.mm v0, v0, v8
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 %evl)
+  ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_cond_x_cond(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: select_cond_x_cond:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vmand.mm v0, v0, v8
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %x, i32 %evl)
+  ret <vscale x 2 x i1> %a
+}



More information about the llvm-commits mailing list