[llvm] [RISCV][VP] Introduce vp saturating addition/substraction and RISC-V support. (PR #82370)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 20 07:23:33 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-risc-v

Author: Yeting Kuo (yetingk)

<details>
<summary>Changes</summary>

This patch also pick the MatchContext framework from DAGCombiner to an indiviual header file to make the framework be used from other files in llvm/lib/CodeGen/SelectionDAG/.

---

Patch is 455.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82370.diff


15 Files Affected:

- (modified) llvm/docs/LangRef.rst (+199) 
- (modified) llvm/include/llvm/IR/Intrinsics.td (+20) 
- (modified) llvm/include/llvm/IR/VPIntrinsics.def (+24) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+2-135) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+23-13) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+8-8) 
- (added) llvm/lib/CodeGen/SelectionDAG/MatchContext.h (+175) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+11-1) 
- (added) llvm/test/CodeGen/RISCV/rvv/vsadd-vp.ll (+2015) 
- (added) llvm/test/CodeGen/RISCV/rvv/vsaddu-vp.ll (+2014) 
- (added) llvm/test/CodeGen/RISCV/rvv/vssub-vp.ll (+2067) 
- (added) llvm/test/CodeGen/RISCV/rvv/vssubu-vp.ll (+2065) 
- (added) llvm/test/CodeGen/RISCV/rvv/vssubu-vp.s () 
- (modified) llvm/unittests/IR/VPIntrinsicTest.cpp (+8) 


``````````diff
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index fd2e3aacd0169c..676ba8a41362bf 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -16718,6 +16718,7 @@ an operation is greater than the maximum value, the result is set (or
 "clamped") to this maximum. If it is below the minimum, it is clamped to this
 minimum.
 
+.. _int_sadd_sat:
 
 '``llvm.sadd.sat.*``' Intrinsics
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -16767,6 +16768,8 @@ Examples
       %res = call i4 @llvm.sadd.sat.i4(i4 -4, i4 -5)  ; %res = -8
 
 
+.. _int_uadd_sat:
+
 '``llvm.uadd.sat.*``' Intrinsics
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
@@ -16814,6 +16817,8 @@ Examples
       %res = call i4 @llvm.uadd.sat.i4(i4 8, i4 8)  ; %res = 15
 
 
+.. _int_ssub_sat:
+
 '``llvm.ssub.sat.*``' Intrinsics
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
@@ -16862,6 +16867,8 @@ Examples
       %res = call i4 @llvm.ssub.sat.i4(i4 4, i4 -5)  ; %res = 7
 
 
+.. _int_usub_sat:
+
 '``llvm.usub.sat.*``' Intrinsics
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
@@ -23579,6 +23586,198 @@ Examples:
       %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
 
 
