[llvm] [BasicTTI] Use getTypeLegalizationCost to generalize vector cast cost. (PR #107303)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 17 07:26:02 PDT 2024


================
@@ -1171,9 +1171,50 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                           CostKind, I));
       }
 
+      // Check if the wider type of Src and Dst needs to be legalized. If it
+      // does, compute the cost of the cast based on vectors with the same
+      // number of elements as the legalized widest type. Don't directly return
+      // the cost, as scalarization may be more profitable, which has its cost
+      // computed below.
+      // TODO: Use the more general logic below to replace the specific logic
+      // above to also handle cost for legalization via splitting above.
+      InstructionCost LegalizedCost = InstructionCost::getInvalid();
+      if (SrcVTy->getElementCount().isVector() &&
+          DstVTy->getElementCount().isVector()) {
+        Type *WiderTy =
+            Src->getScalarSizeInBits() < Dst->getScalarSizeInBits() ? Dst : Src;
+        auto [WiderLegalCost, WiderLegalTy] = getTypeLegalizationCost(WiderTy);
+        if (WiderLegalTy.isVector() &&
+            TLI->getValueType(DL, WiderTy) != WiderLegalTy) {
+          Type *SplitDstTy = VectorType::get(
+              DstVTy->getElementType(), WiderLegalTy.getVectorElementCount());
+          Type *SplitSrcTy = VectorType::get(
+              SrcVTy->getElementType(), WiderLegalTy.getVectorElementCount());
+          if (SplitDstTy != Dst && SplitSrcTy != Src) {
+            bool SplitSrc = TLI->getTypeAction(Src->getContext(),
+                                               TLI->getValueType(DL, Src)) ==
+                            TargetLowering::TypeSplitVector;
+            bool SplitDst = TLI->getTypeAction(Dst->getContext(),
+                                               TLI->getValueType(DL, Dst)) ==
+                            TargetLowering::TypeSplitVector;
+            T *TTI = static_cast<T *>(this);
+            // If both types need to be split then the split is free.
+            InstructionCost SplitCost =
+                (SplitSrc || SplitSrc) && (!SplitSrc || !SplitDst)
+                    ? TTI->getVectorSplitCost()
+                    : 0;
+
+            LegalizedCost = WiderLegalCost *
+                            (SplitCost + TTI->getCastInstrCost(
+                                             Opcode, SplitDstTy, SplitSrcTy,
+                                             CCH, CostKind, I));
+          }
+        }
+      }
+
       // Scalarization cost is Invalid, can't assume any num elements.
       if (isa<ScalableVectorType>(DstVTy))
-        return InstructionCost::getInvalid();
+        return LegalizedCost;
----------------
david-arm wrote:

Is the comment above still valid?

`// Scalarization cost is Invalid, can't assume any num elements.`


https://github.com/llvm/llvm-project/pull/107303


More information about the llvm-commits mailing list