[llvm] [AArch64][CostModel] Improve cost estimate of scalarizing a vector di… (PR #118055)
Sushant Gokhale via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 12 09:20:02 PST 2024
https://github.com/sushgokh updated https://github.com/llvm/llvm-project/pull/118055
>From 6596bc90efe6a225a1898be7c4c701d601495a67 Mon Sep 17 00:00:00 2001
From: sgokhale <sgokhale at nvidia.com>
Date: Fri, 29 Nov 2024 11:59:18 +0530
Subject: [PATCH] [AArch64][CostModel] Improve cost estimate of scalarizing a
vector division
In the backend, last resort of finding the vector division cost is to use its scalar cost. However, without knowledge about the division operands, the cost can be off in certain cases.
For SLP, this patch tries to pass scalars for better scalar cost estimation in the backend.
---
.../AArch64/AArch64TargetTransformInfo.cpp | 14 ++++++++
.../Transforms/Vectorize/SLPVectorizer.cpp | 22 ++++++++++--
.../Transforms/SLPVectorizer/AArch64/div.ll | 36 ++++---------------
3 files changed, 40 insertions(+), 32 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index d1536a276a9040..be3c3170caef13 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3472,6 +3472,20 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
Cost *= 4;
return Cost;
} else {
+ // If the information about individual scalars being vectorized is
+ // available, this yeilds better cost estimation.
+ if (auto *VTy = dyn_cast<FixedVectorType>(Ty); VTy && !Args.empty()) {
+ InstructionCost InsertExtractCost =
+ ST->getVectorInsertExtractBaseCost();
+ Cost = (3 * InsertExtractCost) * VTy->getNumElements();
+ for (int i = 0, Sz = Args.size(); i < Sz; i += 2) {
+ Cost += getArithmeticInstrCost(
+ Opcode, VTy->getScalarType(), CostKind,
+ TTI::getOperandInfo(Args[i]), TTI::getOperandInfo(Args[i + 1]));
+ }
+ return Cost;
+ }
+
// If one of the operands is a uniform constant then the cost for each
// element is Cost for insertion, extraction and division.
// Insertion cost = 2, Extraction Cost = 2, Division = cost for the
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 48a8520a966fc7..69588460d0cfff 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -11561,9 +11561,25 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
- return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
- Op2Info, {}, nullptr, TLI) +
- CommonCost;
+ SmallVector<Value *, 16> Operands;
+ auto IsAllowedScalarTy = [](const Value *I) {
+ // Just to avoid regression with RISCV unittest.
+ return I->getType()->getPrimitiveSizeInBits() >= 32;
+ };
+ if (all_of(E->Scalars, [ShuffleOrOp, IsAllowedScalarTy](Value *V) {
+ return !IsaPred<UndefValue, PoisonValue>(V) &&
+ IsAllowedScalarTy(V) &&
+ cast<Instruction>(V)->getOpcode() == ShuffleOrOp;
+ })) {
+ for (auto *Scalar : E->Scalars) {
+ Instruction *I = cast<Instruction>(Scalar);
+ auto IOperands = I->operand_values();
+ Operands.insert(Operands.end(), IOperands.begin(), IOperands.end());
+ }
+ }
+ return CommonCost +
+ TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
+ Op2Info, Operands, nullptr, TLI);
};
return GetCostDiff(GetScalarCost, GetVectorCost);
}
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/div.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/div.ll
index 57c877a58d5c0e..47630d9eb01fff 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/div.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/div.ll
@@ -554,35 +554,13 @@ define <4 x i32> @slp_v4i32_Op1_unknown_Op2_const_pow2(<4 x i32> %a)
; computes (a/const + x - y) * z
define <2 x i32> @vectorize_sdiv_v2i32(<2 x i32> %a, <2 x i32> %x, <2 x i32> %y, <2 x i32> %z)
-; NO-SVE-LABEL: define <2 x i32> @vectorize_sdiv_v2i32(
-; NO-SVE-SAME: <2 x i32> [[A:%.*]], <2 x i32> [[X:%.*]], <2 x i32> [[Y:%.*]], <2 x i32> [[Z:%.*]]) #[[ATTR0]] {
-; NO-SVE-NEXT: [[A0:%.*]] = extractelement <2 x i32> [[A]], i64 0
-; NO-SVE-NEXT: [[A1:%.*]] = extractelement <2 x i32> [[A]], i64 1
-; NO-SVE-NEXT: [[TMP1:%.*]] = sdiv i32 [[A0]], 2
-; NO-SVE-NEXT: [[TMP2:%.*]] = sdiv i32 [[A1]], 4
-; NO-SVE-NEXT: [[X0:%.*]] = extractelement <2 x i32> [[X]], i64 0
-; NO-SVE-NEXT: [[X1:%.*]] = extractelement <2 x i32> [[X]], i64 1
-; NO-SVE-NEXT: [[TMP3:%.*]] = add i32 [[TMP1]], [[X0]]
-; NO-SVE-NEXT: [[TMP4:%.*]] = add i32 [[TMP2]], [[X1]]
-; NO-SVE-NEXT: [[Y0:%.*]] = extractelement <2 x i32> [[Y]], i64 0
-; NO-SVE-NEXT: [[Y1:%.*]] = extractelement <2 x i32> [[Y]], i64 1
-; NO-SVE-NEXT: [[TMP5:%.*]] = sub i32 [[TMP3]], [[Y0]]
-; NO-SVE-NEXT: [[TMP6:%.*]] = sub i32 [[TMP4]], [[Y1]]
-; NO-SVE-NEXT: [[Z0:%.*]] = extractelement <2 x i32> [[Z]], i64 0
-; NO-SVE-NEXT: [[Z1:%.*]] = extractelement <2 x i32> [[Z]], i64 1
-; NO-SVE-NEXT: [[TMP7:%.*]] = mul i32 [[TMP5]], [[Z0]]
-; NO-SVE-NEXT: [[TMP8:%.*]] = mul i32 [[TMP6]], [[Z1]]
-; NO-SVE-NEXT: [[RES0:%.*]] = insertelement <2 x i32> poison, i32 [[TMP7]], i32 0
-; NO-SVE-NEXT: [[RES1:%.*]] = insertelement <2 x i32> [[RES0]], i32 [[TMP8]], i32 1
-; NO-SVE-NEXT: ret <2 x i32> [[RES1]]
-;
-; SVE-LABEL: define <2 x i32> @vectorize_sdiv_v2i32(
-; SVE-SAME: <2 x i32> [[A:%.*]], <2 x i32> [[X:%.*]], <2 x i32> [[Y:%.*]], <2 x i32> [[Z:%.*]]) #[[ATTR0]] {
-; SVE-NEXT: [[TMP1:%.*]] = sdiv <2 x i32> [[A]], <i32 2, i32 4>
-; SVE-NEXT: [[TMP2:%.*]] = add <2 x i32> [[TMP1]], [[X]]
-; SVE-NEXT: [[TMP3:%.*]] = sub <2 x i32> [[TMP2]], [[Y]]
-; SVE-NEXT: [[TMP4:%.*]] = mul <2 x i32> [[TMP3]], [[Z]]
-; SVE-NEXT: ret <2 x i32> [[TMP4]]
+; CHECK-LABEL: define <2 x i32> @vectorize_sdiv_v2i32(
+; CHECK-SAME: <2 x i32> [[A:%.*]], <2 x i32> [[X:%.*]], <2 x i32> [[Y:%.*]], <2 x i32> [[Z:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT: [[TMP1:%.*]] = sdiv <2 x i32> [[A]], <i32 2, i32 4>
+; CHECK-NEXT: [[TMP2:%.*]] = add <2 x i32> [[TMP1]], [[X]]
+; CHECK-NEXT: [[TMP3:%.*]] = sub <2 x i32> [[TMP2]], [[Y]]
+; CHECK-NEXT: [[TMP4:%.*]] = mul <2 x i32> [[TMP3]], [[Z]]
+; CHECK-NEXT: ret <2 x i32> [[TMP4]]
;
{
%a0 = extractelement <2 x i32> %a, i64 0
More information about the llvm-commits
mailing list