+.. _int_vp_sadd_sat:
+
+'``llvm.vp.sadd.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <16 x i32>  @llvm.vp.sadd.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+      declare <vscale x 4 x i32>  @llvm.vp.sadd.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare <256 x i64>  @llvm.vp.sadd.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated signed saturating addition of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.sadd.sat``' intrinsic performs sadd.sat (:ref:`sadd.sat <int_sadd_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+      %r = call <4 x i32> @llvm.vp.sadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+      ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+      %t = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+      %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
+.. _int_vp_uadd_sat:
+
+'``llvm.vp.uadd.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <16 x i32>  @llvm.vp.uadd.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+      declare <vscale x 4 x i32>  @llvm.vp.uadd.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare <256 x i64>  @llvm.vp.uadd.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated unsigned saturating addition of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.uadd.sat``' intrinsic performs uadd.sat (:ref:`uadd.sat <int_uadd_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+      %r = call <4 x i32> @llvm.vp.uadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+      ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+      %t = call <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+      %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
+.. _int_vp_ssub_sat:
+
+'``llvm.vp.ssub.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <16 x i32>  @llvm.vp.ssub.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+      declare <vscale x 4 x i32>  @llvm.vp.ssub.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare <256 x i64>  @llvm.vp.ssub.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated signed saturating subtraction of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.ssub.sat``' intrinsic performs ssub.sat (:ref:`ssub.sat <int_ssub_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+      %r = call <4 x i32> @llvm.vp.ssub.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+      ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+      %t = call <4 x i32> @llvm.ssub.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+      %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
+.. _int_vp_usub_sat:
+
+'``llvm.vp.usub.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <16 x i32>  @llvm.vp.usub.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+      declare <vscale x 4 x i32>  @llvm.vp.usub.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare <256 x i64>  @llvm.vp.usub.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated unsigned saturating subtraction of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.usub.sat``' intrinsic performs usub.sat (:ref:`usub.sat <int_usub_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+      %r = call <4 x i32> @llvm.vp.usub.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+      ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+      %t = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+      %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
 .. _int_vp_fshl:
 
 '``llvm.vp.fshl.*``' Intrinsics
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 8c0d4d5db32d88..d7c1ce153a6c80 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1933,6 +1933,26 @@ let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
                                LLVMMatchType<0>,
                                LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
                                llvm_i32_ty]>;
+  def int_vp_sadd_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+                             [ LLVMMatchType<0>,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+                               llvm_i32_ty]>;
+  def int_vp_uadd_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+                             [ LLVMMatchType<0>,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+                               llvm_i32_ty]>;
+  def int_vp_ssub_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+                             [ LLVMMatchType<0>,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+                               llvm_i32_ty]>;
+  def int_vp_usub_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+                             [ LLVMMatchType<0>,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+                               llvm_i32_ty]>;
 
   // Floating-point arithmetic
   def int_vp_fadd : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def
index 3b32b60609f536..4089acf9ec3f05 100644
--- a/llvm/include/llvm/IR/VPIntrinsics.def
+++ b/llvm/include/llvm/IR/VPIntrinsics.def
@@ -293,6 +293,30 @@ BEGIN_REGISTER_VP(vp_fshr, 3, 4, VP_FSHR, -1)
 VP_PROPERTY_FUNCTIONAL_INTRINSIC(fshr)
 VP_PROPERTY_FUNCTIONAL_SDOPC(FSHR)
 END_REGISTER_VP(vp_fshr, VP_FSHR)
+
+// llvm.vp.sadd.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_sadd_sat, 2, 3, VP_SADDSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(sadd_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(SADDSAT)
+END_REGISTER_VP(vp_sadd_sat, VP_SADDSAT)
+
+// llvm.vp.uadd.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_uadd_sat, 2, 3, VP_UADDSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(uadd_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(UADDSAT)
+END_REGISTER_VP(vp_uadd_sat, VP_UADDSAT)
+
+// llvm.vp.ssub.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_ssub_sat, 2, 3, VP_SSUBSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(ssub_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(SSUBSAT)
+END_REGISTER_VP(vp_ssub_sat, VP_SSUBSAT)
+
+// llvm.vp.usub.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_usub_sat, 2, 3, VP_USUBSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(usub_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(USUBSAT)
+END_REGISTER_VP(vp_usub_sat, VP_USUBSAT)
 ///// } Integer Arithmetic
 
 ///// Floating-Point Arithmetic {
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2a09e44e192979..318e1c12c3d56f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -76,6 +76,8 @@
 #include <utility>
 #include <variant>
 
+#include "MatchContext.h"
+
 using namespace llvm;
 
 #define DEBUG_TYPE "dagcombine"
@@ -888,141 +890,6 @@ class WorklistInserter : public SelectionDAG::DAGUpdateListener {
   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
 };
 
-class EmptyMatchContext {
-  SelectionDAG &DAG;
-  const TargetLowering &TLI;
-
-public:
-  EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
-      : DAG(DAG), TLI(TLI) {}
-
-  bool match(SDValue OpN, unsigned Opcode) const {
-    return Opcode == OpN->getOpcode();
-  }
-
-  // Same as SelectionDAG::getNode().
-  template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
-    return DAG.getNode(std::forward<ArgT>(Args)...);
-  }
-
-  bool isOperationLegalOrCustom(unsigned Op, EVT VT,
-                                bool LegalOnly = false) const {
-    return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
-  }
-};
-
-class VPMatchContext {
-  SelectionDAG &DAG;
-  const TargetLowering &TLI;
-  SDValue RootMaskOp;
-  SDValue RootVectorLenOp;
-
-public:
-  VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
-      : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
-    assert(Root->isVPOpcode());
-    if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
-      RootMaskOp = Root->getOperand(*RootMaskPos);
-    else if (Root->getOpcode() == ISD::VP_SELECT)
-      RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
-                                          Root->getOperand(0).getValueType());
-
-    if (auto RootVLenPos =
-            ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
-      RootVectorLenOp = Root->getOperand(*RootVLenPos);
-  }
-
-  /// whether \p OpVal is a node that is functionally compatible with the
-  /// NodeType \p Opc
-  bool match(SDValue OpVal, unsigned Opc) const {
-    if (!OpVal->isVPOpcode())
-      return OpVal->getOpcode() == Opc;
-
-    auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(),
-                                           !OpVal->getFlags().hasNoFPExcept());
-    if (BaseOpc != Opc)
-      return false;
-
-    // Make sure the mask of OpVal is true mask or is same as Root's.
-    unsigned VPOpcode = OpVal->getOpcode();
-    if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
-      SDValue MaskOp = OpVal.getOperand(*MaskPos);
-      if (RootMaskOp != MaskOp &&
-          !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
-        return false;
-    }
-
-    // Make sure the EVL of OpVal is same as Root's.
-    if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
-      if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
-        return false;
-    return true;
-  }
-
-  // Specialize based on number of operands.
-  // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
-  // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
-  // DAG.getNode(Opcode, DL, VT); }
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {Operand, RootMaskOp, RootVectorLenOp});
-  }
-
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
-                  SDValue N2) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, RootMaskOp, RootVectorLenOp});
-  }
-
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
-                  SDValue N2, SDValue N3) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, N3, RootMaskOp, RootVectorLenOp});
-  }
-
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
-                  SDNodeFlags Flags) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
-    return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
-                       Flags);
-  }
-
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
-                  SDValue N2, SDNodeFlags Flags) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
-    return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
-                       Flags);
-  }
-
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
-                  SDValue N2, SDValue N3, SDNodeFlags Flags) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
-  }
-
-  bool isOperationLegalOrCustom(unsigned Op, EVT VT,
-                                bool LegalOnly = false) const {
-    unsigned VPOp = ISD::getVPForBaseOpcode(Op);
-    return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
-  }
-};
-
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index a4ba261686c688..b5b33bb279e8b8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -217,7 +217,15 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::SSUBSAT:
   case ISD::USUBSAT:
   case ISD::SSHLSAT:
