[llvm] [AArch64] Add better fcmp costs for expanded predicates (PR #147940)

David Green via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 21 08:04:04 PDT 2025


================
@@ -4308,63 +4307,80 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
         return LT.first;
     }
 
-    static const TypeConversionCostTblEntry
-    VectorSelectTbl[] = {
-      { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 },
-      { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 },
-      { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 },
-      { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 },
-      { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 },
-      { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
-      { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
-      { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
-      { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
-      { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
-      { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
-    };
+    static const TypeConversionCostTblEntry VectorSelectTbl[] = {
+        {Instruction::Select, MVT::v2i1, MVT::v2f32, 2},
+        {Instruction::Select, MVT::v2i1, MVT::v2f64, 2},
+        {Instruction::Select, MVT::v4i1, MVT::v4f32, 2},
+        {Instruction::Select, MVT::v4i1, MVT::v4f16, 2},
+        {Instruction::Select, MVT::v8i1, MVT::v8f16, 2},
+        {Instruction::Select, MVT::v16i1, MVT::v16i16, 16},
+        {Instruction::Select, MVT::v8i1, MVT::v8i32, 8},
+        {Instruction::Select, MVT::v16i1, MVT::v16i32, 16},
+        {Instruction::Select, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost},
+        {Instruction::Select, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost},
+        {Instruction::Select, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost}};
 
     EVT SelCondTy = TLI->getValueType(DL, CondTy);
     EVT SelValTy = TLI->getValueType(DL, ValTy);
     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
-      if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
+      if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, Opcode,
                                                      SelCondTy.getSimpleVT(),
                                                      SelValTy.getSimpleVT()))
         return Entry->Cost;
     }
   }
 
-  if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
-    Type *ValScalarTy = ValTy->getScalarType();
-    if ((ValScalarTy->isHalfTy() && !ST->hasFullFP16()) ||
-        ValScalarTy->isBFloatTy()) {
-      auto *ValVTy = cast<FixedVectorType>(ValTy);
-
-      // Without dedicated instructions we promote [b]f16 compares to f32.
-      auto *PromotedTy =
-          VectorType::get(Type::getFloatTy(ValTy->getContext()), ValVTy);
-
-      InstructionCost Cost = 0;
-      // Promote operands to float vectors.
-      Cost += 2 * getCastInstrCost(Instruction::FPExt, PromotedTy, ValTy,
-                                   TTI::CastContextHint::None, CostKind);
-      // Compare float vectors.
+  if (Opcode == Instruction::FCmp) {
+    // Without dedicated instructions we promote f16 + bf16 compares to f32.
+    if ((!ST->hasFullFP16() && ValTy->getScalarType()->isHalfTy()) ||
+        ValTy->getScalarType()->isBFloatTy()) {
+      Type *PromotedTy =
+          ValTy->getWithNewType(Type::getFloatTy(ValTy->getContext()));
+      InstructionCost Cost =
+          getCastInstrCost(Instruction::FPExt, PromotedTy, ValTy,
+                           TTI::CastContextHint::None, CostKind);
+      if (!Op1Info.isConstant() && !Op2Info.isConstant())
+        Cost *= 2;
       Cost += getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, CostKind,
                                  Op1Info, Op2Info);
-      // During codegen we'll truncate the vector result from i32 to i16.
-      Cost +=
-          getCastInstrCost(Instruction::Trunc, VectorType::getInteger(ValVTy),
-                           VectorType::getInteger(PromotedTy),
-                           TTI::CastContextHint::None, CostKind);
+      if (ValTy->isVectorTy())
+        Cost += getCastInstrCost(
+            Instruction::Trunc, VectorType::getInteger(cast<VectorType>(ValTy)),
+            VectorType::getInteger(cast<VectorType>(PromotedTy)),
+            TTI::CastContextHint::None, CostKind);
       return Cost;
     }
+
+    auto LT = getTypeLegalizationCost(ValTy);
+    // Model unknown fp compares as a libcall.
+    if (LT.second.getScalarType() != MVT::f64 &&
+        LT.second.getScalarType() != MVT::f32 &&
+        LT.second.getScalarType() != MVT::f16)
+      return LT.first * getCallInstrCost(/*Function*/ nullptr, ValTy,
+                                         {ValTy, ValTy}, CostKind);
----------------
davemgreen wrote:

We usually say that is kind of impossible and just say every call costs 10. It is usually enough to prevent vectorization, which is what we are after for fp128 types.

https://github.com/llvm/llvm-project/pull/147940


More information about the llvm-commits mailing list