[llvm] add narrowExtractedVectorUnaryOp to simplify cast nodes (PR #87977)

Vedant Paranjape via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 8 04:01:39 PDT 2024


https://github.com/vedantparanjape-amd updated https://github.com/llvm/llvm-project/pull/87977

>From 02d887d71f45721218fda03dd6428e6caa24aee9 Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Mon, 8 Apr 2024 09:50:19 +0000
Subject: [PATCH 1/2] add narrowExtractedVectorUnaryOp to simplify cast nodes

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 52 +++++++++++++++++++
 1 file changed, 52 insertions(+)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2f46b23a97c62c..50bb270a66a588 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -24083,6 +24083,55 @@ static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
                      BinOp->getFlags());
 }
 
+/// If we are extracting a subvector produced by a wide unary operator try
+/// to use a narrow unary operator and/or avoid extraction.
+static SDValue narrowExtractedVectorUnaryOp(SDNode *Extract, SelectionDAG &DAG,
+                                          bool LegalOperations) {
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  SDValue UnaryOp = Extract->getOperand(0);
+  unsigned UnaryOpcode = UnaryOp.getOpcode();
+  
+  if (UnaryOpcode != ISD::FP_TO_SINT || UnaryOp->getNumValues() != 1)
+    return SDValue();
+
+  // The extract index must be a constant, so we can map it to a concat operand.
+  auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
+  if (!ExtractIndexC)
+    return SDValue();
+
+  EVT WideUVT = UnaryOp.getValueType();
+  if (!WideUVT.isFixedLengthVector())
+    return SDValue();
+  
+  EVT VT = Extract->getValueType(0);
+  unsigned ExtractIndex = ExtractIndexC->getZExtValue();
+  assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
+         "Extract index is not a multiple of the vector length.");
+
+  // Bail out if this is not a proper multiple width extraction.
+  unsigned WideWidth = WideUVT.getSizeInBits();
+  unsigned NarrowWidth = VT.getSizeInBits();
+  if (WideWidth % NarrowWidth != 0)
+    return SDValue();
+
+  unsigned NarrowingRatio = WideWidth / NarrowWidth;
+  unsigned WideNumElts = WideUVT.getVectorNumElements();
+
+  // Bail out if the target does not support a narrower version of the unaryop.
+  EVT NarrowUVT = EVT::getVectorVT(*DAG.getContext(), WideUVT.getScalarType(),
+                                   WideNumElts / NarrowingRatio);
+  if (!TLI.isOperationLegalOrCustomOrPromote(UnaryOpcode, NarrowUVT,
+                                             LegalOperations))
+    return SDValue();
+  
+  SDLoc DL(Extract);
+  auto ret = DAG.getNode(UnaryOpcode, DL, NarrowUVT, UnaryOp.getOperand(0));
+  dbgs() << "reduced node found: \n";
+  ret.dump();
+
+  return ret;
+}
+
 /// If we are extracting a subvector produced by a wide binary operator try
 /// to use a narrow binary operator and/or avoid concatenation and extraction.
 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
@@ -24613,6 +24662,9 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
         N->getOperand(1));
   }
 
+  if (SDValue NarrowUOp = narrowExtractedVectorUnaryOp(N, DAG, LegalOperations))
+    return NarrowUOp;
+
   if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
     return NarrowBOp;
 

>From e263e5f0c99aaf8d9451228518e376b24b5d7f57 Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Mon, 8 Apr 2024 11:01:19 +0000
Subject: [PATCH 2/2] Added fold logic

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 25 ++++++++++++++++++-
 1 file changed, 24 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 50bb270a66a588..23dd91e07fb9f4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -24123,7 +24123,30 @@ static SDValue narrowExtractedVectorUnaryOp(SDNode *Extract, SelectionDAG &DAG,
   if (!TLI.isOperationLegalOrCustomOrPromote(UnaryOpcode, NarrowUVT,
                                              LegalOperations))
     return SDValue();
-  
+
+  EVT NarrowUEltVT = EVT::getVectorVT(*DAG.getContext(), UnaryOp.getOperand(0).getValueType().getScalarType(),
+                                   WideNumElts / NarrowingRatio);
+  // if (!TLI.isOperationLegalOrCustomOrPromote(ISD::EXTRACT_SUBVECTOR, NarrowUEltVT,
+  //                                            LegalOperations))
+  //   return SDValue();
+
+  // If extraction is cheap, we don't need to look at the binop operands
+  // for concat ops. The narrow binop alone makes this transform profitable.
+  if (TLI.isExtractSubvectorCheap(NarrowUVT, WideUVT, ExtractIndex) &&
+      UnaryOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
+    // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
+    SDLoc DL(Extract);
+    NarrowUVT.dump();
+    NarrowUEltVT.dump();
+    SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowUEltVT,
+                            UnaryOp.getOperand(0), Extract->getOperand(1));
+    SDValue NarrowUnaryOp = DAG.getNode(UnaryOpcode, DL, NarrowUVT, X);
+    X.dump();
+    NarrowUnaryOp.dump();
+    dbgs() << "reduced node found (smol)\n";
+    return DAG.getBitcast(VT, NarrowUnaryOp);
+  }
+
   SDLoc DL(Extract);
   auto ret = DAG.getNode(UnaryOpcode, DL, NarrowUVT, UnaryOp.getOperand(0));
   dbgs() << "reduced node found: \n";



More information about the llvm-commits mailing list