-  case ISD::USHLSAT:     Res = PromoteIntRes_ADDSUBSHLSAT(N); break;
+  case ISD::USHLSAT:
+    Res = PromoteIntRes_ADDSUBSHLSAT<EmptyMatchContext>(N);
+    break;
+  case ISD::VP_SADDSAT:
+  case ISD::VP_UADDSAT:
+  case ISD::VP_SSUBSAT:
+  case ISD::VP_USUBSAT:
+    Res = PromoteIntRes_ADDSUBSHLSAT<VPMatchContext>(N);
+    break;
 
   case ISD::SMULFIX:
   case ISD::SMULFIXSAT:
@@ -934,6 +942,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Overflow(SDNode *N) {
   return DAG.getBoolExtOrTrunc(Res.getValue(1), dl, NVT, VT);
 }
 
+template <class MatchContextClass>
 SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
   // If the promoted type is legal, we can convert this to:
   //   1. ANY_EXTEND iN to iM
@@ -945,9 +954,10 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
   SDLoc dl(N);
   SDValue Op1 = N->getOperand(0);
   SDValue Op2 = N->getOperand(1);
+  MatchContextClass matcher(DAG, TLI, N);
   unsigned OldBits = Op1.getScalarValueSizeInBits();
 
-  unsigned Opcode = N->getOpcode();
+  unsigned Opcode = matcher.getRootBaseOpcode();
   bool IsShift = Opcode == ISD::USHLSAT || Opcode == ISD::SSHLSAT;
 
   SDValue Op1Promoted, Op2Promoted;
@@ -968,18 +978,18 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
     APInt MaxVal = APInt::getAllOnes(OldBits).zext(NewBits);
     SDValue SatMax = DAG.getConstant(MaxVal, dl, PromotedType);
     SDValue Add =
-        DAG.getNode(ISD::ADD, dl, PromotedType, Op1Promoted, Op2Promoted);
-    return DAG.getNode(ISD::UMIN, dl, PromotedType, Add, SatMax);
+        matcher.getNode(ISD::ADD, dl, PromotedType, Op1Promoted, Op2Promoted);
+    return matcher.getNode(ISD::UMIN, dl, PromotedType, Add, SatMax);
   }
 
   // USUBSAT can always be promoted as long as we have zero-extended the args.
   if (Opcode == ISD::USUBSAT)
-    return DAG.getNode(ISD::USUBSAT, dl, PromotedType, Op1Promoted,
-                       Op2Promoted);
+    return matcher.getNode(ISD::USUBSAT, dl, PromotedType, Op1Promoted,
+                           Op2Promoted);
 
   // Shift cannot use a min/max expansion, we can't detect overflow if all of
   // the bits have been shifted out.
-  if (IsShift || TLI.isOperationLegal(Opcode, PromotedType)) {
+  if (IsShift || matcher.isOperationLegal(Opcode, PromotedType)) {
     unsigned ShiftOp;
     switch (Opcode) {
     case ISD::SADDSAT:
@@ -1002,11 +1012,11 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
         DAG.getNode(ISD::SHL, dl, PromotedType, Op1Promoted, ShiftAmount);
     if (!IsShift)
       Op2Promoted =
-          DAG.getNode(ISD::SHL, dl, PromotedType, Op2Promoted, ShiftAmount);
+          matcher.getNode(ISD::SHL, dl, PromotedType, Op2Promoted, ShiftAmount);
 
     SDValue Result =
-        DAG.getNode(Opcode, dl, PromotedType, Op1Promoted, Op2Promoted);
-    return DAG.getNode(ShiftOp, dl, PromotedType, Result, ShiftAmount);
+        matcher.getNode(Opcode, dl, PromotedType, Op1Promoted, Op2Promoted);
+    return matcher.getNode(ShiftOp, dl, PromotedType, Result, ShiftAmount);
   }
 
   unsigned AddOp = Opcode == ISD::SADDSAT ? ISD::ADD : ISD::SUB;
@@ -1015,9 +1025,9 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
   SDValue SatMin = DAG.getConstant(MinVal, dl, PromotedType);
   SDValue SatMax = DAG.getConstant(MaxVal,...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82370


More information about the llvm-commits mailing list