[llvm] [RISCV][VP] Introduce vp saturating addition/substraction and RISC-V support. (PR #82370)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 20 07:23:33 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-backend-risc-v
Author: Yeting Kuo (yetingk)
<details>
<summary>Changes</summary>
This patch also pick the MatchContext framework from DAGCombiner to an indiviual header file to make the framework be used from other files in llvm/lib/CodeGen/SelectionDAG/.
---
Patch is 455.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82370.diff
15 Files Affected:
- (modified) llvm/docs/LangRef.rst (+199)
- (modified) llvm/include/llvm/IR/Intrinsics.td (+20)
- (modified) llvm/include/llvm/IR/VPIntrinsics.def (+24)
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+2-135)
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+23-13)
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+2)
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+8-8)
- (added) llvm/lib/CodeGen/SelectionDAG/MatchContext.h (+175)
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+11-1)
- (added) llvm/test/CodeGen/RISCV/rvv/vsadd-vp.ll (+2015)
- (added) llvm/test/CodeGen/RISCV/rvv/vsaddu-vp.ll (+2014)
- (added) llvm/test/CodeGen/RISCV/rvv/vssub-vp.ll (+2067)
- (added) llvm/test/CodeGen/RISCV/rvv/vssubu-vp.ll (+2065)
- (added) llvm/test/CodeGen/RISCV/rvv/vssubu-vp.s ()
- (modified) llvm/unittests/IR/VPIntrinsicTest.cpp (+8)
``````````diff
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index fd2e3aacd0169c..676ba8a41362bf 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -16718,6 +16718,7 @@ an operation is greater than the maximum value, the result is set (or
"clamped") to this maximum. If it is below the minimum, it is clamped to this
minimum.
+.. _int_sadd_sat:
'``llvm.sadd.sat.*``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -16767,6 +16768,8 @@ Examples
%res = call i4 @llvm.sadd.sat.i4(i4 -4, i4 -5) ; %res = -8
+.. _int_uadd_sat:
+
'``llvm.uadd.sat.*``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -16814,6 +16817,8 @@ Examples
%res = call i4 @llvm.uadd.sat.i4(i4 8, i4 8) ; %res = 15
+.. _int_ssub_sat:
+
'``llvm.ssub.sat.*``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -16862,6 +16867,8 @@ Examples
%res = call i4 @llvm.ssub.sat.i4(i4 4, i4 -5) ; %res = 7
+.. _int_usub_sat:
+
'``llvm.usub.sat.*``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -23579,6 +23586,198 @@ Examples:
%also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+.. _int_vp_sadd_sat:
+
+'``llvm.vp.sadd.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+ declare <16 x i32> @llvm.vp.sadd.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+ declare <vscale x 4 x i32> @llvm.vp.sadd.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+ declare <256 x i64> @llvm.vp.sadd.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated signed saturating addition of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.sadd.sat``' intrinsic performs sadd.sat (:ref:`sadd.sat <int_sadd_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+ %r = call <4 x i32> @llvm.vp.sadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+ ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+ %t = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+ %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
+.. _int_vp_uadd_sat:
+
+'``llvm.vp.uadd.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+ declare <16 x i32> @llvm.vp.uadd.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+ declare <vscale x 4 x i32> @llvm.vp.uadd.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+ declare <256 x i64> @llvm.vp.uadd.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated unsigned saturating addition of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.uadd.sat``' intrinsic performs uadd.sat (:ref:`uadd.sat <int_uadd_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+ %r = call <4 x i32> @llvm.vp.uadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+ ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+ %t = call <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+ %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
+.. _int_vp_ssub_sat:
+
+'``llvm.vp.ssub.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+ declare <16 x i32> @llvm.vp.ssub.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+ declare <vscale x 4 x i32> @llvm.vp.ssub.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+ declare <256 x i64> @llvm.vp.ssub.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated signed saturating subtraction of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.ssub.sat``' intrinsic performs ssub.sat (:ref:`ssub.sat <int_ssub_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+ %r = call <4 x i32> @llvm.vp.ssub.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+ ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+ %t = call <4 x i32> @llvm.ssub.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+ %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
+.. _int_vp_usub_sat:
+
+'``llvm.vp.usub.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+ declare <16 x i32> @llvm.vp.usub.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+ declare <vscale x 4 x i32> @llvm.vp.usub.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+ declare <256 x i64> @llvm.vp.usub.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated unsigned saturating subtraction of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.usub.sat``' intrinsic performs usub.sat (:ref:`usub.sat <int_usub_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+ %r = call <4 x i32> @llvm.vp.usub.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+ ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+ %t = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+ %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
.. _int_vp_fshl:
'``llvm.vp.fshl.*``' Intrinsics
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 8c0d4d5db32d88..d7c1ce153a6c80 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1933,6 +1933,26 @@ let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
LLVMMatchType<0>,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
llvm_i32_ty]>;
+ def int_vp_sadd_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+ [ LLVMMatchType<0>,
+ LLVMMatchType<0>,
+ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+ llvm_i32_ty]>;
+ def int_vp_uadd_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+ [ LLVMMatchType<0>,
+ LLVMMatchType<0>,
+ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+ llvm_i32_ty]>;
+ def int_vp_ssub_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+ [ LLVMMatchType<0>,
+ LLVMMatchType<0>,
+ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+ llvm_i32_ty]>;
+ def int_vp_usub_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+ [ LLVMMatchType<0>,
+ LLVMMatchType<0>,
+ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+ llvm_i32_ty]>;
// Floating-point arithmetic
def int_vp_fadd : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def
index 3b32b60609f536..4089acf9ec3f05 100644
--- a/llvm/include/llvm/IR/VPIntrinsics.def
+++ b/llvm/include/llvm/IR/VPIntrinsics.def
@@ -293,6 +293,30 @@ BEGIN_REGISTER_VP(vp_fshr, 3, 4, VP_FSHR, -1)
VP_PROPERTY_FUNCTIONAL_INTRINSIC(fshr)
VP_PROPERTY_FUNCTIONAL_SDOPC(FSHR)
END_REGISTER_VP(vp_fshr, VP_FSHR)
+
+// llvm.vp.sadd.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_sadd_sat, 2, 3, VP_SADDSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(sadd_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(SADDSAT)
+END_REGISTER_VP(vp_sadd_sat, VP_SADDSAT)
+
+// llvm.vp.uadd.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_uadd_sat, 2, 3, VP_UADDSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(uadd_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(UADDSAT)
+END_REGISTER_VP(vp_uadd_sat, VP_UADDSAT)
+
+// llvm.vp.ssub.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_ssub_sat, 2, 3, VP_SSUBSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(ssub_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(SSUBSAT)
+END_REGISTER_VP(vp_ssub_sat, VP_SSUBSAT)
+
+// llvm.vp.usub.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_usub_sat, 2, 3, VP_USUBSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(usub_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(USUBSAT)
+END_REGISTER_VP(vp_usub_sat, VP_USUBSAT)
///// } Integer Arithmetic
///// Floating-Point Arithmetic {
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2a09e44e192979..318e1c12c3d56f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -76,6 +76,8 @@
#include <utility>
#include <variant>
+#include "MatchContext.h"
+
using namespace llvm;
#define DEBUG_TYPE "dagcombine"
@@ -888,141 +890,6 @@ class WorklistInserter : public SelectionDAG::DAGUpdateListener {
void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
};
-class EmptyMatchContext {
- SelectionDAG &DAG;
- const TargetLowering &TLI;
-
-public:
- EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
- : DAG(DAG), TLI(TLI) {}
-
- bool match(SDValue OpN, unsigned Opcode) const {
- return Opcode == OpN->getOpcode();
- }
-
- // Same as SelectionDAG::getNode().
- template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
- return DAG.getNode(std::forward<ArgT>(Args)...);
- }
-
- bool isOperationLegalOrCustom(unsigned Op, EVT VT,
- bool LegalOnly = false) const {
- return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
- }
-};
-
-class VPMatchContext {
- SelectionDAG &DAG;
- const TargetLowering &TLI;
- SDValue RootMaskOp;
- SDValue RootVectorLenOp;
-
-public:
- VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
- : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
- assert(Root->isVPOpcode());
- if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
- RootMaskOp = Root->getOperand(*RootMaskPos);
- else if (Root->getOpcode() == ISD::VP_SELECT)
- RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
- Root->getOperand(0).getValueType());
-
- if (auto RootVLenPos =
- ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
- RootVectorLenOp = Root->getOperand(*RootVLenPos);
- }
-
- /// whether \p OpVal is a node that is functionally compatible with the
- /// NodeType \p Opc
- bool match(SDValue OpVal, unsigned Opc) const {
- if (!OpVal->isVPOpcode())
- return OpVal->getOpcode() == Opc;
-
- auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(),
- !OpVal->getFlags().hasNoFPExcept());
- if (BaseOpc != Opc)
- return false;
-
- // Make sure the mask of OpVal is true mask or is same as Root's.
- unsigned VPOpcode = OpVal->getOpcode();
- if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
- SDValue MaskOp = OpVal.getOperand(*MaskPos);
- if (RootMaskOp != MaskOp &&
- !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
- return false;
- }
-
- // Make sure the EVL of OpVal is same as Root's.
- if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
- if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
- return false;
- return true;
- }
-
- // Specialize based on number of operands.
- // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
- // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
- // DAG.getNode(Opcode, DL, VT); }
- SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
- unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
- assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
- ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
- return DAG.getNode(VPOpcode, DL, VT,
- {Operand, RootMaskOp, RootVectorLenOp});
- }
-
- SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
- SDValue N2) {
- unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
- assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
- ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
- return DAG.getNode(VPOpcode, DL, VT,
- {N1, N2, RootMaskOp, RootVectorLenOp});
- }
-
- SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
- SDValue N2, SDValue N3) {
- unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
- assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
- ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
- return DAG.getNode(VPOpcode, DL, VT,
- {N1, N2, N3, RootMaskOp, RootVectorLenOp});
- }
-
- SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
- SDNodeFlags Flags) {
- unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
- assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
- ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
- return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
- Flags);
- }
-
- SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
- SDValue N2, SDNodeFlags Flags) {
- unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
- assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
- ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
- return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
- Flags);
- }
-
- SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
- SDValue N2, SDValue N3, SDNodeFlags Flags) {
- unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
- assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
- ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
- return DAG.getNode(VPOpcode, DL, VT,
- {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
- }
-
- bool isOperationLegalOrCustom(unsigned Op, EVT VT,
- bool LegalOnly = false) const {
- unsigned VPOp = ISD::getVPForBaseOpcode(Op);
- return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
- }
-};
-
} // end anonymous namespace
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index a4ba261686c688..b5b33bb279e8b8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -217,7 +217,15 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::SSUBSAT:
case ISD::USUBSAT:
case ISD::SSHLSAT:
- case ISD::USHLSAT: Res = PromoteIntRes_ADDSUBSHLSAT(N); break;
+ case ISD::USHLSAT:
+ Res = PromoteIntRes_ADDSUBSHLSAT<EmptyMatchContext>(N);
+ break;
+ case ISD::VP_SADDSAT:
+ case ISD::VP_UADDSAT:
+ case ISD::VP_SSUBSAT:
+ case ISD::VP_USUBSAT:
+ Res = PromoteIntRes_ADDSUBSHLSAT<VPMatchContext>(N);
+ break;
case ISD::SMULFIX:
case ISD::SMULFIXSAT:
@@ -934,6 +942,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Overflow(SDNode *N) {
return DAG.getBoolExtOrTrunc(Res.getValue(1), dl, NVT, VT);
}
+template <class MatchContextClass>
SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
// If the promoted type is legal, we can convert this to:
// 1. ANY_EXTEND iN to iM
@@ -945,9 +954,10 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
SDLoc dl(N);
SDValue Op1 = N->getOperand(0);
SDValue Op2 = N->getOperand(1);
+ MatchContextClass matcher(DAG, TLI, N);
unsigned OldBits = Op1.getScalarValueSizeInBits();
- unsigned Opcode = N->getOpcode();
+ unsigned Opcode = matcher.getRootBaseOpcode();
bool IsShift = Opcode == ISD::USHLSAT || Opcode == ISD::SSHLSAT;
SDValue Op1Promoted, Op2Promoted;
@@ -968,18 +978,18 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
APInt MaxVal = APInt::getAllOnes(OldBits).zext(NewBits);
SDValue SatMax = DAG.getConstant(MaxVal, dl, PromotedType);
SDValue Add =
- DAG.getNode(ISD::ADD, dl, PromotedType, Op1Promoted, Op2Promoted);
- return DAG.getNode(ISD::UMIN, dl, PromotedType, Add, SatMax);
+ matcher.getNode(ISD::ADD, dl, PromotedType, Op1Promoted, Op2Promoted);
+ return matcher.getNode(ISD::UMIN, dl, PromotedType, Add, SatMax);
}
// USUBSAT can always be promoted as long as we have zero-extended the args.
if (Opcode == ISD::USUBSAT)
- return DAG.getNode(ISD::USUBSAT, dl, PromotedType, Op1Promoted,
- Op2Promoted);
+ return matcher.getNode(ISD::USUBSAT, dl, PromotedType, Op1Promoted,
+ Op2Promoted);
// Shift cannot use a min/max expansion, we can't detect overflow if all of
// the bits have been shifted out.
- if (IsShift || TLI.isOperationLegal(Opcode, PromotedType)) {
+ if (IsShift || matcher.isOperationLegal(Opcode, PromotedType)) {
unsigned ShiftOp;
switch (Opcode) {
case ISD::SADDSAT:
@@ -1002,11 +1012,11 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
DAG.getNode(ISD::SHL, dl, PromotedType, Op1Promoted, ShiftAmount);
if (!IsShift)
Op2Promoted =
- DAG.getNode(ISD::SHL, dl, PromotedType, Op2Promoted, ShiftAmount);
+ matcher.getNode(ISD::SHL, dl, PromotedType, Op2Promoted, ShiftAmount);
SDValue Result =
- DAG.getNode(Opcode, dl, PromotedType, Op1Promoted, Op2Promoted);
- return DAG.getNode(ShiftOp, dl, PromotedType, Result, ShiftAmount);
+ matcher.getNode(Opcode, dl, PromotedType, Op1Promoted, Op2Promoted);
+ return matcher.getNode(ShiftOp, dl, PromotedType, Result, ShiftAmount);
}
unsigned AddOp = Opcode == ISD::SADDSAT ? ISD::ADD : ISD::SUB;
@@ -1015,9 +1025,9 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
SDValue SatMin = DAG.getConstant(MinVal, dl, PromotedType);
SDValue SatMax = DAG.getConstant(MaxVal,...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/82370
More information about the llvm-commits
mailing list