[llvm] [RISCV] Combine `(setcc (riscv_selectcc A, B, ...), Y)` to just `(setcc A, B)` when possible (PR #90538)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 29 18:04:04 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-risc-v

Author: Min-Yih Hsu (mshockwave)

<details>
<summary>Changes</summary>

Given `(seteq (riscv_selectcc LHS, RHS, CC, X, Y), X)`, we can turn it into `(setCC LHS, RHS)`.
I think we can generalize this into ISD::SELECT_CC as well.

-------
Right now this PR is stacked on top of #<!-- -->90502 -- but it doesn't have to -- to show changes on the test. The main patch is 018b08a239f0739ca923ca691dc770ebc60c4da6. I'll add more tests tomorrow.

---

Patch is 33.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90538.diff


14 Files Affected:

- (modified) llvm/docs/LangRef.rst (+48) 
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+5) 
- (modified) llvm/include/llvm/IR/Intrinsics.td (+6) 
- (modified) llvm/include/llvm/IR/VPIntrinsics.def (+9) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+9) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+10) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+7) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+42) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+8-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+33) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+109-10) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+1) 
- (added) llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll (+234) 


``````````diff
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 37662f79145d67..f79c1fd9278de3 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -24001,6 +24001,54 @@ Examples:
       %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
 
 
