[llvm] 2a1ef87 - [CostModel][X86] getCastInstrCost - attempt to match custom cast/conversion before legalized types.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 1 04:22:30 PDT 2021


Author: Simon Pilgrim
Date: 2021-07-01T12:06:40+01:00
New Revision: 2a1ef8784ad9a78583a2f1f3bba536ee57b6b13b

URL: https://github.com/llvm/llvm-project/commit/2a1ef8784ad9a78583a2f1f3bba536ee57b6b13b
DIFF: https://github.com/llvm/llvm-project/commit/2a1ef8784ad9a78583a2f1f3bba536ee57b6b13b.diff

LOG: [CostModel][X86] getCastInstrCost - attempt to match custom cast/conversion before legalized types.

Move the (SSE-only) generic, legalized type conversion matching after the specific,custom conversion cases, allowing us to properly provide cost overrides.

The next step will be to clean up some of the weird existing costs and then to enable AVX+ legalized costs, which will let us strip out a lot of the cost tables entries.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/test/Analysis/CostModel/X86/cast.ll
    llvm/test/Analysis/CostModel/X86/sitofp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index d245324cc9ce6..e400000f83160 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -1519,9 +1519,11 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
     return Cost;
   };
 
+  // The cost tables include both specific, custom (non-legal) src/dst type
+  // conversions and generic, legalized types. We test for customs first, before
+  // falling back to legalization.
   // FIXME: Need a better design of the cost table to handle non-simple types of
   // potential massive combinations (elem_num x src_type x dst_type).
-
   static const TypeConversionCostTblEntry AVX512BWConversionTbl[] {
     { ISD::SIGN_EXTEND, MVT::v32i16, MVT::v32i8, 1 },
     { ISD::ZERO_EXTEND, MVT::v32i16, MVT::v32i8, 1 },
@@ -2173,85 +2175,87 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
     { ISD::TRUNCATE,    MVT::v2i32,  MVT::v2i64,  1 }, // PSHUFD
   };
 
-  std::pair<InstructionCost, MVT> LTSrc = TLI->getTypeLegalizationCost(DL, Src);
-  std::pair<InstructionCost, MVT> LTDest =
-      TLI->getTypeLegalizationCost(DL, Dst);
-
-  if (ST->hasSSE41() && !ST->hasAVX())
-    if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
-                                                   LTDest.second, LTSrc.second))
-      return AdjustCost(LTSrc.first * Entry->Cost);
-
-  if (ST->hasSSE2() && !ST->hasAVX())
-    if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD,
-                                                   LTDest.second, LTSrc.second))
-      return AdjustCost(LTSrc.first * Entry->Cost);
-
+  // Attempt to map directly to (simple) MVT types to let us match custom entries.
   EVT SrcTy = TLI->getValueType(DL, Src);
   EVT DstTy = TLI->getValueType(DL, Dst);
 
   // The function getSimpleVT only handles simple value types.
-  if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind));
-
-  MVT SimpleSrcTy = SrcTy.getSimpleVT();
-  MVT SimpleDstTy = DstTy.getSimpleVT();
+  if (SrcTy.isSimple() && DstTy.isSimple()) {
+    MVT SimpleSrcTy = SrcTy.getSimpleVT();
+    MVT SimpleDstTy = DstTy.getSimpleVT();
+
+    if (ST->useAVX512Regs()) {
+      if (ST->hasBWI())
+        if (const auto *Entry = ConvertCostTableLookup(
+                AVX512BWConversionTbl, ISD, SimpleDstTy, SimpleSrcTy))
+          return AdjustCost(Entry->Cost);
+
+      if (ST->hasDQI())
+        if (const auto *Entry = ConvertCostTableLookup(
+                AVX512DQConversionTbl, ISD, SimpleDstTy, SimpleSrcTy))
+          return AdjustCost(Entry->Cost);
+
+      if (ST->hasAVX512())
+        if (const auto *Entry = ConvertCostTableLookup(
+                AVX512FConversionTbl, ISD, SimpleDstTy, SimpleSrcTy))
+          return AdjustCost(Entry->Cost);
+    }
 
-  if (ST->useAVX512Regs()) {
     if (ST->hasBWI())
-      if (const auto *Entry = ConvertCostTableLookup(AVX512BWConversionTbl, ISD,
-                                                     SimpleDstTy, SimpleSrcTy))
+      if (const auto *Entry = ConvertCostTableLookup(
+              AVX512BWVLConversionTbl, ISD, SimpleDstTy, SimpleSrcTy))
         return AdjustCost(Entry->Cost);
 
     if (ST->hasDQI())
-      if (const auto *Entry = ConvertCostTableLookup(AVX512DQConversionTbl, ISD,
-                                                     SimpleDstTy, SimpleSrcTy))
+      if (const auto *Entry = ConvertCostTableLookup(
+              AVX512DQVLConversionTbl, ISD, SimpleDstTy, SimpleSrcTy))
         return AdjustCost(Entry->Cost);
 
     if (ST->hasAVX512())
-      if (const auto *Entry = ConvertCostTableLookup(AVX512FConversionTbl, ISD,
+      if (const auto *Entry = ConvertCostTableLookup(AVX512VLConversionTbl, ISD,
                                                      SimpleDstTy, SimpleSrcTy))
         return AdjustCost(Entry->Cost);
-  }
 
-  if (ST->hasBWI())
-    if (const auto *Entry = ConvertCostTableLookup(AVX512BWVLConversionTbl, ISD,
-                                                   SimpleDstTy, SimpleSrcTy))
-      return AdjustCost(Entry->Cost);
+    if (ST->hasAVX2()) {
+      if (const auto *Entry = ConvertCostTableLookup(AVX2ConversionTbl, ISD,
+                                                     SimpleDstTy, SimpleSrcTy))
+        return AdjustCost(Entry->Cost);
+    }
 
-  if (ST->hasDQI())
-    if (const auto *Entry = ConvertCostTableLookup(AVX512DQVLConversionTbl, ISD,
-                                                   SimpleDstTy, SimpleSrcTy))
-      return AdjustCost(Entry->Cost);
+    if (ST->hasAVX()) {
+      if (const auto *Entry = ConvertCostTableLookup(AVXConversionTbl, ISD,
+                                                     SimpleDstTy, SimpleSrcTy))
+        return AdjustCost(Entry->Cost);
+    }
 
-  if (ST->hasAVX512())
-    if (const auto *Entry = ConvertCostTableLookup(AVX512VLConversionTbl, ISD,
-                                                   SimpleDstTy, SimpleSrcTy))
-      return AdjustCost(Entry->Cost);
-
-  if (ST->hasAVX2()) {
-    if (const auto *Entry = ConvertCostTableLookup(AVX2ConversionTbl, ISD,
-                                                   SimpleDstTy, SimpleSrcTy))
-      return AdjustCost(Entry->Cost);
-  }
+    if (ST->hasSSE41()) {
+      if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
+                                                     SimpleDstTy, SimpleSrcTy))
+        return AdjustCost(Entry->Cost);
+    }
 
-  if (ST->hasAVX()) {
-    if (const auto *Entry = ConvertCostTableLookup(AVXConversionTbl, ISD,
-                                                   SimpleDstTy, SimpleSrcTy))
-      return AdjustCost(Entry->Cost);
+    if (ST->hasSSE2()) {
+      if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD,
+                                                     SimpleDstTy, SimpleSrcTy))
+        return AdjustCost(Entry->Cost);
+    }
   }
 
