[llvm] [ARM][SLP] Fix cost function for SLP Vectorization of ZExt/SExt (PR #122713)
Nashe Mncube via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 30 05:33:08 PST 2025
https://github.com/nasherm updated https://github.com/llvm/llvm-project/pull/122713
>From dcf3cea6b1c297880a451a3772316175656279de Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Thu, 9 Jan 2025 14:04:39 +0000
Subject: [PATCH 1/2] [ARM][SLP] Fix incorrect cost function for SLP
Vectorization of ZExt/SExt
PR #117350 made changes to the SLP vectorizer which introduced
a regression on ARM vectorization benchmarks. This was due
to the changes assuming that SExt/ZExt vector instructions have
constant cost. This behaviour is expected for RISCV but not on ARM
where we take into account source and destination type of SExt/ZExt
instructions when calculating vector cost.
Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
---
.../lib/Target/ARM/ARMTargetTransformInfo.cpp | 1 -
.../Transforms/SLPVectorizer/ARM/vadd-mve.ll | 29 +++++++++++++++++++
2 files changed, 29 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/Transforms/SLPVectorizer/ARM/vadd-mve.ll
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 639f3bf8fc62e30..0518059add28bf0 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1795,7 +1795,6 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
case ISD::ADD:
if (ST->hasMVEIntegerOps() && ValVT.isSimple() && ResVT.isSimple()) {
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
-
// The legal cases are:
// VADDV u/s 8/16/32
// VADDLV u/s 32
diff --git a/llvm/test/Transforms/SLPVectorizer/ARM/vadd-mve.ll b/llvm/test/Transforms/SLPVectorizer/ARM/vadd-mve.ll
new file mode 100644
index 000000000000000..e0af0ca1e4f8ba3
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/ARM/vadd-mve.ll
@@ -0,0 +1,29 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=slp-vectorizer --mtriple arm-none-eabi -mattr=+mve -S -o - | FileCheck %s
+
+define i64 @vadd_32_64(ptr readonly %a) {
+; CHECK-LABEL: define i64 @vadd_32_64(
+; CHECK-SAME: ptr readonly [[A:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[A]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = sext <4 x i32> [[TMP0]] to <4 x i64>
+; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP1]])
+; CHECK-NEXT: ret i64 [[TMP2]]
+;
+entry:
+ %0 = load i32, ptr %a, align 4
+ %conv = sext i32 %0 to i64
+ %arrayidx1 = getelementptr inbounds nuw i8, ptr %a, i32 4
+ %1 = load i32, ptr %arrayidx1, align 4
+ %conv2 = sext i32 %1 to i64
+ %add = add nsw i64 %conv2, %conv
+ %arrayidx3 = getelementptr inbounds nuw i8, ptr %a, i32 8
+ %2 = load i32, ptr %arrayidx3, align 4
+ %conv4 = sext i32 %2 to i64
+ %add5 = add nsw i64 %add, %conv4
+ %arrayidx6 = getelementptr inbounds nuw i8, ptr %a, i32 12
+ %3 = load i32, ptr %arrayidx6, align 4
+ %conv7 = sext i32 %3 to i64
+ %add8 = add nsw i64 %add5, %conv7
+ ret i64 %add8
+}
>From 75be883e7fda0283d8ab48729d46d0bba997a8eb Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Thu, 30 Jan 2025 11:51:01 +0000
Subject: [PATCH 2/2] Adjust extended reduction cost
Change-Id: Ie2795e0e5ebb0589146eaf07c752410e307a36e6
---
llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 0518059add28bf0..d59476704058db0 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1795,6 +1795,7 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
case ISD::ADD:
if (ST->hasMVEIntegerOps() && ValVT.isSimple() && ResVT.isSimple()) {
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
+
// The legal cases are:
// VADDV u/s 8/16/32
// VADDLV u/s 32
@@ -1806,7 +1807,7 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
((LT.second == MVT::v16i8 && RevVTSize <= 32) ||
(LT.second == MVT::v8i16 && RevVTSize <= 32) ||
(LT.second == MVT::v4i32 && RevVTSize <= 64)))
- return ST->getMVEVectorCostFactor(CostKind) * LT.first;
+ return 3 * ST->getMVEVectorCostFactor(CostKind) * LT.first;
}
break;
default:
More information about the llvm-commits
mailing list