[llvm] [DAG] Support saturated truncate (PR #99418)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 17 18:44:59 PDT 2024


https://github.com/ParkHanbum created https://github.com/llvm/llvm-project/pull/99418

`truncate` is `saturated` if no additional conversion is required
between the target and return values. if the target is `saturated`
when attempting to crop from a `vector`, there is an opportunity
to optimize it.

previously, each architecture had an attemping optimization, so there
was redundant code.

this patch implements common logic by adding `ISD::TRUNCATE_[US]SAT`
to indicate saturated truncate.

>From 66b93322069570310c029954dc17dfb7f7d99ec4 Mon Sep 17 00:00:00 2001
From: hanbeom <kese111 at gmail.com>
Date: Tue, 16 Jul 2024 13:52:29 +0900
Subject: [PATCH 1/3] [DAG] Support saturated truncate

`truncate` is `saturated` if no additional conversion is required
between the target and return values. if the target is `saturated`
when attempting to crop from a `vector`, there is an opportunity
to optimize it.

previously, each architecture had an attemping optimization, so there
was redundant code.

this patch implements common logic by adding `ISD::TRUNCATE_[US]SAT`
to indicate saturated truncate.

Fixes #85903
---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |   3 +
 .../include/llvm/Target/TargetSelectionDAG.td |   2 +
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 132 +++++++++++++++++-
 .../SelectionDAG/SelectionDAGDumper.cpp       |   2 +
 llvm/lib/CodeGen/TargetLoweringBase.cpp       |   4 +
 5 files changed, 142 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index e6b10209b4767..0b36e5b40da73 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -804,6 +804,9 @@ enum NodeType {
 
   /// TRUNCATE - Completely drop the high bits.
   TRUNCATE,
+  /// TRUNCATE_[SU]SAT - Truncate for saturated operand
+  TRUNCATE_SSAT,
+  TRUNCATE_USAT,
 
   /// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
   /// depends on the first letter) to floating point.
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 133c9b113e51b..a5242694c9507 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -471,6 +471,8 @@ def sext       : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
 def zext       : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
 def anyext     : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
 def trunc      : SDNode<"ISD::TRUNCATE"   , SDTIntTruncOp>;
+def truncssat  : SDNode<"ISD::TRUNCATE_SSAT", SDTIntTruncOp>;
+def truncusat  : SDNode<"ISD::TRUNCATE_USAT", SDTIntTruncOp>;
 def bitconvert : SDNode<"ISD::BITCAST"    , SDTUnaryOp>;
 def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
 def freeze     : SDNode<"ISD::FREEZE"     , SDTFreeze>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 302ad128f4f53..967f313c9885e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -486,6 +486,8 @@ namespace {
     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
     SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
     SDValue visitTRUNCATE(SDNode *N);
+    SDValue visitTRUNCATE_SSAT(SDNode *N);
+    SDValue visitTRUNCATE_USAT(SDNode *N);
     SDValue visitBITCAST(SDNode *N);
     SDValue visitFREEZE(SDNode *N);
     SDValue visitBUILD_PAIR(SDNode *N);
@@ -1907,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::ZERO_EXTEND_VECTOR_INREG:
   case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
   case ISD::TRUNCATE:           return visitTRUNCATE(N);
+  case ISD::TRUNCATE_SSAT:      return visitTRUNCATE_SSAT(N);
+  case ISD::TRUNCATE_USAT:      return visitTRUNCATE_USAT(N);
   case ISD::BITCAST:            return visitBITCAST(N);
   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
   case ISD::FADD:               return visitFADD(N);
@@ -13154,7 +13158,8 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
   unsigned CastOpcode = Cast->getOpcode();
   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
-          CastOpcode == ISD::FP_ROUND) &&
+          CastOpcode == ISD::TRUNCATE_SSAT ||
+          CastOpcode == ISD::TRUNCATE_USAT || CastOpcode == ISD::FP_ROUND) &&
          "Unexpected opcode for vector select narrowing/widening");
 
   // We only do this transform before legal ops because the pattern may be
