[llvm] 45188c6 - [DAGCombiner] Use generalized pattern matcher in foldBoolSelectToLogic (#79101)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 29 18:26:56 PST 2024
Author: Liao Chunyu
Date: 2024-01-30T10:26:51+08:00
New Revision: 45188c64db68af92596acdb2d9022527f6aa4502
URL: https://github.com/llvm/llvm-project/commit/45188c64db68af92596acdb2d9022527f6aa4502
DIFF: https://github.com/llvm/llvm-project/commit/45188c64db68af92596acdb2d9022527f6aa4502.diff
LOG: [DAGCombiner] Use generalized pattern matcher in foldBoolSelectToLogic (#79101)
support vp.select
TODO: Possibly other functions could be supported, eg: SimplifySelect()
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 87184fe409ead..b17724cd07209 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -478,6 +478,7 @@ namespace {
SDValue visitCTPOP(SDNode *N);
SDValue visitSELECT(SDNode *N);
SDValue visitVSELECT(SDNode *N);
+ SDValue visitVP_SELECT(SDNode *N);
SDValue visitSELECT_CC(SDNode *N);
SDValue visitSETCC(SDNode *N);
SDValue visitSETCCCARRY(SDNode *N);
@@ -927,6 +928,9 @@ class VPMatchContext {
assert(Root->isVPOpcode());
if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
RootMaskOp = Root->getOperand(*RootMaskPos);
+ else if (Root->getOpcode() == ISD::VP_SELECT)
+ RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
+ Root->getOperand(0).getValueType());
if (auto RootVLenPos =
ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
@@ -11420,35 +11424,42 @@ SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
return SDValue();
}
+template <class MatchContextClass>
static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
- assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT) &&
- "Expected a (v)select");
+ assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
+ N->getOpcode() == ISD::VP_SELECT) &&
+ "Expected a (v)(vp.)select");
SDValue Cond = N->getOperand(0);
SDValue T = N->getOperand(1), F = N->getOperand(2);
EVT VT = N->getValueType(0);
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ MatchContextClass matcher(DAG, TLI, N);
+
if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
return SDValue();
// select Cond, Cond, F --> or Cond, F
// select Cond, 1, F --> or Cond, F
if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
- return DAG.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
+ return matcher.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
// select Cond, T, Cond --> and Cond, T
// select Cond, T, 0 --> and Cond, T
if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
- return DAG.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
+ return matcher.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
// select Cond, T, 1 --> or (not Cond), T
if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
- SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
- return DAG.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
+ SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
+ DAG.getAllOnesConstant(SDLoc(N), VT));
+ return matcher.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
}
// select Cond, 0, F --> and (not Cond), F
if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
- SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
- return DAG.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
+ SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
+ DAG.getAllOnesConstant(SDLoc(N), VT));
+ return matcher.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
}
return SDValue();
@@ -11524,7 +11535,7 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
if (SDValue V = DAG.simplifySelect(N0, N1, N2))
return V;
- if (SDValue V = foldBoolSelectToLogic(N, DAG))
+ if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DAG))
return V;
// select (not Cond), N1, N2 -> select Cond, N2, N1
@@ -12138,6 +12149,13 @@ SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
+ if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DAG))
+ return V;
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitVSELECT(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -12148,7 +12166,7 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
if (SDValue V = DAG.simplifySelect(N0, N1, N2))
return V;
- if (SDValue V = foldBoolSelectToLogic(N, DAG))
+ if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DAG))
return V;
// vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
@@ -26401,6 +26419,8 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
return visitVP_FSUB(N);
case ISD::VP_FMA:
return visitFMA<VPMatchContext>(N);
+ case ISD::VP_SELECT:
+ return visitVP_SELECT(N);
}
return SDValue();
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll
index 9e7df5eab8dda..1b568bf8801b1 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll
@@ -745,3 +745,83 @@ define <vscale x 16 x double> @select_nxv16f64(<vscale x 16 x i1> %a, <vscale x
%v = call <vscale x 16 x double> @llvm.vp.select.nxv16f64(<vscale x 16 x i1> %a, <vscale x 16 x double> %b, <vscale x 16 x double> %c, i32 %evl)
ret <vscale x 16 x double> %v
}
+
+define <vscale x 2 x i1> @select_zero(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: select_zero:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT: vmand.mm v0, v0, v8
+; CHECK-NEXT: ret
+ %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> zeroinitializer, i32 %evl)
+ ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_one(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: select_one:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT: vmorn.mm v0, v8, v0
+; CHECK-NEXT: ret
+ %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %evl)
+ ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_x_zero(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
+; CHECK-LABEL: select_x_zero:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT: vmand.mm v0, v0, v8
+; CHECK-NEXT: ret
+ %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> zeroinitializer, i32 %evl)
+ ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_x_one(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
+; CHECK-LABEL: select_x_one:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT: vmorn.mm v0, v8, v0
+; CHECK-NEXT: ret
+ %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %evl)
+ ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_zero_x(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
+; CHECK-LABEL: select_zero_x:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT: vmandn.mm v0, v8, v0
+; CHECK-NEXT: ret
+ %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> zeroinitializer, <vscale x 2 x i1> %y, i32 %evl)
+ ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_one_x(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
+; CHECK-LABEL: select_one_x:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT: vmor.mm v0, v0, v8
+; CHECK-NEXT: ret
+ %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %y, i32 %evl)
+ ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_cond_cond_x(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: select_cond_cond_x:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT: vmor.mm v0, v0, v8
+; CHECK-NEXT: ret
+ %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 %evl)
+ ret <vscale x 2 x i1> %a
+}
+
+define <vscale x 2 x i1> @select_cond_x_cond(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: select_cond_x_cond:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT: vmand.mm v0, v0, v8
+; CHECK-NEXT: ret
+ %a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %x, i32 %evl)
+ ret <vscale x 2 x i1> %a
+}
More information about the llvm-commits
mailing list