[llvm] [X86] Fix duplicated compute in recursive search. (PR #130226)

via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 9 19:57:41 PDT 2025


https://github.com/haonanya1 updated https://github.com/llvm/llvm-project/pull/130226

>From 4574283c27179ae0cb5b1ebbf7e7053551a7a79b Mon Sep 17 00:00:00 2001
From: "Yang, Haonan" <haonan.yang at intel.com>
Date: Fri, 7 Mar 2025 04:23:08 +0100
Subject: [PATCH 1/2] [X86] Fix duplicated compute in recursive search.

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 74 ++++++++++++++++++-------
 1 file changed, 53 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index deab638b7e546..1e58222f9bea1 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -63,6 +63,7 @@
 #include <bitset>
 #include <cctype>
 #include <numeric>
+#include <tuple>
 using namespace llvm;
 
 #define DEBUG_TYPE "x86-isel"
@@ -44745,31 +44746,59 @@ bool X86TargetLowering::isSplatValueForTargetNode(SDValue Op,
 
 // Helper to peek through bitops/trunc/setcc to determine size of source vector.
 // Allows combineBitcastvxi1 to determine what size vector generated a <X x i1>.
-static bool checkBitcastSrcVectorSize(SDValue Src, unsigned Size,
-                                      bool AllowTruncate) {
+static bool
+checkBitcastSrcVectorSize(SDValue Src, unsigned Size, bool AllowTruncate,
+                          std::map<std::tuple<SDValue, unsigned, bool>, bool>
+                              &BitcastSrcVectorSizeMap) {
+  auto Tp = std::make_tuple(Src, Size, AllowTruncate);
+  if (BitcastSrcVectorSizeMap.count(Tp))
+    return BitcastSrcVectorSizeMap[Tp];
   switch (Src.getOpcode()) {
   case ISD::TRUNCATE:
-    if (!AllowTruncate)
+    if (!AllowTruncate) {
+      BitcastSrcVectorSizeMap[Tp] = false;
       return false;
+    }
     [[fallthrough]];
-  case ISD::SETCC:
-    return Src.getOperand(0).getValueSizeInBits() == Size;
-  case ISD::FREEZE:
-    return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate);
+  case ISD::SETCC: {
+    auto Ret = Src.getOperand(0).getValueSizeInBits() == Size;
+    BitcastSrcVectorSizeMap[Tp] = Ret;
+    return Ret;
+  }
+  case ISD::FREEZE: {
+    auto Ret = checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate,
+                                         BitcastSrcVectorSizeMap);
+    BitcastSrcVectorSizeMap[Tp] = Ret;
+    return Ret;
+  }
   case ISD::AND:
   case ISD::XOR:
-  case ISD::OR:
-    return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate) &&
-           checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate);
+  case ISD::OR: {
+    auto Ret1 = checkBitcastSrcVectorSize(
+        Src.getOperand(0), Size, AllowTruncate, BitcastSrcVectorSizeMap);
+    auto Ret2 = checkBitcastSrcVectorSize(
+        Src.getOperand(1), Size, AllowTruncate, BitcastSrcVectorSizeMap);
+    BitcastSrcVectorSizeMap[Tp] = Ret1 && Ret2;
+    return Ret1 && Ret2;
+  }
   case ISD::SELECT:
-  case ISD::VSELECT:
-    return Src.getOperand(0).getScalarValueSizeInBits() == 1 &&
-           checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate) &&
-           checkBitcastSrcVectorSize(Src.getOperand(2), Size, AllowTruncate);
-  case ISD::BUILD_VECTOR:
-    return ISD::isBuildVectorAllZeros(Src.getNode()) ||
-           ISD::isBuildVectorAllOnes(Src.getNode());
+  case ISD::VSELECT: {
+    auto Ret1 = checkBitcastSrcVectorSize(
+        Src.getOperand(1), Size, AllowTruncate, BitcastSrcVectorSizeMap);
+    auto Ret2 = checkBitcastSrcVectorSize(
+        Src.getOperand(2), Size, AllowTruncate, BitcastSrcVectorSizeMap);
+    auto Ret3 = Src.getOperand(0).getScalarValueSizeInBits() == 1;
+    BitcastSrcVectorSizeMap[Tp] = Ret1 && Ret2 && Ret3;
+    return Ret1 && Ret2 && Ret3;
+  }
+  case ISD::BUILD_VECTOR: {
+    auto Ret = ISD::isBuildVectorAllZeros(Src.getNode()) ||
+               ISD::isBuildVectorAllOnes(Src.getNode());
+    BitcastSrcVectorSizeMap[Tp] = Ret;
+    return Ret;
   }