@@ -14867,6 +14872,119 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
+  EVT VT = N->getValueType(0);
+  SDValue N0 = N->getOperand(0);
+  SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
+  if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
+      FPInstr.getOpcode() == ISD::FP_TO_UINT) {
+    EVT FPVT = FPInstr.getOperand(0).getValueType();
+    if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
+                                                          FPVT, VT))
+      return SDValue();
+    SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
+                              FPInstr.getOperand(0),
+                              DAG.getValueType(VT.getScalarType()));
+    return Sat;
+  }
+
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitTRUNCATE_SSAT(SDNode *N) { return SDValue(); }
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+///   Return the source value x to be truncated or SDValue() if the pattern was
+///   not matched.
+///
+/// 2. (truncate (smin (smax (x, C1), C2)) to dest_type),
+///   where C1 >= 0 and C2 is unsigned max of destination type.
+///
+///    (truncate (smax (smin (x, C2), C1)) to dest_type)
+///   where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
+///
+///   These two patterns are equivalent to:
+///   (truncate (umin (smax(x, C1), unsigned_max_of_dest_type)) to dest_type)
+///   So return the smax(x, C1) value to be truncated or SDValue() if the
+///   pattern was not matched.
+static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+                                 const SDLoc &DL) {
+  EVT InVT = In.getValueType();
+
+  // Saturation with truncation. We truncate from InVT to VT.
+  assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+         "Unexpected types for truncate operation");
+
+  // Match min/max and return limit value as a parameter.
+  auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+    if (V.getOpcode() == Opcode &&
+        ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+      return V.getOperand(0);
+    return SDValue();
+  };
+
+  APInt C1, C2;
+  if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
+    // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+    // the element size of the destination type.
+    if (C2.isMask(VT.getScalarSizeInBits()))
+      return UMin;
+
+  if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
+    if (MatchMinMax(SMin, ISD::SMAX, C1))
+      if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
+        return SMin;
+
+  if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
+    if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
+      if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
+          C2.uge(C1))
+        return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
+
+  return SDValue();
+}
+
+/// Detect patterns of truncation with signed saturation:
+/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
+///                  signed_max_of_dest_type)) to dest_type)
+/// or:
+/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
+///                  signed_min_of_dest_type)) to dest_type).
+/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
+/// Return the source value to be truncated or SDValue() if the pattern was not
+/// matched.
+static SDValue detectSSatPattern(SDValue In, EVT VT) {
+  unsigned NumDstBits = VT.getScalarSizeInBits();
+  unsigned NumSrcBits = In.getScalarValueSizeInBits();
+  assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+  auto MatchMinMax = [](SDValue V, unsigned Opcode,
+                        const APInt &Limit) -> SDValue {
+    APInt C;
+    if (V.getOpcode() == Opcode &&
+        ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
+      return V.getOperand(0);
+    return SDValue();
+  };
+
+  APInt SignedMax, SignedMin;
+  SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+  SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+  if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
+    if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
+      return SMax;
+    }
+  }
+  if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
+    if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
+      return SMin;
+    }
+  }
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
@@ -14874,6 +14992,18 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   bool isLE = DAG.getDataLayout().isLittleEndian();
   SDLoc DL(N);
 
+  if (!LegalOperations && N->getOpcode() == ISD::TRUNCATE) {
+    if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT, SrcVT)) {
+      if (SDValue SSatVal = detectSSatPattern(N0, VT))
+        return DAG.getNode(ISD::TRUNCATE_SSAT, DL, VT, SSatVal);
+    }
+
+    if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_USAT, SrcVT)) {
+      if (SDValue USatVal = detectUSatPattern(N0, VT, DAG, DL))
+        return DAG.getNode(ISD::TRUNCATE_USAT, DL, VT, USatVal);
+    }
+  }
+
   // trunc(undef) = undef
   if (N0.isUndef())
     return DAG.getUNDEF(VT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index cc8de3a217f82..d3ad6c8acf4f1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -380,6 +380,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::SIGN_EXTEND_VECTOR_INREG:   return "sign_extend_vector_inreg";
   case ISD::ZERO_EXTEND_VECTOR_INREG:   return "zero_extend_vector_inreg";
   case ISD::TRUNCATE:                   return "truncate";
+  case ISD::TRUNCATE_SSAT:              return "truncate_ssat";
+  case ISD::TRUNCATE_USAT:              return "truncate_usat";
   case ISD::FP_ROUND:                   return "fp_round";
   case ISD::STRICT_FP_ROUND:            return "strict_fp_round";
   case ISD::FP_EXTEND:                  return "fp_extend";
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index bf031c00a2449..3e855d5e450df 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -718,6 +718,10 @@ void TargetLoweringBase::initActions() {
     // Absolute difference
     setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);
 
+    // Saturated trunc
+    setOperationAction(ISD::TRUNCATE_SSAT, VT, Expand);
+    setOperationAction(ISD::TRUNCATE_USAT, VT, Expand);
+
     // These default to Expand so they will be expanded to CTLZ/CTTZ by default.
     setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
                        Expand);

