[llvm] [ARM][SLP] Fix cost function for SLP Vectorization of ZExt/SExt (PR #122713)

Nashe Mncube via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 21 05:09:40 PST 2025


https://github.com/nasherm updated https://github.com/llvm/llvm-project/pull/122713

>From 28b9d6a18ead88930e0b8836f97c1161dd78aac2 Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Thu, 9 Jan 2025 14:04:39 +0000
Subject: [PATCH] [ARM][SLP] Fix incorrect cost function for SLP Vectorization
 of ZExt/SExt

PR #117350 made changes to the SLP vectorizer which introduced
a regression on ARM vectorization benchmarks. This was due
to the changes assuming that SExt/ZExt vector instructions have
constant cost. This behaviour is expected for RISCV but not on ARM
where we take into account source and destination type of SExt/ZExt
instructions when calculating vector cost.

Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
---
 .../lib/Target/ARM/ARMTargetTransformInfo.cpp |  26 +-
 ...nsive-arithmetic-extended-reduction-mve.ll | 285 ++++++++++++++++++
 2 files changed, 309 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/Transforms/SLPVectorizer/ARM/expensive-arithmetic-extended-reduction-mve.ll

diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 639f3bf8fc62e3..3e282639449f88 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1791,11 +1791,33 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
 
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
 
+  auto CastCost = [=]() -> unsigned {
+    // MVE extend costs, taken from codegen tests. i8->i16 or i16->i32 is one
+    // instruction, i8->i32 is two. i64 zexts are an VAND with a constant, sext
+    // are linearised so take more.
+    static const TypeConversionCostTblEntry MVEVectorConversionTbl[] = {
+        {ISD::SIGN_EXTEND, MVT::v8i16, MVT::v8i8, 1},
+        {ISD::ZERO_EXTEND, MVT::v8i16, MVT::v8i8, 1},
+        {ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i8, 2},
+        {ISD::ZERO_EXTEND, MVT::v4i32, MVT::v4i8, 2},
+        {ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i16, 1},
+        {ISD::ZERO_EXTEND, MVT::v4i32, MVT::v4i16, 1},
+    };
+
+    if (ST->hasMVEIntegerOps()) {
+      if (const auto *Entry = ConvertCostTableLookup(
+              MVEVectorConversionTbl,
+              (IsUnsigned) ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND,
+              ResVT.getSimpleVT(), ValVT.getSimpleVT()))
+        return Entry->Cost;
+    }
+    return 0;
+  };
+
   switch (ISD) {
   case ISD::ADD:
     if (ST->hasMVEIntegerOps() && ValVT.isSimple() && ResVT.isSimple()) {
       std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
-
       // The legal cases are:
       //   VADDV u/s 8/16/32
       //   VADDLV u/s 32
@@ -1807,7 +1829,7 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
           ((LT.second == MVT::v16i8 && RevVTSize <= 32) ||
            (LT.second == MVT::v8i16 && RevVTSize <= 32) ||
            (LT.second == MVT::v4i32 && RevVTSize <= 64)))
-        return ST->getMVEVectorCostFactor(CostKind) * LT.first;
+        return CastCost() + ST->getMVEVectorCostFactor(CostKind) * LT.first;
     }
     break;
   default:
diff --git a/llvm/test/Transforms/SLPVectorizer/ARM/expensive-arithmetic-extended-reduction-mve.ll b/llvm/test/Transforms/SLPVectorizer/ARM/expensive-arithmetic-extended-reduction-mve.ll
new file mode 100644
index 00000000000000..f84bc7dc076f1b
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/ARM/expensive-arithmetic-extended-reduction-mve.ll
@@ -0,0 +1,285 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes="default<O1>,slp-vectorizer" -S -mtriple=arm-none-eabi --mattr=+mve | FileCheck %s
+
+
+define dso_local i64 @vadd(ptr noundef %0) #0 {
+; CHECK-LABEL: define dso_local range(i64 -8589934592, 8589934589) i64 @vadd(
+; CHECK-SAME: ptr nocapture noundef readonly [[TMP0:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:    [[TMP2:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <4 x i32> [[TMP2]] to <4 x i64>
+; CHECK-NEXT:    [[TMP21:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP3]])
+; CHECK-NEXT:    ret i64 [[TMP21]]
+;
+  %2 = alloca ptr, align 4
+  store ptr %0, ptr %2, align 4
+  %3 = load ptr, ptr %2, align 4
+  %4 = getelementptr inbounds i32, ptr %3, i32 0
+  %5 = load i32, ptr %4, align 4
+  %6 = sext i32 %5 to i64
+  %7 = load ptr, ptr %2, align 4
+  %8 = getelementptr inbounds i32, ptr %7, i32 1
+  %9 = load i32, ptr %8, align 4
+  %10 = sext i32 %9 to i64
+  %11 = add nsw i64 %6, %10
+  %12 = load ptr, ptr %2, align 4
+  %13 = getelementptr inbounds i32, ptr %12, i32 2
+  %14 = load i32, ptr %13, align 4
+  %15 = sext i32 %14 to i64
+  %16 = add nsw i64 %11, %15
+  %17 = load ptr, ptr %2, align 4
+  %18 = getelementptr inbounds i32, ptr %17, i32 3
+  %19 = load i32, ptr %18, align 4
+  %20 = sext i32 %19 to i64
+  %21 = add nsw i64 %16, %20
+  ret i64 %21
+}
+
+define dso_local i64 @vmul(ptr noundef %0) #0 {
+; CHECK-LABEL: define dso_local i64 @vmul(
+; CHECK-SAME: ptr nocapture noundef readonly [[TMP0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; CHECK-NEXT:    [[TMP5:%.*]] = load i32, ptr [[TMP0]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = sext i32 [[TMP5]] to i64
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 4
+; CHECK-NEXT:    [[TMP9:%.*]] = load i32, ptr [[TMP8]], align 4
+; CHECK-NEXT:    [[TMP10:%.*]] = sext i32 [[TMP9]] to i64
+; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw i64 [[TMP10]], [[TMP6]]
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 8
+; CHECK-NEXT:    [[TMP14:%.*]] = load i32, ptr [[TMP13]], align 4
+; CHECK-NEXT:    [[TMP15:%.*]] = sext i32 [[TMP14]] to i64
+; CHECK-NEXT:    [[TMP16:%.*]] = mul nsw i64 [[TMP11]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 12
+; CHECK-NEXT:    [[TMP19:%.*]] = load i32, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP20:%.*]] = sext i32 [[TMP19]] to i64
+; CHECK-NEXT:    [[TMP21:%.*]] = mul nsw i64 [[TMP16]], [[TMP20]]
+; CHECK-NEXT:    ret i64 [[TMP21]]
+;
+  %2 = alloca ptr, align 4
+  store ptr %0, ptr %2, align 4
+  %3 = load ptr, ptr %2, align 4
+  %4 = getelementptr inbounds i32, ptr %3, i32 0
+  %5 = load i32, ptr %4, align 4
+  %6 = sext i32 %5 to i64
+  %7 = load ptr, ptr %2, align 4
+  %8 = getelementptr inbounds i32, ptr %7, i32 1
+  %9 = load i32, ptr %8, align 4
+  %10 = sext i32 %9 to i64
+  %11 = mul nsw i64 %6, %10
+  %12 = load ptr, ptr %2, align 4
+  %13 = getelementptr inbounds i32, ptr %12, i32 2
+  %14 = load i32, ptr %13, align 4
+  %15 = sext i32 %14 to i64
+  %16 = mul nsw i64 %11, %15
+  %17 = load ptr, ptr %2, align 4
+  %18 = getelementptr inbounds i32, ptr %17, i32 3
+  %19 = load i32, ptr %18, align 4
+  %20 = sext i32 %19 to i64
+  %21 = mul nsw i64 %16, %20
+  ret i64 %21
+}
+
+define dso_local i64 @vand(ptr noundef %0) #0 {
+; CHECK-LABEL: define dso_local range(i64 -2147483648, 2147483648) i64 @vand(
+; CHECK-SAME: ptr nocapture noundef readonly [[TMP0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[TMP0]], align 4
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 4
+; CHECK-NEXT:    [[TMP9:%.*]] = load i32, ptr [[TMP8]], align 4
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP9]], [[TMP2]]
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 8
+; CHECK-NEXT:    [[TMP14:%.*]] = load i32, ptr [[TMP13]], align 4
+; CHECK-NEXT:    [[TMP10:%.*]] = and i32 [[TMP5]], [[TMP14]]
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 12
+; CHECK-NEXT:    [[TMP19:%.*]] = load i32, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP11:%.*]] = and i32 [[TMP10]], [[TMP19]]
+; CHECK-NEXT:    [[TMP21:%.*]] = sext i32 [[TMP11]] to i64
+; CHECK-NEXT:    ret i64 [[TMP21]]
+;
+  %2 = alloca ptr, align 4
+  store ptr %0, ptr %2, align 4
+  %3 = load ptr, ptr %2, align 4
+  %4 = getelementptr inbounds i32, ptr %3, i32 0
+  %5 = load i32, ptr %4, align 4
+  %6 = sext i32 %5 to i64
+  %7 = load ptr, ptr %2, align 4
+  %8 = getelementptr inbounds i32, ptr %7, i32 1
+  %9 = load i32, ptr %8, align 4
+  %10 = sext i32 %9 to i64
+  %11 = and i64 %6, %10
+  %12 = load ptr, ptr %2, align 4
+  %13 = getelementptr inbounds i32, ptr %12, i32 2
+  %14 = load i32, ptr %13, align 4
+  %15 = sext i32 %14 to i64
+  %16 = and i64 %11, %15
+  %17 = load ptr, ptr %2, align 4
+  %18 = getelementptr inbounds i32, ptr %17, i32 3
+  %19 = load i32, ptr %18, align 4
+  %20 = sext i32 %19 to i64
+  %21 = and i64 %16, %20
+  ret i64 %21
+}
+
+define dso_local i64 @vor(ptr noundef %0) #0 {
+; CHECK-LABEL: define dso_local range(i64 -2147483648, 2147483648) i64 @vor(
+; CHECK-SAME: ptr nocapture noundef readonly [[TMP0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[TMP0]], align 4
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 4
+; CHECK-NEXT:    [[TMP9:%.*]] = load i32, ptr [[TMP8]], align 4
+; CHECK-NEXT:    [[TMP5:%.*]] = or i32 [[TMP9]], [[TMP2]]
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 8
+; CHECK-NEXT:    [[TMP14:%.*]] = load i32, ptr [[TMP13]], align 4
+; CHECK-NEXT:    [[TMP10:%.*]] = or i32 [[TMP5]], [[TMP14]]
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 12
+; CHECK-NEXT:    [[TMP19:%.*]] = load i32, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP11:%.*]] = or i32 [[TMP10]], [[TMP19]]
+; CHECK-NEXT:    [[TMP21:%.*]] = sext i32 [[TMP11]] to i64
+; CHECK-NEXT:    ret i64 [[TMP21]]
+;
+  %2 = alloca ptr, align 4
+  store ptr %0, ptr %2, align 4
+  %3 = load ptr, ptr %2, align 4
+  %4 = getelementptr inbounds i32, ptr %3, i32 0
+  %5 = load i32, ptr %4, align 4
+  %6 = sext i32 %5 to i64
+  %7 = load ptr, ptr %2, align 4
+  %8 = getelementptr inbounds i32, ptr %7, i32 1
+  %9 = load i32, ptr %8, align 4
+  %10 = sext i32 %9 to i64
+  %11 = or i64 %6, %10
+  %12 = load ptr, ptr %2, align 4
+  %13 = getelementptr inbounds i32, ptr %12, i32 2
+  %14 = load i32, ptr %13, align 4
+  %15 = sext i32 %14 to i64
+  %16 = or i64 %11, %15
+  %17 = load ptr, ptr %2, align 4
+  %18 = getelementptr inbounds i32, ptr %17, i32 3
+  %19 = load i32, ptr %18, align 4
+  %20 = sext i32 %19 to i64
+  %21 = or i64 %16, %20
+  ret i64 %21
+}
+
+define dso_local i64 @vxor(ptr noundef %0) #0 {
+; CHECK-LABEL: define dso_local range(i64 -2147483648, 2147483648) i64 @vxor(
+; CHECK-SAME: ptr nocapture noundef readonly [[TMP0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[TMP0]], align 4
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 4
+; CHECK-NEXT:    [[TMP9:%.*]] = load i32, ptr [[TMP8]], align 4
+; CHECK-NEXT:    [[TMP5:%.*]] = xor i32 [[TMP9]], [[TMP2]]
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 8
+; CHECK-NEXT:    [[TMP14:%.*]] = load i32, ptr [[TMP13]], align 4
+; CHECK-NEXT:    [[TMP10:%.*]] = xor i32 [[TMP5]], [[TMP14]]
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 12
+; CHECK-NEXT:    [[TMP19:%.*]] = load i32, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP11:%.*]] = xor i32 [[TMP10]], [[TMP19]]
+; CHECK-NEXT:    [[TMP21:%.*]] = sext i32 [[TMP11]] to i64
+; CHECK-NEXT:    ret i64 [[TMP21]]
+;
+  %2 = alloca ptr, align 4
+  store ptr %0, ptr %2, align 4
+  %3 = load ptr, ptr %2, align 4
+  %4 = getelementptr inbounds i32, ptr %3, i32 0
+  %5 = load i32, ptr %4, align 4
+  %6 = sext i32 %5 to i64
+  %7 = load ptr, ptr %2, align 4
+  %8 = getelementptr inbounds i32, ptr %7, i32 1
+  %9 = load i32, ptr %8, align 4
+  %10 = sext i32 %9 to i64
+  %11 = xor i64 %6, %10
+  %12 = load ptr, ptr %2, align 4
+  %13 = getelementptr inbounds i32, ptr %12, i32 2
+  %14 = load i32, ptr %13, align 4
+  %15 = sext i32 %14 to i64
+  %16 = xor i64 %11, %15
+  %17 = load ptr, ptr %2, align 4
+  %18 = getelementptr inbounds i32, ptr %17, i32 3
+  %19 = load i32, ptr %18, align 4
+  %20 = sext i32 %19 to i64
+  %21 = xor i64 %16, %20
+  ret i64 %21
+}
+
+define dso_local double @vfadd(ptr noundef %0) #0 {
+; CHECK-LABEL: define dso_local double @vfadd(
+; CHECK-SAME: ptr nocapture noundef readonly [[TMP0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; CHECK-NEXT:    [[TMP5:%.*]] = load float, ptr [[TMP0]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = fpext float [[TMP5]] to double
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 4
+; CHECK-NEXT:    [[TMP9:%.*]] = load float, ptr [[TMP8]], align 4
+; CHECK-NEXT:    [[TMP10:%.*]] = fpext float [[TMP9]] to double
+; CHECK-NEXT:    [[TMP11:%.*]] = fadd double [[TMP6]], [[TMP10]]
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 8
+; CHECK-NEXT:    [[TMP14:%.*]] = load float, ptr [[TMP13]], align 4
+; CHECK-NEXT:    [[TMP15:%.*]] = fpext float [[TMP14]] to double
+; CHECK-NEXT:    [[TMP16:%.*]] = fadd double [[TMP11]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 12
+; CHECK-NEXT:    [[TMP19:%.*]] = load float, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP20:%.*]] = fpext float [[TMP19]] to double
+; CHECK-NEXT:    [[TMP21:%.*]] = fadd double [[TMP16]], [[TMP20]]
+; CHECK-NEXT:    ret double [[TMP21]]
+;
+  %2 = alloca ptr, align 4
+  store ptr %0, ptr %2, align 4
+  %3 = load ptr, ptr %2, align 4
+  %4 = getelementptr inbounds float, ptr %3, i32 0
+  %5 = load float, ptr %4, align 4
+  %6 = fpext float %5 to double
+  %7 = load ptr, ptr %2, align 4
+  %8 = getelementptr inbounds float, ptr %7, i32 1
+  %9 = load float, ptr %8, align 4
+  %10 = fpext float %9 to double
+  %11 = fadd double %6, %10
+  %12 = load ptr, ptr %2, align 4
+  %13 = getelementptr inbounds float, ptr %12, i32 2
+  %14 = load float, ptr %13, align 4
+  %15 = fpext float %14 to double
+  %16 = fadd double %11, %15
+  %17 = load ptr, ptr %2, align 4
+  %18 = getelementptr inbounds float, ptr %17, i32 3
+  %19 = load float, ptr %18, align 4
+  %20 = fpext float %19 to double
+  %21 = fadd double %16, %20
+  ret double %21
+}
+
+define dso_local double @vfmul(ptr noundef %0) #0 {
+; CHECK-LABEL: define dso_local double @vfmul(
+; CHECK-SAME: ptr nocapture noundef readonly [[TMP0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; CHECK-NEXT:    [[TMP5:%.*]] = load float, ptr [[TMP0]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = fpext float [[TMP5]] to double
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 4
+; CHECK-NEXT:    [[TMP9:%.*]] = load float, ptr [[TMP8]], align 4
+; CHECK-NEXT:    [[TMP10:%.*]] = fpext float [[TMP9]] to double
+; CHECK-NEXT:    [[TMP11:%.*]] = fmul double [[TMP6]], [[TMP10]]
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 8
+; CHECK-NEXT:    [[TMP14:%.*]] = load float, ptr [[TMP13]], align 4
+; CHECK-NEXT:    [[TMP15:%.*]] = fpext float [[TMP14]] to double
+; CHECK-NEXT:    [[TMP16:%.*]] = fmul double [[TMP11]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0]], i32 12
+; CHECK-NEXT:    [[TMP19:%.*]] = load float, ptr [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP20:%.*]] = fpext float [[TMP19]] to double
+; CHECK-NEXT:    [[TMP21:%.*]] = fmul double [[TMP16]], [[TMP20]]
+; CHECK-NEXT:    ret double [[TMP21]]
+;
+  %2 = alloca ptr, align 4
+  store ptr %0, ptr %2, align 4
+  %3 = load ptr, ptr %2, align 4
+  %4 = getelementptr inbounds float, ptr %3, i32 0
+  %5 = load float, ptr %4, align 4
+  %6 = fpext float %5 to double
+  %7 = load ptr, ptr %2, align 4
+  %8 = getelementptr inbounds float, ptr %7, i32 1
+  %9 = load float, ptr %8, align 4
+  %10 = fpext float %9 to double
+  %11 = fmul double %6, %10
+  %12 = load ptr, ptr %2, align 4
+  %13 = getelementptr inbounds float, ptr %12, i32 2
+  %14 = load float, ptr %13, align 4
+  %15 = fpext float %14 to double
+  %16 = fmul double %11, %15
+  %17 = load ptr, ptr %2, align 4
+  %18 = getelementptr inbounds float, ptr %17, i32 3
+  %19 = load float, ptr %18, align 4
+  %20 = fpext float %19 to double
+  %21 = fmul double %16, %20
+  ret double %21
+}
+



More information about the llvm-commits mailing list