[llvm] [CostModel] Make sure getCmpSelInstrCost is passed a CondTy (PR #135535)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Sun Apr 13 02:13:17 PDT 2025
https://github.com/davemgreen created https://github.com/llvm/llvm-project/pull/135535
It is already required along certain code paths that the CondTy is valid. Fix some of the uses to make sure it is passed.
>From 332ce17d5050ad89b05b0c6b1a43c11e0b3d99e6 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Sun, 13 Apr 2025 10:11:31 +0100
Subject: [PATCH] [CostModel] Make sure getCmpSelInstrCost is passed a CondTy
It is already required along certain code paths that the CondTy is valid. Fix
some of the uses to make sure it is passed.
---
llvm/include/llvm/CodeGen/BasicTTIImpl.h | 8 +++-----
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 3 ++-
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 8 ++++----
llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 8 ++++----
4 files changed, 13 insertions(+), 14 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index eacf75c24695f..983fb16f255ec 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1384,11 +1384,9 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return InstructionCost::getInvalid();
unsigned Num = cast<FixedVectorType>(ValVTy)->getNumElements();
- if (CondTy)
- CondTy = CondTy->getScalarType();
- InstructionCost Cost =
- thisT()->getCmpSelInstrCost(Opcode, ValVTy->getScalarType(), CondTy,
- VecPred, CostKind, Op1Info, Op2Info, I);
+ InstructionCost Cost = thisT()->getCmpSelInstrCost(
+ Opcode, ValVTy->getScalarType(), CondTy->getScalarType(), VecPred,
+ CostKind, Op1Info, Op2Info, I);
// Return the cost of multiple scalar invocation plus the cost of
// inserting and extracting the values.
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index eac7e7c209c95..6d9fd98cb20a8 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -3094,7 +3094,8 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB,
if (ThenV == OrigV)
continue;
- Cost += TTI.getCmpSelInstrCost(Instruction::Select, PN.getType(), nullptr,
+ Cost += TTI.getCmpSelInstrCost(Instruction::Select, PN.getType(),
+ CmpInst::makeCmpResultType(PN.getType()),
CmpInst::BAD_ICMP_PREDICATE, CostKind);
// Don't convert to selects if we could remove undefined behavior instead.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0acca63503afa..2b61d0c5441ed 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6974,10 +6974,10 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
}
VectorTy = toVectorTy(ValTy, VF);
- return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, nullptr,
- cast<CmpInst>(I)->getPredicate(), CostKind,
- {TTI::OK_AnyValue, TTI::OP_None},
- {TTI::OK_AnyValue, TTI::OP_None}, I);
+ return TTI.getCmpSelInstrCost(
+ I->getOpcode(), VectorTy, CmpInst::makeCmpResultType(VectorTy),
+ cast<CmpInst>(I)->getPredicate(), CostKind,
+ {TTI::OK_AnyValue, TTI::OP_None}, {TTI::OK_AnyValue, TTI::OP_None}, I);
}
case Instruction::Store:
case Instruction::Load: {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2cff343d915cf..ebedea1d65a9a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1753,10 +1753,10 @@ InstructionCost VPWidenRecipe::computeCost(ElementCount VF,
case Instruction::FCmp: {
Instruction *CtxI = dyn_cast_or_null<Instruction>(getUnderlyingValue());
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
- return Ctx.TTI.getCmpSelInstrCost(Opcode, VectorTy, nullptr, getPredicate(),
- Ctx.CostKind,
- {TTI::OK_AnyValue, TTI::OP_None},
- {TTI::OK_AnyValue, TTI::OP_None}, CtxI);
+ return Ctx.TTI.getCmpSelInstrCost(
+ Opcode, VectorTy, CmpInst::makeCmpResultType(VectorTy), getPredicate(),
+ Ctx.CostKind, {TTI::OK_AnyValue, TTI::OP_None},
+ {TTI::OK_AnyValue, TTI::OP_None}, CtxI);
}
default:
llvm_unreachable("Unsupported opcode for instruction");
More information about the llvm-commits
mailing list