>From c80af0333a0c241dfe6d333e711e7b11be20ea08 Mon Sep 17 00:00:00 2001
From: hanbeom <kese111 at gmail.com>
Date: Tue, 16 Jul 2024 14:05:55 +0900
Subject: [PATCH 2/3] [AArch64] Support saturated truncate

Add support for `ISD::TRUNCATE_[US]SAT`.
---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp |  2 ++
 llvm/lib/Target/AArch64/AArch64InstrInfo.td     | 16 ++++++++++++++++
 2 files changed, 18 insertions(+)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index df9b0ae1a632f..504bbaed1c8aa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1274,6 +1274,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::AVGCEILU, VT, Legal);
       setOperationAction(ISD::ABDS, VT, Legal);
       setOperationAction(ISD::ABDU, VT, Legal);
+      setOperationAction(ISD::TRUNCATE_SSAT, VT, Legal);
+      setOperationAction(ISD::TRUNCATE_USAT, VT, Legal);
     }
 
     // Vector reductions
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index dd11f74882115..322219607407b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -5343,9 +5343,13 @@ def VImm8000: PatLeaf<(AArch64mvni_msl (i32 127), (i32 264))>;
 // trunc(umin(X, 255)) -> UQXTRN v8i8
 def : Pat<(v8i8 (trunc (umin (v8i16 V128:$Vn), (v8i16 VImmFF)))),
           (UQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncusat (v8i16 V128:$Vn))),
+          (UQXTNv8i8 V128:$Vn)>;
 // trunc(umin(X, 65535)) -> UQXTRN v4i16
 def : Pat<(v4i16 (trunc (umin (v4i32 V128:$Vn), (v4i32 VImmFFFF)))),
           (UQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncusat (v4i32 V128:$Vn))),
+          (UQXTNv4i16 V128:$Vn)>;
 // trunc(smin(smax(X, -128), 128)) -> SQXTRN
 //  with reversed min/max
 def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
@@ -5354,6 +5358,8 @@ def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
 def : Pat<(v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
                              (v8i16 VImm80)))),
           (SQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncssat (v8i16 V128:$Vn))),
+          (SQXTNv8i8 V128:$Vn)>;
 // trunc(smin(smax(X, -32768), 32767)) -> SQXTRN
 //  with reversed min/max
 def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
@@ -5362,6 +5368,8 @@ def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
 def : Pat<(v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
                               (v4i32 VImm8000)))),
           (SQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncssat (v4i32 V128:$Vn))),
+          (SQXTNv4i16 V128:$Vn)>;
 
 // concat_vectors(Vd, trunc(smin(smax Vm, -128), 127) ~> SQXTN2(Vd, Vn)
 // with reversed min/max
@@ -5375,6 +5383,10 @@ def : Pat<(v16i8 (concat_vectors
                  (v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
                                           (v8i16 VImm80)))))),
           (SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v16i8 (concat_vectors
+                 (v8i8 V64:$Vd),
+                 (v8i8 (truncssat (v8i16 V128:$Vn))))),
+          (SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
 
 // concat_vectors(Vd, trunc(smin(smax Vm, -32768), 32767) ~> SQXTN2(Vd, Vn)
 // with reversed min/max
@@ -5388,6 +5400,10 @@ def : Pat<(v8i16 (concat_vectors
                  (v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
                                            (v4i32 VImm8000)))))),
           (SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v8i16 (concat_vectors
+                 (v4i16 V64:$Vd),
+                 (v4i16 (truncssat (v4i32 V128:$Vn))))),
+          (SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
 
 // Select BSWAP vector instructions into REV instructions
 def : Pat<(v4i16 (bswap (v4i16 V64:$Rn))), 

>From 108a26c9a7a771d23a7ada07172dccabf5e0a89c Mon Sep 17 00:00:00 2001
From: hanbeom <kese111 at gmail.com>
Date: Tue, 16 Jul 2024 14:14:40 +0900
Subject: [PATCH 3/3] [RISCV] Support saturated truncate

Add support for `ISD::TRUNCATE_[US]SAT`.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 40 +++++++++++++++------
 llvm/lib/Target/RISCV/RISCVISelLowering.h   |  2 ++
 2 files changed, 32 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 953196a586b6e..3b54416d1b5b2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -853,7 +853,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
       // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
       // nodes which truncate by one power of two at a time.
-      setOperationAction(ISD::TRUNCATE, VT, Custom);
+      setOperationAction(
+          {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT, Custom);
 
       // Custom-lower insert/extract operations to simplify patterns.
       setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
@@ -1168,7 +1169,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
         setOperationAction(ISD::SELECT, VT, Custom);
 
-        setOperationAction(ISD::TRUNCATE, VT, Custom);
+        setOperationAction(
+            {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT,
+            Custom);
 
         setOperationAction(ISD::BITCAST, VT, Custom);
 
@@ -1479,8 +1482,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
 
   if ((Subtarget.hasStdExtZbs() && Subtarget.is64Bit()) ||
-      Subtarget.hasStdExtV())
+      Subtarget.hasStdExtV()) {
     setTargetDAGCombine(ISD::TRUNCATE);
+    setTargetDAGCombine(ISD::TRUNCATE_SSAT);
+    setTargetDAGCombine(ISD::TRUNCATE_USAT);
+  }
 
   if (Subtarget.hasStdExtZbkb())
     setTargetDAGCombine(ISD::BITREVERSE);
@@ -6092,7 +6098,7 @@ static bool hasMergeOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    130 &&
+                    132 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -6118,7 +6124,7 @@ static bool hasMaskOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    130 &&
+                    132 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -6389,6 +6395,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap);
   }
   case ISD::TRUNCATE:
