[llvm] [ARM][SLP] Fix cost function for SLP Vectorization of ZExt/SExt (PR #122713)

Nashe Mncube via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 30 05:33:08 PST 2025


https://github.com/nasherm updated https://github.com/llvm/llvm-project/pull/122713

>From dcf3cea6b1c297880a451a3772316175656279de Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Thu, 9 Jan 2025 14:04:39 +0000
Subject: [PATCH 1/2] [ARM][SLP] Fix incorrect cost function for SLP
 Vectorization of ZExt/SExt

PR #117350 made changes to the SLP vectorizer which introduced
a regression on ARM vectorization benchmarks. This was due
to the changes assuming that SExt/ZExt vector instructions have
constant cost. This behaviour is expected for RISCV but not on ARM
where we take into account source and destination type of SExt/ZExt
instructions when calculating vector cost.

Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
---
 .../lib/Target/ARM/ARMTargetTransformInfo.cpp |  1 -
 .../Transforms/SLPVectorizer/ARM/vadd-mve.ll  | 29 +++++++++++++++++++
 2 files changed, 29 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Transforms/SLPVectorizer/ARM/vadd-mve.ll

diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 639f3bf8fc62e30..0518059add28bf0 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1795,7 +1795,6 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
   case ISD::ADD:
     if (ST->hasMVEIntegerOps() && ValVT.isSimple() && ResVT.isSimple()) {
       std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
-
       // The legal cases are:
       //   VADDV u/s 8/16/32
       //   VADDLV u/s 32
diff --git a/llvm/test/Transforms/SLPVectorizer/ARM/vadd-mve.ll b/llvm/test/Transforms/SLPVectorizer/ARM/vadd-mve.ll
new file mode 100644
index 000000000000000..e0af0ca1e4f8ba3
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/ARM/vadd-mve.ll
@@ -0,0 +1,29 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=slp-vectorizer --mtriple arm-none-eabi -mattr=+mve -S -o - | FileCheck %s
+
+define i64 @vadd_32_64(ptr readonly %a) {
+; CHECK-LABEL: define i64 @vadd_32_64(
+; CHECK-SAME: ptr readonly [[A:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[A]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <4 x i32> [[TMP0]] to <4 x i64>
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP1]])
+; CHECK-NEXT:    ret i64 [[TMP2]]
+;
+entry:
+  %0 = load i32, ptr %a, align 4
+  %conv = sext i32 %0 to i64
+  %arrayidx1 = getelementptr inbounds nuw i8, ptr %a, i32 4
+  %1 = load i32, ptr %arrayidx1, align 4
+  %conv2 = sext i32 %1 to i64
+  %add = add nsw i64 %conv2, %conv
+  %arrayidx3 = getelementptr inbounds nuw i8, ptr %a, i32 8
+  %2 = load i32, ptr %arrayidx3, align 4
+  %conv4 = sext i32 %2 to i64
+  %add5 = add nsw i64 %add, %conv4
+  %arrayidx6 = getelementptr inbounds nuw i8, ptr %a, i32 12
+  %3 = load i32, ptr %arrayidx6, align 4
+  %conv7 = sext i32 %3 to i64
+  %add8 = add nsw i64 %add5, %conv7
+  ret i64 %add8
+}

>From 75be883e7fda0283d8ab48729d46d0bba997a8eb Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Thu, 30 Jan 2025 11:51:01 +0000
Subject: [PATCH 2/2] Adjust extended reduction cost

Change-Id: Ie2795e0e5ebb0589146eaf07c752410e307a36e6
---
 llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 0518059add28bf0..d59476704058db0 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1795,6 +1795,7 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
   case ISD::ADD:
     if (ST->hasMVEIntegerOps() && ValVT.isSimple() && ResVT.isSimple()) {
       std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
+
       // The legal cases are:
       //   VADDV u/s 8/16/32
       //   VADDLV u/s 32
@@ -1806,7 +1807,7 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
           ((LT.second == MVT::v16i8 && RevVTSize <= 32) ||
            (LT.second == MVT::v8i16 && RevVTSize <= 32) ||
            (LT.second == MVT::v4i32 && RevVTSize <= 64)))
-        return ST->getMVEVectorCostFactor(CostKind) * LT.first;
+        return 3 * ST->getMVEVectorCostFactor(CostKind) * LT.first;
     }
     break;
   default:



More information about the llvm-commits mailing list