+.. _int_vp_cttz_elts:
+
+'``llvm.vp.cttz.elts.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic. You can use ```llvm.vp.cttz.elts``` on any
+vector of integer elements, both fixed width and scalable.
+
+::
+
+      declare i32  @llvm.vp.cttz.elts.i32.v16i32 (<16 x i32> <op>, i1 <is_zero_poison>, <16 x i1> <mask>, i32 <vector_length>)
+      declare i64  @llvm.vp.cttz.elts.i64.nxv4i32 (<vscale x 4 x i32> <op>, i1 <is_zero_poison>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare i64  @llvm.vp.cttz.elts.i64.v256i1 (<256 x i1> <op>, i1 <is_zero_poison>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+This '```llvm.vp.cttz.elts```' intrinsic counts the number of trailing zero
+elements of a vector. This is basically the vector-predicated version of
+'```llvm.experimental.cttz.elts```'.
+
+Arguments:
+""""""""""
+
+The first argument is the vector to be counted. This argument must be a vector
+with integer element type. The return type must also be an integer type which is
+wide enough to hold the maximum number of elements of the source vector. The
+behavior of this intrinsic is undefined if the return type is not wide enough
+for the number of elements in the input vector.
+
+The second argument is a constant flag that indicates whether the intrinsic
+returns a valid result if the first argument is all zero.
+
+The third operand is the vector mask and has the same number of elements as the
+input vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.cttz.elts``' intrinsic counts the trailing (least
+significant / lowest-numbered) zero elements in the first operand on each
+enabled lane. If the first argument is all zero and the second argument is true,
+the result is poison. Otherwise, it returns the explicit vector length (i.e. the
+fourth operand).
+
 .. _int_vp_sadd_sat:
 
 '``llvm.vp.sadd.sat.*``' Intrinsics
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 661b2841c6ac72..7ed08cfa8a2022 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5307,6 +5307,11 @@ class TargetLowering : public TargetLoweringBase {
   /// \returns The expansion result or SDValue() if it fails.
   SDValue expandVPCTTZ(SDNode *N, SelectionDAG &DAG) const;
 
+  /// Expand VP_CTTZ_ELTS/VP_CTTZ_ELTS_ZERO_UNDEF nodes.
+  /// \param N Node to expand
+  /// \returns The expansion result or SDValue() if it fails.
+  SDValue expandVPCTTZElements(SDNode *N, SelectionDAG &DAG) const;
+
   /// Expand ABS nodes. Expands vector/scalar ABS nodes,
   /// vector nodes can only succeed if all operations are legal/custom.
   /// (ABS x) -> (XOR (ADD x, (SRA x, type_size)), (SRA x, type_size))
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index a2678d69ce4062..28116e5316c96b 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2255,6 +2255,12 @@ let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn, ImmArg<ArgIndex<1>>
                                llvm_i1_ty,
                                LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
                                llvm_i32_ty]>;
+
+  def int_vp_cttz_elts : DefaultAttrsIntrinsic<[ llvm_anyint_ty ],
+                                  [ llvm_anyvector_ty,
+                                    llvm_i1_ty,
+                                    LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>,
+                                    llvm_i32_ty]>;
 }
 
 def int_get_active_lane_mask:
diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def
index 1c2708a9e85437..f1cc8bcae467be 100644
--- a/llvm/include/llvm/IR/VPIntrinsics.def
+++ b/llvm/include/llvm/IR/VPIntrinsics.def
@@ -282,6 +282,15 @@ BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ZERO_UNDEF, -1, vp_cttz_zero_undef, 1, 2)
 END_REGISTER_VP_SDNODE(VP_CTTZ_ZERO_UNDEF)
 END_REGISTER_VP_INTRINSIC(vp_cttz)
 
+// llvm.vp.cttz.elts(x,is_zero_poison,mask,vl)
+BEGIN_REGISTER_VP_INTRINSIC(vp_cttz_elts, 2, 3)
+VP_PROPERTY_NO_FUNCTIONAL
+BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ELTS, 0, vp_cttz_elts, 1, 2)
+END_REGISTER_VP_SDNODE(VP_CTTZ_ELTS)
+BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ELTS_ZERO_UNDEF, 0, vp_cttz_elts_zero_undef, 1, 2)
+END_REGISTER_VP_SDNODE(VP_CTTZ_ELTS_ZERO_UNDEF)
+END_REGISTER_VP_INTRINSIC(vp_cttz_elts)
+
 // llvm.vp.fshl(x,y,z,mask,vlen)
 BEGIN_REGISTER_VP(vp_fshl, 3, 4, VP_FSHL, -1)
 VP_PROPERTY_FUNCTIONAL_INTRINSIC(fshl)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 46e54b5366d66a..5322ea3b6a2d97 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1220,6 +1220,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
     Action = TLI.getOperationAction(
         Node->getOpcode(), Node->getOperand(1).getValueType());
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    Action = TLI.getOperationAction(Node->getOpcode(),
+                                    Node->getOperand(0).getValueType());
+    break;
   default:
     if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
       Action = TLI.getCustomOperationAction(*Node);
@@ -4234,6 +4239,10 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
   case ISD::VECREDUCE_FMINIMUM:
     Results.push_back(TLI.expandVecReduce(Node, DAG));
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    Results.push_back(TLI.expandVPCTTZElements(Node, DAG));
+    break;
   case ISD::GLOBAL_OFFSET_TABLE:
   case ISD::GlobalAddress:
   case ISD::GlobalTLSAddress:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 55f9737bc94dd5..0aa36deda79dcc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -76,6 +76,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::VP_CTTZ:
   case ISD::CTTZ_ZERO_UNDEF:
   case ISD::CTTZ:        Res = PromoteIntRes_CTTZ(N); break;
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+  case ISD::VP_CTTZ_ELTS:
+    Res = PromoteIntRes_VP_CttzElements(N);
+    break;
   case ISD::EXTRACT_VECTOR_ELT:
                          Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break;
   case ISD::LOAD:        Res = PromoteIntRes_LOAD(cast<LoadSDNode>(N)); break;
@@ -724,6 +728,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTTZ(SDNode *N) {
                      N->getOperand(2));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_VP_CttzElements(SDNode *N) {
+  SDLoc DL(N);
+  EVT NewVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  return DAG.getNode(N->getOpcode(), DL, NewVT, N->ops());
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N) {
   SDLoc dl(N);
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 4a2c7b355eb528..49be824deb5134 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -309,6 +309,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_CTLZ(SDNode *N);
   SDValue PromoteIntRes_CTPOP_PARITY(SDNode *N);
   SDValue PromoteIntRes_CTTZ(SDNode *N);
+  SDValue PromoteIntRes_VP_CttzElements(SDNode *N);
   SDValue PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
@@ -912,6 +913,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SplitVecOp_FP_ROUND(SDNode *N);
   SDValue SplitVecOp_FPOpDifferentTypes(SDNode *N);
   SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
+  SDValue SplitVecOp_VP_CttzElements(SDNode *N);
 
   //===--------------------------------------------------------------------===//
   // Vector Widening Support: LegalizeVectorTypes.cpp
@@ -1019,6 +1021,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue WidenVecOp_VECREDUCE_SEQ(SDNode *N);
   SDValue WidenVecOp_VP_REDUCE(SDNode *N);
   SDValue WidenVecOp_ExpOp(SDNode *N);
+  SDValue WidenVecOp_VP_CttzElements(SDNode *N);
 
   /// Helper function to generate a set of operations to perform
   /// a vector operation for a wider type.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8f87ee8e09393a..26cd5482168f9f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -510,6 +510,13 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       if (Action != TargetLowering::Legal)                                     \
         break;                                                                 \
     }                                                                          \
+    /* Defer non-vector results to LegalizeDAG. */                             \
+    /* Remove this after #90522 is landed */                                   \
+    if (ISD::VPID == ISD::VP_CTTZ_ELTS ||                                      \
+        ISD::VPID == ISD::VP_CTTZ_ELTS_ZERO_UNDEF) {                           \
+      Action = TargetLowering::Legal;                                          \
+      break;                                                                   \
+    }                                                                          \
     Action = TLI.getOperationAction(Node->getOpcode(), LegalizeVT);            \
   } break;
 #include "llvm/IR/VPIntrinsics.def"
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 985c9f16ab97cd..cab4dc5f3c1565 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3098,6 +3098,10 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::VP_REDUCE_FMIN:
     Res = SplitVecOp_VP_REDUCE(N, OpNo);
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    Res = SplitVecOp_VP_CttzElements(N);
+    break;
   }
 
   // If the result is null, the sub-method took care of registering results etc.
@@ -4056,6 +4060,29 @@ SDValue DAGTypeLegalizer::SplitVecOp_FP_TO_XINT_SAT(SDNode *N) {
   return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi);
 }
 
+SDValue DAGTypeLegalizer::SplitVecOp_VP_CttzElements(SDNode *N) {
+  SDLoc DL(N);
+  EVT ResVT = N->getValueType(0);
+
+  SDValue Lo, Hi;
+  SDValue VecOp = N->getOperand(0);
+  GetSplitVector(VecOp, Lo, Hi);
+
+  auto [MaskLo, MaskHi] = SplitMask(N->getOperand(1));
+  auto [EVLLo, EVLHi] =
+      DAG.SplitEVL(N->getOperand(2), VecOp.getValueType(), DL);
+  SDValue VLo = DAG.getZExtOrTrunc(EVLLo, DL, ResVT);
+
+  // if VP_CTTZ_ELTS(Lo) != EVLLo => VP_CTTZ_ELTS(Lo).
+  // else => EVLLo + (VP_CTTZ_ELTS(Hi) or VP_CTTZ_ELTS_ZERO_UNDEF(Hi)).
+  SDValue ResLo = DAG.getNode(ISD::VP_CTTZ_ELTS, DL, ResVT, Lo, MaskLo, EVLLo);
+  SDValue ResLoNotEVL =
+      DAG.getSetCC(DL, getSetCCResultType(ResVT), ResLo, VLo, ISD::SETNE);
+  SDValue ResHi = DAG.getNode(N->getOpcode(), DL, ResVT, Hi, MaskHi, EVLHi);
+  return DAG.getSelect(DL, ResVT, ResLoNotEVL, ResLo,
+                       DAG.getNode(ISD::ADD, DL, ResVT, VLo, ResHi));
+}
+
 //===----------------------------------------------------------------------===//
 //  Result Vector Widening
 //===----------------------------------------------------------------------===//
@@ -6161,6 +6188,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::VP_REDUCE_FMIN:
     Res = WidenVecOp_VP_REDUCE(N);
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    Res = WidenVecOp_VP_CttzElements(N);
+    break;
   }
 
   // If Res is null, the sub-method took care of registering the result.
@@ -6924,6 +6955,17 @@ SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) {
                      DAG.getVectorIdxConstant(0, DL));
 }
 
+SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
+  SDLoc DL(N);
+  SDValue Source = GetWidenedVector(N->getOperand(0));
+  EVT SrcVT = Source.getValueType();
+  SDValue Mask =
+      GetWidenedMask(N->getOperand(1), SrcVT.getVectorElementCount());
+
+  return DAG.getNode(N->getOpcode(), DL, N->getValueType(0),
+                     {Source, Mask, N->getOperand(2)}, N->getFlags());
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Widening Utilities
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5caf868c83a296..cfd82a342433fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8076,6 +8076,11 @@ static unsigned getISDForVPIntrinsic(const VPIntrinsic &VPIntrin) {
     ResOPC = IsZeroUndef ? ISD::VP_CTTZ_ZERO_UNDEF : ISD::VP_CTTZ;
     break;
   }
+  case Intrinsic::vp_cttz_elts: {
+    bool IsZeroPoison = cast<ConstantInt>(VPIntrin.getArgOperand(1))->isOne();
+    ResOPC = IsZeroPoison ? ISD::VP_CTTZ_ELTS_ZERO_UNDEF : ISD::VP_CTTZ_ELTS;
+    break;
+  }
 #define HELPER_MAP_VPID_TO_VPSD(VPID, VPSD)                                    \
   case Intrinsic::VPID:                                                        \
     ResOPC = ISD::VPSD;                                                        \
@@ -8428,7 +8433,9 @@ void SelectionDAGBuilder::visitVectorPredicationIntrinsic(
   case ISD::VP_CTLZ:
   case ISD::VP_CTLZ_ZERO_UNDEF:
   case ISD::VP_CTTZ:
-  case ISD::VP_CTTZ_ZERO_UNDEF: {
+  case ISD::VP_CTTZ_ZERO_UNDEF:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+  case ISD::VP_CTTZ_ELTS: {
     SDValue Result =
         DAG.getNode(Opcode, DL, VTs, {OpValues[0], OpValues[2], OpValues[3]});
     setValue(&VPIntrin, Result);
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index cdc1227fd572dc..336d89fbcf638e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -9074,6 +9074,39 @@ SDValue TargetLowering::expandVPCTTZ(SDNode *Node, SelectionDAG &DAG) const {
   return DAG.getNode(ISD::VP_CTPOP, dl, VT, Tmp, Mask, VL);
 }
 
+SDValue TargetLowering::expandVPCTTZElements(SDNode *N,
+                                             SelectionDAG &DAG) const {
+  // %cond = to_bool_vec %source
+  // %splat = splat /*val=*/VL
+  // %tz = step_vector
+  // %v = vp.select %cond, /*true=*/tz, /*false=*/%splat
+  // %r = vp.reduce.umin %v
+  SDLoc DL(N);
+  SDValue Source = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+  SDValue EVL = N->getOperand(2);
+  EVT SrcVT = Source.getValueType();
+  EVT ResVT = N->getValueType(0);
+  EVT ResVecVT =
+      EVT::getVectorVT(*DAG.getContext(), ResVT, SrcVT.getVectorElementCount());
+
+  // Convert to boolean vector.
+  if (SrcVT.getScalarType() != MVT::i1) {
+    SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
+    SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                             SrcVT.getVectorElementCount());
+    Source = DAG.getNode(ISD::VP_SETCC, DL, SrcVT, Source, AllZero,
+                         DAG.getCondCode(ISD::SETNE), Mask, EVL);
+  }
+
+  SDValue ExtEVL = DAG.getZExtOrTrunc(EVL, DL, ResVT);
+  SDValue Splat = DAG.getSplat(ResVecVT, DL, ExtEVL);
+  SDValue StepVec = DAG.getStepVector(DL, ResVecVT);
+  SDValue Select =
+      DAG.getNode(ISD::VP_SELECT, DL, ResVecVT, Source, StepVec, Splat, EVL);
+  return DAG.getNode(ISD::VP_REDUCE_UMIN, DL, ResVT, ExtEVL, Select, Mask, EVL);
+}
+
 SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
                                   bool IsNegative) const {
   SDLoc dl(N);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 68f4ec5ef49f31..c61e477d79e110 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -28,6 +28,7 @@
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 #include "llvm/CodeGen/ValueTypes.h"
@@ -698,7 +699,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
         ISD::VP_SMAX,        ISD::VP_UMIN,        ISD::VP_UMAX,
         ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
         ISD::VP_SADDSAT,     ISD::VP_UADDSAT,     ISD::VP_SSUBSAT,
-        ISD::VP_USUBSAT};
+        ISD::VP_USUBSAT,     ISD::VP_CTTZ_ELTS,   ISD::VP_CTTZ_ELTS_ZERO_UNDEF};
 
     static const unsigned FloatingPointVPOps[] = {
         ISD::VP_FADD,        ISD::VP_FSUB,        ISD::VP_FMUL,
@@ -759,6 +760,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
           {ISD::SELECT_CC, ISD::VSELECT, ISD::VP_MERGE, ISD::VP_SELECT}, VT,
           Expand);
 
+      setOperationAction({ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF}, VT,
+                         Custom);
+
       setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR}, VT, Custom);
 
       setOperationAction(
@@ -5341,6 +5345,44 @@ RISCVTargetLowering::lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op,
   return Res;
 }
 
+SDValue RISCVTargetLowering::lowerVPCttzElements(SDValue Op,
+                                                 SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  MVT XLenVT = Subtarget.getXLenVT();
+  SDValue Source = Op->getOperand(0);
+  MVT SrcVT = Source.getSimpleValueType();
+  SDValue Mask = Op->getOperand(1);
+  SDValue EVL = Op->getOperand(2);
+
+  if (SrcVT.isFixedLengthVector()) {
+    MVT ContainerVT = getContainerForFixedLengthVector(SrcVT);
+    Source = convertToScalableVector(ContainerVT, Source, DAG, Subtarget);
+    Mask = convertToScalableVector(getMaskTypeFor(ContainerVT), Mask, DAG,
+                                   Subtarget);
+    SrcVT = ContainerVT;
+  }
+
+  // Convert to boolean vector.
+  if (SrcVT.getScalarType() != MVT::i1) {
+    SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
+    SrcVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorElementCount());
+    Source = DAG.getNode(RISCVISD::SETCC_VL, DL, SrcVT,
+                         {Source, AllZero, DAG.getCondCode(ISD::SETNE),
+                          DAG.getUNDEF(SrcVT), Mask, EVL});
+  }
+
+  SDValue Res = DAG.getNode(RISCVISD::VFIRST_VL, DL, XLenVT, Source, Mask, EVL);
+  if (Op->getOpcode() == ISD::VP_CTTZ_ELTS_ZERO_UNDEF)
+    // In this case, we can interpret poison as -1, so nothing to do further.
+    return Res;
+
+  // Convert -1 to VL.
+  SDValue SetCC =
+      DAG.getSetCC(DL, XLenVT, Res, DAG.getConstant(0, DL, XLenVT), ISD::SETLT);
+  Res = DAG.getSelect(DL, XLenVT, SetCC, EVL, Res);
+  return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
+}
+
 // While RVV has alignment restrictions, we should always be able to load as a
 // legal equivalently-sized byte-typed vector instead. This method is
 // responsible for re-expressing a ISD::LOAD via a correctly-aligned type. If
@@ -6595,6 +6637,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     if (Op.getOperand(1).getValueType().getVectorElementType() == MVT::i1)
       return lowerVectorMaskVecReduction(Op, DAG, /*IsVP*/ true);
     return lowerVPREDUCE(Op, DAG);
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    return lowerVPCttzElements(Op, DAG);
   case ISD::UNDEF: {
     MVT ContainerVT = getContainerForFixedLengthVector(Op.getSimpleValueType());
     return convertFromScalableVector(Op.getSimpleValueType(),
@@ -13634,9 +13679,69 @@ st...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/90538


More information about the llvm-commits mailing list