[llvm] [RISCV] Move TRUNCATE_VECTOR_VL combine into a helper function. NFC (PR #93574)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 28 09:29:03 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

I plan to add other combines on TRUNCATE_VECTOR_VL.

---
Full diff: https://github.com/llvm/llvm-project/pull/93574.diff


1 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+53-50) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f0e5a7d393b6c..47b1cc1ba6460 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16087,6 +16087,57 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
   return true;
 }
 
+static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
+  // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+  // This would be benefit for the cases where X and Y are both the same value
+  // type of low precision vectors. Since the truncate would be lowered into
+  // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+  // restriction, such pattern would be expanded into a series of "vsetvli"
+  // and "vnsrl" instructions later to reach this point.
+  auto IsTruncNode = [](SDValue V) {
+    if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
+      return false;
+    SDValue VL = V.getOperand(2);
+    auto *C = dyn_cast<ConstantSDNode>(VL);
+    // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
+    bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
+                           (isa<RegisterSDNode>(VL) &&
+                            cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+    return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
+  };
+
+  SDValue Op = N->getOperand(0);
+
+  // We need to first find the inner level of TRUNCATE_VECTOR_VL node
+  // to distinguish such pattern.
+  while (IsTruncNode(Op)) {
+    if (!Op.hasOneUse())
+      return SDValue();
+    Op = Op.getOperand(0);
+  }
+
+  if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse())
+    return SDValue();
+
+  SDValue N0 = Op.getOperand(0);
+  SDValue N1 = Op.getOperand(1);
+  if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() ||
+      N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse())
+    return SDValue();
+
+  SDValue N00 = N0.getOperand(0);
+  SDValue N10 = N1.getOperand(0);
+  if (!N00.getValueType().isVector() ||
+      N00.getValueType() != N10.getValueType() ||
+      N->getValueType(0) != N10.getValueType())
+    return SDValue();
+
+  unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
+  SDValue SMin =
+      DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
+                  DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
+  return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
+}
 
 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
@@ -16304,56 +16355,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
       }
     }
     return SDValue();
-  case RISCVISD::TRUNCATE_VECTOR_VL: {
-    // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
-    // This would be benefit for the cases where X and Y are both the same value
-    // type of low precision vectors. Since the truncate would be lowered into
-    // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
-    // restriction, such pattern would be expanded into a series of "vsetvli"
-    // and "vnsrl" instructions later to reach this point.
-    auto IsTruncNode = [](SDValue V) {
-      if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
-        return false;
-      SDValue VL = V.getOperand(2);
-      auto *C = dyn_cast<ConstantSDNode>(VL);
-      // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
-      bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
-                             (isa<RegisterSDNode>(VL) &&
-                              cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
-      return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
-             IsVLMAXForVMSET;
-    };
-
-    SDValue Op = N->getOperand(0);
-
-    // We need to first find the inner level of TRUNCATE_VECTOR_VL node
-    // to distinguish such pattern.
-    while (IsTruncNode(Op)) {
-      if (!Op.hasOneUse())
-        return SDValue();
-      Op = Op.getOperand(0);
-    }
-
-    if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
-      SDValue N0 = Op.getOperand(0);
-      SDValue N1 = Op.getOperand(1);
-      if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
-          N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
-        SDValue N00 = N0.getOperand(0);
-        SDValue N10 = N1.getOperand(0);
-        if (N00.getValueType().isVector() &&
-            N00.getValueType() == N10.getValueType() &&
-            N->getValueType(0) == N10.getValueType()) {
-          unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
-          SDValue SMin = DAG.getNode(
-              ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
-              DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
-          return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
-        }
-      }
-    }
-    break;
-  }
+  case RISCVISD::TRUNCATE_VECTOR_VL:
+    return combineTruncOfSraSext(N, DAG);
   case ISD::TRUNCATE:
     return performTRUNCATECombine(N, DAG, Subtarget);
   case ISD::SELECT:

``````````

</details>


https://github.com/llvm/llvm-project/pull/93574


More information about the llvm-commits mailing list