-  if (ST->hasSSE41()) {
+  // Fall back to legalized types.
+  // TODO: Add AVX support.
+  std::pair<InstructionCost, MVT> LTSrc = TLI->getTypeLegalizationCost(DL, Src);
+  std::pair<InstructionCost, MVT> LTDest =
+      TLI->getTypeLegalizationCost(DL, Dst);
+
+  if (ST->hasSSE41() && !ST->hasAVX())
     if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
-                                                   SimpleDstTy, SimpleSrcTy))
-      return AdjustCost(Entry->Cost);
-  }
+                                                   LTDest.second, LTSrc.second))
+      return AdjustCost(LTSrc.first * Entry->Cost);
 
-  if (ST->hasSSE2()) {
+  if (ST->hasSSE2() && !ST->hasAVX())
     if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD,
-                                                   SimpleDstTy, SimpleSrcTy))
-      return AdjustCost(Entry->Cost);
-  }
+                                                   LTDest.second, LTSrc.second))
+      return AdjustCost(LTSrc.first * Entry->Cost);
 
   return AdjustCost(
       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));

diff  --git a/llvm/test/Analysis/CostModel/X86/cast.ll b/llvm/test/Analysis/CostModel/X86/cast.ll
index 60fd218a19f6c..5377c68761a3b 100644
--- a/llvm/test/Analysis/CostModel/X86/cast.ll
+++ b/llvm/test/Analysis/CostModel/X86/cast.ll
@@ -391,7 +391,7 @@ define void @sitofp4(<4 x i1> %a, <4 x i8> %b, <4 x i16> %c, <4 x i32> %d) {
 ; SSE41-NEXT:  Cost Model: Found an estimated cost of 15 for instruction: %C1 = sitofp <4 x i16> %c to <4 x float>
 ; SSE41-NEXT:  Cost Model: Found an estimated cost of 80 for instruction: %C2 = sitofp <4 x i16> %c to <4 x double>
 ; SSE41-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %D1 = sitofp <4 x i32> %d to <4 x float>
-; SSE41-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %D2 = sitofp <4 x i32> %d to <4 x double>
+; SSE41-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %D2 = sitofp <4 x i32> %d to <4 x double>
 ; SSE41-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
 ;
 ; AVX1-LABEL: 'sitofp4'

diff  --git a/llvm/test/Analysis/CostModel/X86/sitofp.ll b/llvm/test/Analysis/CostModel/X86/sitofp.ll
index 718467c1e7832..de4b2c2276896 100644
--- a/llvm/test/Analysis/CostModel/X86/sitofp.ll
+++ b/llvm/test/Analysis/CostModel/X86/sitofp.ll
@@ -79,7 +79,7 @@ define i32 @sitofp_i32_double() {
 ; SSE42-LABEL: 'sitofp_i32_double'
 ; SSE42-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %cvt_i32_f64 = sitofp i32 undef to double
 ; SSE42-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %cvt_v2i32_v2f64 = sitofp <2 x i32> undef to <2 x double>
-; SSE42-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %cvt_v4i32_v4f64 = sitofp <4 x i32> undef to <4 x double>
+; SSE42-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %cvt_v4i32_v4f64 = sitofp <4 x i32> undef to <4 x double>
 ; SSE42-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %cvt_v8i32_v8f64 = sitofp <8 x i32> undef to <8 x double>
 ; SSE42-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret i32 undef
 ;


        


More information about the llvm-commits mailing list