[llvm] [SelectionDAG][RISCV] Move VP_REDUCE* legalization to LegalizeDAG.cpp. (PR #90522)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 29 21:48:57 PDT 2024
https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/90522
>From 96cc3ac6bd6e8d59d04c64137a4732aa5fac1ce8 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Mon, 29 Apr 2024 13:48:21 -0700
Subject: [PATCH 1/2] [SelectionDAG][RISCV] Move VP_REDUCE* legalization to
LegalizeDAG.cpp.
LegalizeVectorType is responsible for legalizing nodes that perform an
operation on each element may need to scalarize.
This is not true for nodes like VP_REDUCE.*, BUILD_VECTOR,
SHUFFLE_VECTOR, EXTRACT_SUBVECTOR, etc.
This patch drops any nodes with a scalar result from LegalizeVectorOps
and handles them in LegalizeDAG instead.
This required moving the reduction promotion to LegalizeDAG. I have
removed the support integer promotion as it was incorrect for integer
min/max reductions. Since it was untested, it was best to assert on it
until it was really needed.
There are a couple regressions that can be fixed with a small DAG combine
which I will do as a follow up.
---
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 64 +++++++++++++++-
.../SelectionDAG/LegalizeVectorOps.cpp | 73 ++-----------------
.../rvv/fixed-vectors-reduction-int-vp.ll | 26 ++++---
.../CodeGen/RISCV/rvv/vreductions-int-vp.ll | 7 +-
4 files changed, 87 insertions(+), 83 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 46e54b5366d66a..9dd40531abb005 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -180,6 +180,13 @@ class SelectionDAGLegalize {
SmallVectorImpl<SDValue> &Results);
SDValue PromoteLegalFP_TO_INT_SAT(SDNode *Node, const SDLoc &dl);
+ /// Implements vector reduce operation promotion.
+ ///
+ /// All vector operands are promoted to a vector type with larger element
+ /// type, and the start value is promoted to a larger scalar type. Then the
+ /// result is truncated back to the original scalar type.
+ void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
+
SDValue ExpandPARITY(SDValue Op, const SDLoc &dl);
SDValue ExpandExtractFromVectorThroughStack(SDValue Op);
@@ -2979,6 +2986,49 @@ SDValue SelectionDAGLegalize::ExpandPARITY(SDValue Op, const SDLoc &dl) {
return DAG.getNode(ISD::AND, dl, VT, Result, DAG.getConstant(1, dl, VT));
}
+void SelectionDAGLegalize::PromoteReduction(SDNode *Node,
+ SmallVectorImpl<SDValue> &Results) {
+ MVT VecVT = Node->getOperand(1).getSimpleValueType();
+ MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
+ MVT ScalarVT = Node->getSimpleValueType(0);
+ MVT NewScalarVT = NewVecVT.getVectorElementType();
+
+ SDLoc DL(Node);
+ SmallVector<SDValue, 4> Operands(Node->getNumOperands());
+
+ // promote the initial value.
+ // FIXME: Support integer.
+ assert(Node->getOperand(0).getValueType().isFloatingPoint() &&
+ "Only FP promotion is supported");
+ Operands[0] =
+ DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
+
+ for (unsigned j = 1; j != Node->getNumOperands(); ++j)
+ if (Node->getOperand(j).getValueType().isVector() &&
+ !(ISD::isVPOpcode(Node->getOpcode()) &&
+ ISD::getVPMaskIdx(Node->getOpcode()) == j)) { // Skip mask operand.
+ // promote the vector operand.
+ // FIXME: Support integer.
+ assert(Node->getOperand(j).getValueType().isFloatingPoint() &&
+ "Only FP promotion is supported");
+ Operands[j] =
+ DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
+ } else {
+ Operands[j] = Node->getOperand(j); // Skip VL operand.
+ }
+
+ SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
+ Node->getFlags());
+
+ if (ScalarVT.isFloatingPoint())
+ Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
+ DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+ else
+ Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
+
+ Results.push_back(Res);
+}
+
bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
LLVM_DEBUG(dbgs() << "Trying to expand node\n");
SmallVector<SDValue, 8> Results;
@@ -4955,7 +5005,12 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
if (Node->getOpcode() == ISD::STRICT_UINT_TO_FP ||
Node->getOpcode() == ISD::STRICT_SINT_TO_FP ||
Node->getOpcode() == ISD::STRICT_FSETCC ||
- Node->getOpcode() == ISD::STRICT_FSETCCS)
+ Node->getOpcode() == ISD::STRICT_FSETCCS ||
+ Node->getOpcode() == ISD::VP_REDUCE_FADD ||
+ Node->getOpcode() == ISD::VP_REDUCE_FMUL ||
+ Node->getOpcode() == ISD::VP_REDUCE_FMAX ||
+ Node->getOpcode() == ISD::VP_REDUCE_FMIN ||
+ Node->getOpcode() == ISD::VP_REDUCE_SEQ_FADD)
OVT = Node->getOperand(1).getSimpleValueType();
if (Node->getOpcode() == ISD::BR_CC ||
Node->getOpcode() == ISD::SELECT_CC)
@@ -5613,6 +5668,13 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)));
break;
}
+ case ISD::VP_REDUCE_FADD:
+ case ISD::VP_REDUCE_FMUL:
+ case ISD::VP_REDUCE_FMAX:
+ case ISD::VP_REDUCE_FMIN:
+ case ISD::VP_REDUCE_SEQ_FADD:
+ PromoteReduction(Node, Results);
+ break;
}
// Replace the original node with the legalized result.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8f87ee8e09393a..423df9ae6b2a55 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -176,13 +176,6 @@ class VectorLegalizer {
/// truncated back to the original type.
void PromoteFP_TO_INT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
- /// Implements vector reduce operation promotion.
- ///
- /// All vector operands are promoted to a vector type with larger element
- /// type, and the start value is promoted to a larger scalar type. Then the
- /// result is truncated back to the original scalar type.
- void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
-
/// Implements vector setcc operation promotion.
///
/// All vector operands are promoted to a vector type with larger element
@@ -510,6 +503,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
if (Action != TargetLowering::Legal) \
break; \
} \
+ /* Defer non-vector results to LegalizeDAG. */ \
+ if (!Node->getValueType(0).isVector()) { \
+ Action = TargetLowering::Legal; \
+ break; \
+ } \
Action = TLI.getOperationAction(Node->getOpcode(), LegalizeVT); \
} break;
#include "llvm/IR/VPIntrinsics.def"
@@ -580,50 +578,6 @@ bool VectorLegalizer::LowerOperationWrapper(SDNode *Node,
return true;
}
-void VectorLegalizer::PromoteReduction(SDNode *Node,
- SmallVectorImpl<SDValue> &Results) {
- MVT VecVT = Node->getOperand(1).getSimpleValueType();
- MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
- MVT ScalarVT = Node->getSimpleValueType(0);
- MVT NewScalarVT = NewVecVT.getVectorElementType();
-
- SDLoc DL(Node);
- SmallVector<SDValue, 4> Operands(Node->getNumOperands());
-
- // promote the initial value.
- if (Node->getOperand(0).getValueType().isFloatingPoint())
- Operands[0] =
- DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
- else
- Operands[0] =
- DAG.getNode(ISD::ANY_EXTEND, DL, NewScalarVT, Node->getOperand(0));
-
- for (unsigned j = 1; j != Node->getNumOperands(); ++j)
- if (Node->getOperand(j).getValueType().isVector() &&
- !(ISD::isVPOpcode(Node->getOpcode()) &&
- ISD::getVPMaskIdx(Node->getOpcode()) == j)) // Skip mask operand.
- // promote the vector operand.
- if (Node->getOperand(j).getValueType().isFloatingPoint())
- Operands[j] =
- DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
- else
- Operands[j] =
- DAG.getNode(ISD::ANY_EXTEND, DL, NewVecVT, Node->getOperand(j));
- else
- Operands[j] = Node->getOperand(j); // Skip VL operand.
-
- SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
- Node->getFlags());
-
- if (ScalarVT.isFloatingPoint())
- Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
- DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
- else
- Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
-
- Results.push_back(Res);
-}
-
void VectorLegalizer::PromoteSETCC(SDNode *Node,
SmallVectorImpl<SDValue> &Results) {
MVT VecVT = Node->getOperand(0).getSimpleValueType();
@@ -708,23 +662,6 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
// Promote the operation by extending the operand.
PromoteFP_TO_INT(Node, Results);
return;
- case ISD::VP_REDUCE_ADD:
- case ISD::VP_REDUCE_MUL:
- case ISD::VP_REDUCE_AND:
- case ISD::VP_REDUCE_OR:
- case ISD::VP_REDUCE_XOR:
- case ISD::VP_REDUCE_SMAX:
- case ISD::VP_REDUCE_SMIN:
- case ISD::VP_REDUCE_UMAX:
- case ISD::VP_REDUCE_UMIN:
- case ISD::VP_REDUCE_FADD:
- case ISD::VP_REDUCE_FMUL:
- case ISD::VP_REDUCE_FMAX:
- case ISD::VP_REDUCE_FMIN:
- case ISD::VP_REDUCE_SEQ_FADD:
- // Promote the operation by extending the operand.
- PromoteReduction(Node, Results);
- return;
case ISD::VP_SETCC:
case ISD::SETCC:
// Promote the operation by extending the operand.
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
index 02a989a9699606..b874a4477f5d17 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
@@ -802,25 +802,27 @@ define signext i32 @vpreduce_xor_v64i32(i32 signext %s, <64 x i32> %v, <64 x i1>
; CHECK-LABEL: vpreduce_xor_v64i32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, ta, ma
-; CHECK-NEXT: li a3, 32
; CHECK-NEXT: vslidedown.vi v24, v0, 4
-; CHECK-NEXT: mv a2, a1
-; CHECK-NEXT: bltu a1, a3, .LBB49_2
+; CHECK-NEXT: addi a2, a1, -32
+; CHECK-NEXT: sltu a3, a1, a2
+; CHECK-NEXT: addi a3, a3, -1
+; CHECK-NEXT: li a4, 32
+; CHECK-NEXT: and a2, a3, a2
+; CHECK-NEXT: bltu a1, a4, .LBB49_2
; CHECK-NEXT: # %bb.1:
-; CHECK-NEXT: li a2, 32
+; CHECK-NEXT: li a1, 32
; CHECK-NEXT: .LBB49_2:
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
; CHECK-NEXT: vmv.s.x v25, a0
-; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma
+; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; CHECK-NEXT: vredxor.vs v25, v8, v25, v0.t
-; CHECK-NEXT: addi a0, a1, -32
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: addi a1, a1, -1
-; CHECK-NEXT: and a0, a1, a0
-; CHECK-NEXT: vsetvli zero, a0, e32, m8, ta, ma
-; CHECK-NEXT: vmv1r.v v0, v24
-; CHECK-NEXT: vredxor.vs v25, v16, v25, v0.t
; CHECK-NEXT: vmv.x.s a0, v25
+; CHECK-NEXT: vsetivli zero, 1, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v8, a0
+; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma
+; CHECK-NEXT: vmv1r.v v0, v24
+; CHECK-NEXT: vredxor.vs v8, v16, v8, v0.t
+; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: ret
%r = call i32 @llvm.vp.reduce.xor.v64i32(i32 %s, <64 x i32> %v, <64 x i1> %m, i32 %evl)
ret i32 %r
diff --git a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
index 7bcf37b1af3c8f..95b64cb662a614 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
@@ -1115,10 +1115,13 @@ define signext i32 @vpreduce_umax_nxv32i32(i32 signext %s, <vscale x 32 x i32> %
; CHECK-NEXT: vmv.s.x v25, a0
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; CHECK-NEXT: vredmaxu.vs v25, v8, v25, v0.t
+; CHECK-NEXT: vmv.x.s a0, v25
+; CHECK-NEXT: vsetivli zero, 1, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v8, a0
; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma
; CHECK-NEXT: vmv1r.v v0, v24
-; CHECK-NEXT: vredmaxu.vs v25, v16, v25, v0.t
-; CHECK-NEXT: vmv.x.s a0, v25
+; CHECK-NEXT: vredmaxu.vs v8, v16, v8, v0.t
+; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: ret
%r = call i32 @llvm.vp.reduce.umax.nxv32i32(i32 %s, <vscale x 32 x i32> %v, <vscale x 32 x i1> %m, i32 %evl)
ret i32 %r
>From 9ad0edc888af3c2db46933a000a1995eaba6cbb8 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Mon, 29 Apr 2024 21:45:22 -0700
Subject: [PATCH 2/2] fixup! remove dead code
---
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 9dd40531abb005..398b5fee990b5d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3020,11 +3020,9 @@ void SelectionDAGLegalize::PromoteReduction(SDNode *Node,
SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
Node->getFlags());
- if (ScalarVT.isFloatingPoint())
- Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
- DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
- else
- Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
+ assert(ScalarVT.isFloatingPoint() && "Only FP promotion is supported");
+ Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
+ DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
Results.push_back(Res);
}
More information about the llvm-commits
mailing list