+  case ISD::TRUNCATE_SSAT:
+  case ISD::TRUNCATE_USAT:
     // Only custom-lower vector truncates
     if (!Op.getSimpleValueType().isVector())
       return Op;
@@ -8275,11 +8283,15 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
 
   LLVMContext &Context = *DAG.getContext();
   const ElementCount Count = ContainerVT.getVectorElementCount();
+  unsigned NewOpc = RISCVISD::TRUNCATE_VECTOR_VL;
+  if (Op.getOpcode() == ISD::TRUNCATE_SSAT)
+    NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
+  else if (Op.getOpcode() == ISD::TRUNCATE_USAT)
+    NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
   do {
     SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
     EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
-    Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
-                         Mask, VL);
+    Result = DAG.getNode(NewOpc, DL, ResultVT, Result, Mask, VL);
   } while (SrcEltVT != DstEltVT);
 
   if (SrcVT.isFixedLengthVector())
@@ -16512,7 +16524,9 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
 // minimum value.
 static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
                                     const RISCVSubtarget &Subtarget) {
-  assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
+  assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL ||
+         N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT ||
+         N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT);
 
   MVT VT = N->getSimpleValueType(0);
 
@@ -16617,9 +16631,11 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
 
   SDValue Val;
   unsigned ClipOpc;
-  if ((Val = DetectUSatPattern(Src)))
+
+  Val = N->getOperand(0);
+  if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT)
     ClipOpc = RISCVISD::VNCLIPU_VL;
-  else if ((Val = DetectSSatPattern(Src)))
+  else if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT)
     ClipOpc = RISCVISD::VNCLIP_VL;
   else
     return SDValue();
@@ -16857,6 +16873,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     }
     return SDValue();
   case RISCVISD::TRUNCATE_VECTOR_VL:
+  case RISCVISD::TRUNCATE_VECTOR_VL_SSAT:
+  case RISCVISD::TRUNCATE_VECTOR_VL_USAT:
     if (SDValue V = combineTruncOfSraSext(N, DAG))
       return V;
     return combineTruncToVnclip(N, DAG, Subtarget);
@@ -20433,6 +20451,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
   NODE_NAME_CASE(READ_VLENB)
   NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
   NODE_NAME_CASE(VSLIDEUP_VL)
   NODE_NAME_CASE(VSLIDE1UP_VL)
   NODE_NAME_CASE(VSLIDEDOWN_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 0b0ad9229f0b3..3d582fcdaf64b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -181,6 +181,8 @@ enum NodeType : unsigned {
   // Truncates a RVV integer vector by one power-of-two. Carries both an extra
   // mask and VL operand.
   TRUNCATE_VECTOR_VL,
+  TRUNCATE_VECTOR_VL_SSAT,
+  TRUNCATE_VECTOR_VL_USAT,
   // Matches the semantics of vslideup/vslidedown. The first operand is the
   // pass-thru operand, the second is the source vector, the third is the XLenVT
   // index (either constant or non-constant), the fourth is the mask, the fifth



More information about the llvm-commits mailing list