+  }
+  BitcastSrcVectorSizeMap[Tp] = false;
   return false;
 }
 
@@ -44925,6 +44954,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
   // (v16i8 shuffle <0,2,4,6,8,10,12,14,u,u,...,u> (v16i8 bitcast t0), undef)
   MVT SExtVT;
   bool PropagateSExt = false;
+  std::map<std::tuple<SDValue, unsigned, bool>, bool> BitcastSrcVectorSizeMap;
   switch (SrcVT.getSimpleVT().SimpleTy) {
   default:
     return SDValue();
@@ -44936,7 +44966,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
     // For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2))
     // sign-extend to a 256-bit operation to avoid truncation.
     if (Subtarget.hasAVX() &&
-        checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2())) {
+        checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2(),
+                                  BitcastSrcVectorSizeMap)) {
       SExtVT = MVT::v4i64;
       PropagateSExt = true;
     }
@@ -44948,8 +44979,9 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
     // If the setcc operand is 128-bit, prefer sign-extending to 128-bit over
     // 256-bit because the shuffle is cheaper than sign extending the result of
     // the compare.
-    if (Subtarget.hasAVX() && (checkBitcastSrcVectorSize(Src, 256, true) ||
-                               checkBitcastSrcVectorSize(Src, 512, true))) {
+    if (Subtarget.hasAVX() &&
+        (checkBitcastSrcVectorSize(Src, 256, true, BitcastSrcVectorSizeMap) ||
+         checkBitcastSrcVectorSize(Src, 512, true, BitcastSrcVectorSizeMap))) {
       SExtVT = MVT::v8i32;
       PropagateSExt = true;
     }
@@ -44974,7 +45006,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
       break;
     }
     // Split if this is a <64 x i8> comparison result.
-    if (checkBitcastSrcVectorSize(Src, 512, false)) {
+    if (checkBitcastSrcVectorSize(Src, 512, false, BitcastSrcVectorSizeMap)) {
       SExtVT = MVT::v64i8;
       break;
     }

>From b5aeeeebd53ff266f107e136f2ad2a837338440a Mon Sep 17 00:00:00 2001
From: "Yang, Haonan" <haonan.yang at intel.com>
Date: Mon, 10 Mar 2025 03:27:07 +0100
Subject: [PATCH 2/2] Add a recursive depth limit.

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 82 +++++++++----------------
 1 file changed, 29 insertions(+), 53 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 1e58222f9bea1..26b8cb029c9a8 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -63,7 +63,6 @@
 #include <bitset>
 #include <cctype>
 #include <numeric>
-#include <tuple>
 using namespace llvm;
 
 #define DEBUG_TYPE "x86-isel"
@@ -44746,59 +44745,39 @@ bool X86TargetLowering::isSplatValueForTargetNode(SDValue Op,
 
 // Helper to peek through bitops/trunc/setcc to determine size of source vector.
 // Allows combineBitcastvxi1 to determine what size vector generated a <X x i1>.
-static bool
-checkBitcastSrcVectorSize(SDValue Src, unsigned Size, bool AllowTruncate,
-                          std::map<std::tuple<SDValue, unsigned, bool>, bool>
-                              &BitcastSrcVectorSizeMap) {
-  auto Tp = std::make_tuple(Src, Size, AllowTruncate);
-  if (BitcastSrcVectorSizeMap.count(Tp))
-    return BitcastSrcVectorSizeMap[Tp];
+static bool checkBitcastSrcVectorSize(SDValue Src, unsigned Size,
+                                      bool AllowTruncate, unsigned Depth) {
+  // Limit recursion.
+  if (Depth >= SelectionDAG::MaxRecursionDepth)
+    return false;
   switch (Src.getOpcode()) {
   case ISD::TRUNCATE:
-    if (!AllowTruncate) {
-      BitcastSrcVectorSizeMap[Tp] = false;
+    if (!AllowTruncate)
       return false;
-    }
     [[fallthrough]];
-  case ISD::SETCC: {
-    auto Ret = Src.getOperand(0).getValueSizeInBits() == Size;
-    BitcastSrcVectorSizeMap[Tp] = Ret;
-    return Ret;
-  }
-  case ISD::FREEZE: {
-    auto Ret = checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate,
-                                         BitcastSrcVectorSizeMap);
-    BitcastSrcVectorSizeMap[Tp] = Ret;
-    return Ret;
-  }
+  case ISD::SETCC:
+    return Src.getOperand(0).getValueSizeInBits() == Size;
+  case ISD::FREEZE:
+    return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate,
+                                     Depth + 1);
   case ISD::AND:
   case ISD::XOR:
-  case ISD::OR: {
-    auto Ret1 = checkBitcastSrcVectorSize(
-        Src.getOperand(0), Size, AllowTruncate, BitcastSrcVectorSizeMap);
-    auto Ret2 = checkBitcastSrcVectorSize(
-        Src.getOperand(1), Size, AllowTruncate, BitcastSrcVectorSizeMap);
-    BitcastSrcVectorSizeMap[Tp] = Ret1 && Ret2;
-    return Ret1 && Ret2;
-  }
+  case ISD::OR:
+    return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate,
+                                     Depth + 1) &&
+           checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate,
+                                     Depth + 1);
   case ISD::SELECT:
-  case ISD::VSELECT: {
-    auto Ret1 = checkBitcastSrcVectorSize(
-        Src.getOperand(1), Size, AllowTruncate, BitcastSrcVectorSizeMap);
-    auto Ret2 = checkBitcastSrcVectorSize(
-        Src.getOperand(2), Size, AllowTruncate, BitcastSrcVectorSizeMap);
-    auto Ret3 = Src.getOperand(0).getScalarValueSizeInBits() == 1;
-    BitcastSrcVectorSizeMap[Tp] = Ret1 && Ret2 && Ret3;
-    return Ret1 && Ret2 && Ret3;
-  }
-  case ISD::BUILD_VECTOR: {
-    auto Ret = ISD::isBuildVectorAllZeros(Src.getNode()) ||
-               ISD::isBuildVectorAllOnes(Src.getNode());
-    BitcastSrcVectorSizeMap[Tp] = Ret;
-    return Ret;
-  }
+  case ISD::VSELECT:
+    return Src.getOperand(0).getScalarValueSizeInBits() == 1 &&
+           checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate,
+                                     Depth + 1) &&
+           checkBitcastSrcVectorSize(Src.getOperand(2), Size, AllowTruncate,
+                                     Depth + 1);
+  case ISD::BUILD_VECTOR:
+    return ISD::isBuildVectorAllZeros(Src.getNode()) ||
+           ISD::isBuildVectorAllOnes(Src.getNode());
   }
