[llvm] [RISCV] Verify the VL and Mask on the outer TRUNCATE_VECTOR_VL in combineTruncOfSraSext. (PR #93578)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue May 28 16:12:33 PDT 2024


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/93578

>From e10966f1484d655c9ca532e3942cde38acc0328a Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 28 May 2024 09:22:05 -0700
Subject: [PATCH 1/3] [RISCV] Move TRUNCATE_VECTOR_VL combine into a helper
 function. NFC

I plan to add other combines on TRUNCATE_VECTOR_VL.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 103 ++++++++++----------
 1 file changed, 53 insertions(+), 50 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f0e5a7d393b6c..47b1cc1ba6460 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16087,6 +16087,57 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
   return true;
 }
 
+static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
+  // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+  // This would be benefit for the cases where X and Y are both the same value
+  // type of low precision vectors. Since the truncate would be lowered into
+  // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+  // restriction, such pattern would be expanded into a series of "vsetvli"
+  // and "vnsrl" instructions later to reach this point.
+  auto IsTruncNode = [](SDValue V) {
+    if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
+      return false;
+    SDValue VL = V.getOperand(2);
+    auto *C = dyn_cast<ConstantSDNode>(VL);
+    // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
+    bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
+                           (isa<RegisterSDNode>(VL) &&
+                            cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+    return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
+  };
+
+  SDValue Op = N->getOperand(0);
+
+  // We need to first find the inner level of TRUNCATE_VECTOR_VL node
+  // to distinguish such pattern.
+  while (IsTruncNode(Op)) {
+    if (!Op.hasOneUse())
+      return SDValue();
+    Op = Op.getOperand(0);
+  }
+
+  if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse())
+    return SDValue();
+
+  SDValue N0 = Op.getOperand(0);
+  SDValue N1 = Op.getOperand(1);
+  if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() ||
+      N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse())
+    return SDValue();
+
+  SDValue N00 = N0.getOperand(0);
+  SDValue N10 = N1.getOperand(0);
+  if (!N00.getValueType().isVector() ||
+      N00.getValueType() != N10.getValueType() ||
+      N->getValueType(0) != N10.getValueType())
+    return SDValue();
+
+  unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
+  SDValue SMin =
+      DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
+                  DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
+  return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
+}
 
 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
@@ -16304,56 +16355,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
       }
     }
     return SDValue();
-  case RISCVISD::TRUNCATE_VECTOR_VL: {
-    // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
-    // This would be benefit for the cases where X and Y are both the same value
-    // type of low precision vectors. Since the truncate would be lowered into
-    // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
-    // restriction, such pattern would be expanded into a series of "vsetvli"
-    // and "vnsrl" instructions later to reach this point.
-    auto IsTruncNode = [](SDValue V) {
-      if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
-        return false;
-      SDValue VL = V.getOperand(2);
-      auto *C = dyn_cast<ConstantSDNode>(VL);
-      // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
-      bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
-                             (isa<RegisterSDNode>(VL) &&
-                              cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
-      return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
-             IsVLMAXForVMSET;
-    };
-
-    SDValue Op = N->getOperand(0);
-
-    // We need to first find the inner level of TRUNCATE_VECTOR_VL node
-    // to distinguish such pattern.
-    while (IsTruncNode(Op)) {
-      if (!Op.hasOneUse())
-        return SDValue();
-      Op = Op.getOperand(0);
-    }
-
-    if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
-      SDValue N0 = Op.getOperand(0);
-      SDValue N1 = Op.getOperand(1);
-      if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
-          N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
-        SDValue N00 = N0.getOperand(0);
-        SDValue N10 = N1.getOperand(0);
-        if (N00.getValueType().isVector() &&
-            N00.getValueType() == N10.getValueType() &&
-            N->getValueType(0) == N10.getValueType()) {
-          unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
-          SDValue SMin = DAG.getNode(
-              ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
-              DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
-          return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
-        }
-      }
-    }
-    break;
-  }
+  case RISCVISD::TRUNCATE_VECTOR_VL:
+    return combineTruncOfSraSext(N, DAG);
   case ISD::TRUNCATE:
     return performTRUNCATECombine(N, DAG, Subtarget);
   case ISD::SELECT:

>From 4e227983c1e3c290724f09e4968610e7b0c21689 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 28 May 2024 09:55:48 -0700
Subject: [PATCH 2/3] [RISCV] Verify the VL and Mask on the outer
 TRUNCATE_VECTOR_VL in combineTruncOfSraSext.

We checked the VL and mask of any additional TRUNCATE_VECTOR_VL
nodes we peek through, but not the outermost.

This moves the check to the outer node and then verifies all the
additional nodes have the same VL and Mask.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 23 ++++++++++++---------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 47b1cc1ba6460..288e874276e07 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16088,22 +16088,25 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
 }
 
 static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
+  SDValue Mask = N->getOperand(1);
+  SDValue VL = N->getOperand(2);
+
+  bool IsVLMAX = isAllOnesConstant(VL) ||
+                 (isa<RegisterSDNode>(VL) &&
+                  cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+  if (!IsVLMAX || Mask.getOpcode() != RISCVISD::VMSET_VL ||
+      Mask.getOperand(0) != VL)
+    return SDValue();
+
   // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
   // This would be benefit for the cases where X and Y are both the same value
   // type of low precision vectors. Since the truncate would be lowered into
   // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
   // restriction, such pattern would be expanded into a series of "vsetvli"
   // and "vnsrl" instructions later to reach this point.
-  auto IsTruncNode = [](SDValue V) {
-    if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
-      return false;
-    SDValue VL = V.getOperand(2);
-    auto *C = dyn_cast<ConstantSDNode>(VL);
-    // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
-    bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
-                           (isa<RegisterSDNode>(VL) &&
-                            cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
-    return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
+  auto IsTruncNode = [&](SDValue V) {
+    return V.getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL &&
+           V.getOperand(1) == Mask && V.getOperand(2) == VL;
   };
 
   SDValue Op = N->getOperand(0);

>From 78777ec5442cd9a71c639ce685512e699241200b Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 28 May 2024 16:12:22 -0700
Subject: [PATCH 3/3] fixup! move comment.

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6a9a61e480294..f4da46f82a810 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16128,6 +16128,12 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
   return true;
 }
 
+// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+// This would be benefit for the cases where X and Y are both the same value
+// type of low precision vectors. Since the truncate would be lowered into
+// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+// restriction, such pattern would be expanded into a series of "vsetvli"
+// and "vnsrl" instructions later to reach this point.
 static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
   SDValue Mask = N->getOperand(1);
   SDValue VL = N->getOperand(2);
@@ -16139,12 +16145,6 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
       Mask.getOperand(0) != VL)
     return SDValue();
 
-  // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
-  // This would be benefit for the cases where X and Y are both the same value
-  // type of low precision vectors. Since the truncate would be lowered into
-  // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
-  // restriction, such pattern would be expanded into a series of "vsetvli"
-  // and "vnsrl" instructions later to reach this point.
   auto IsTruncNode = [&](SDValue V) {
     return V.getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL &&
            V.getOperand(1) == Mask && V.getOperand(2) == VL;



More information about the llvm-commits mailing list