[llvm] fe3f8ad - [X86] getIntrinsicInstrCost - begin generalizing BSWAP load/store-folding handling.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 17 10:01:42 PDT 2024


Author: Simon Pilgrim
Date: 2024-06-17T18:01:12+01:00
New Revision: fe3f8ad8cc209c1f73a3c9465b46701120de51f7

URL: https://github.com/llvm/llvm-project/commit/fe3f8ad8cc209c1f73a3c9465b46701120de51f7
DIFF: https://github.com/llvm/llvm-project/commit/fe3f8ad8cc209c1f73a3c9465b46701120de51f7.diff

LOG: [X86] getIntrinsicInstrCost - begin generalizing BSWAP load/store-folding handling.

Move load/store folding 'free costs' inside the adjustTableCost helper so we can some additional intrinsics in the future.

The plan is to do something similar for other costs callbacks as well (getArithmeticInstrCost etc.).

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 64cacd74153fe..de0144331dba3 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -4262,6 +4262,37 @@ X86TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   }
 
   if (ISD != ISD::DELETED_NODE) {
+    auto adjustTableCost = [&](int ISD, unsigned Cost,
+                               std::pair<InstructionCost, MVT> LT,
+                               FastMathFlags FMF) -> InstructionCost {
+      InstructionCost LegalizationCost = LT.first;
+      MVT MTy = LT.second;
+
+      // If there are no NANs to deal with, then these are reduced to a
+      // single MIN** or MAX** instruction instead of the MIN/CMP/SELECT that we
+      // assume is used in the non-fast case.
+      if (ISD == ISD::FMAXNUM || ISD == ISD::FMINNUM) {
+        if (FMF.noNaNs())
+          return LegalizationCost * 1;
+      }
+
+      // For cases where some ops can be folded into a load/store, assume free.
+      if (MTy.isScalarInteger()) {
+        if (ISD == ISD::BSWAP && ST->hasMOVBE() && ST->hasFastMOVBE()) {
+          if (const Instruction *II = ICA.getInst()) {
+            if (II->hasOneUse() && isa<StoreInst>(II->user_back()))
+              return TTI::TCC_Free;
+            if (auto *LI = dyn_cast<LoadInst>(II->getOperand(0))) {
+              if (LI->hasOneUse())
+                return TTI::TCC_Free;
+            }
+          }
+        }
+      }
+
+      return LegalizationCost * (int)Cost;
+    };
+
     // Legalize the type.
     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(OpTy);
     MVT MTy = LT.second;
@@ -4280,180 +4311,132 @@ X86TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
     if (ISD == ISD::FSQRT && CostKind == TTI::TCK_CodeSize)
       return LT.first;
 
-    auto adjustTableCost = [](int ISD, unsigned Cost,
-                              InstructionCost LegalizationCost,
-                              FastMathFlags FMF) {
-      // If there are no NANs to deal with, then these are reduced to a
-      // single MIN** or MAX** instruction instead of the MIN/CMP/SELECT that we
-      // assume is used in the non-fast case.
-      if (ISD == ISD::FMAXNUM || ISD == ISD::FMINNUM) {
-        if (FMF.noNaNs())
-          return LegalizationCost * 1;
-      }
-      return LegalizationCost * (int)Cost;
-    };
-
     if (ST->useGLMDivSqrtCosts())
       if (const auto *Entry = CostTableLookup(GLMCostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->useSLMArithCosts())
       if (const auto *Entry = CostTableLookup(SLMCostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasVBMI2())
       if (const auto *Entry = CostTableLookup(AVX512VBMI2CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasBITALG())
       if (const auto *Entry = CostTableLookup(AVX512BITALGCostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasVPOPCNTDQ())
       if (const auto *Entry = CostTableLookup(AVX512VPOPCNTDQCostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasGFNI())
       if (const auto *Entry = CostTableLookup(GFNICostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasCDI())
       if (const auto *Entry = CostTableLookup(AVX512CDCostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasBWI())
       if (const auto *Entry = CostTableLookup(AVX512BWCostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasAVX512())
       if (const auto *Entry = CostTableLookup(AVX512CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasXOP())
       if (const auto *Entry = CostTableLookup(XOPCostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasAVX2())
       if (const auto *Entry = CostTableLookup(AVX2CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasAVX())
       if (const auto *Entry = CostTableLookup(AVX1CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasSSE42())
       if (const auto *Entry = CostTableLookup(SSE42CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasSSE41())
       if (const auto *Entry = CostTableLookup(SSE41CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasSSSE3())
       if (const auto *Entry = CostTableLookup(SSSE3CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasSSE2())
       if (const auto *Entry = CostTableLookup(SSE2CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasSSE1())
       if (const auto *Entry = CostTableLookup(SSE1CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (ST->hasBMI()) {
       if (ST->is64Bit())
         if (const auto *Entry = CostTableLookup(BMI64CostTbl, ISD, MTy))
           if (auto KindCost = Entry->Cost[CostKind])
-            return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                   ICA.getFlags());
+            return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
       if (const auto *Entry = CostTableLookup(BMI32CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
     }
 
     if (ST->hasLZCNT()) {
       if (ST->is64Bit())
         if (const auto *Entry = CostTableLookup(LZCNT64CostTbl, ISD, MTy))
           if (auto KindCost = Entry->Cost[CostKind])
-            return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                   ICA.getFlags());
+            return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
       if (const auto *Entry = CostTableLookup(LZCNT32CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
     }
 
     if (ST->hasPOPCNT()) {
       if (ST->is64Bit())
         if (const auto *Entry = CostTableLookup(POPCNT64CostTbl, ISD, MTy))
           if (auto KindCost = Entry->Cost[CostKind])
-            return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                   ICA.getFlags());
+            return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
       if (const auto *Entry = CostTableLookup(POPCNT32CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
-    }
-
-    if (ISD == ISD::BSWAP && ST->hasMOVBE() && ST->hasFastMOVBE()) {
-      if (const Instruction *II = ICA.getInst()) {
-        if (II->hasOneUse() && isa<StoreInst>(II->user_back()))
-          return TTI::TCC_Free;
-        if (auto *LI = dyn_cast<LoadInst>(II->getOperand(0))) {
-          if (LI->hasOneUse())
-            return TTI::TCC_Free;
-        }
-      }
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
     }
 
     if (ST->is64Bit())
       if (const auto *Entry = CostTableLookup(X64CostTbl, ISD, MTy))
         if (auto KindCost = Entry->Cost[CostKind])
-          return adjustTableCost(Entry->ISD, *KindCost, LT.first,
-                                 ICA.getFlags());
+          return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
 
     if (const auto *Entry = CostTableLookup(X86CostTbl, ISD, MTy))
       if (auto KindCost = Entry->Cost[CostKind])
-        return adjustTableCost(Entry->ISD, *KindCost, LT.first, ICA.getFlags());
+        return adjustTableCost(Entry->ISD, *KindCost, LT, ICA.getFlags());
   }
 
   return BaseT::getIntrinsicInstrCost(ICA, CostKind);


        


More information about the llvm-commits mailing list