-  BitcastSrcVectorSizeMap[Tp] = false;
   return false;
 }
 
@@ -44954,7 +44933,6 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
   // (v16i8 shuffle <0,2,4,6,8,10,12,14,u,u,...,u> (v16i8 bitcast t0), undef)
   MVT SExtVT;
   bool PropagateSExt = false;
-  std::map<std::tuple<SDValue, unsigned, bool>, bool> BitcastSrcVectorSizeMap;
   switch (SrcVT.getSimpleVT().SimpleTy) {
   default:
     return SDValue();
@@ -44966,8 +44944,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
     // For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2))
     // sign-extend to a 256-bit operation to avoid truncation.
     if (Subtarget.hasAVX() &&
-        checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2(),
-                                  BitcastSrcVectorSizeMap)) {
+        checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2(), 0)) {
       SExtVT = MVT::v4i64;
       PropagateSExt = true;
     }
@@ -44979,9 +44956,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
     // If the setcc operand is 128-bit, prefer sign-extending to 128-bit over
     // 256-bit because the shuffle is cheaper than sign extending the result of
     // the compare.
-    if (Subtarget.hasAVX() &&
-        (checkBitcastSrcVectorSize(Src, 256, true, BitcastSrcVectorSizeMap) ||
-         checkBitcastSrcVectorSize(Src, 512, true, BitcastSrcVectorSizeMap))) {
+    if (Subtarget.hasAVX() && (checkBitcastSrcVectorSize(Src, 256, true, 0) ||
+                               checkBitcastSrcVectorSize(Src, 512, true, 0))) {
       SExtVT = MVT::v8i32;
       PropagateSExt = true;
     }
@@ -45006,7 +44982,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
       break;
     }
     // Split if this is a <64 x i8> comparison result.
-    if (checkBitcastSrcVectorSize(Src, 512, false, BitcastSrcVectorSizeMap)) {
+    if (checkBitcastSrcVectorSize(Src, 512, false, 0)) {
       SExtVT = MVT::v64i8;
       break;
     }



More information about the llvm-commits mailing list