[llvm] [NVPTX] Customize getScalarizationOverhead (PR #128077)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 28 15:07:37 PDT 2025


https://github.com/peterbell10 updated https://github.com/llvm/llvm-project/pull/128077

>From 7bf8d930e361ac175ffa434ff4299561d6a7e797 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Thu, 20 Feb 2025 21:07:31 +0000
Subject: [PATCH 1/3] [SLPVectorizer][NVPTX] Customize getBuildVectorCost for
 NVPTX

We've observed that the SLPVectorizer is too conservative on NVPTX because it
over-estimates the cost to build a vector. PTX has a single `mov` instruction
that can build <2 x half> vectors from scalars, however the SLPVectorizer
estimates the cost as 2 insert elements.

To fix this I add `TargetTransformInfo::getBuildVectorCost` so the
target can optionally specify the exact cost.
---
 .../llvm/Analysis/TargetTransformInfo.h       |  16 ++
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   6 +
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |   6 +
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   7 +
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   |  23 ++-
 .../Transforms/Vectorize/SLPVectorizer.cpp    |   3 +
 .../Transforms/SLPVectorizer/NVPTX/v2f16.ll   | 156 ++++++++++++------
 7 files changed, 170 insertions(+), 47 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 99e21aca97631..7f45ed77775d6 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1482,6 +1482,12 @@ class TargetTransformInfo {
   InstructionCost getInsertExtractValueCost(unsigned Opcode,
                                             TTI::TargetCostKind CostKind) const;
 
+  /// \return The cost of ISD::BUILD_VECTOR, or nullopt if the cost should be
+  /// inferred from insert element and shuffle ops.
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TargetCostKind CostKind) const;
+
   /// \return The cost of replication shuffle of \p VF elements typed \p EltTy
   /// \p ReplicationFactor times.
   ///
@@ -2219,6 +2225,10 @@ class TargetTransformInfo::Concept {
                                              TTI::TargetCostKind CostKind,
                                              unsigned Index) = 0;
 
+  virtual std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TargetCostKind CostKind) = 0;
+
   virtual InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                             const APInt &DemandedDstElts,
@@ -2947,6 +2957,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                      unsigned Index) override {
     return Impl.getVectorInstrCost(I, Val, CostKind, Index);
   }
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) override {
+    return Impl.getBuildVectorCost(VecTy, Operands, CostKind);
+  }
+
   InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                             const APInt &DemandedDstElts,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 745758426c714..f9e765741595a 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -741,6 +741,12 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *Val, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) const {
+    return std::nullopt;
+  }
+
   unsigned getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                                      const APInt &DemandedDstElts,
                                      TTI::TargetCostKind CostKind) {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index eacf75c24695f..93c51614857dc 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1432,6 +1432,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                        Op1);
   }
 
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) {
+    return std::nullopt;
+  }
+
   InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
                                             int VF,
                                             const APInt &DemandedDstElts,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 4df551aca30a7..a9a9f7854f50c 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1124,6 +1124,13 @@ InstructionCost TargetTransformInfo::getInsertExtractValueCost(
   return Cost;
 }
 
+std::optional<InstructionCost>
+TargetTransformInfo::getBuildVectorCost(VectorType *VecTy,
+                                        ArrayRef<Value *> Operands,
+                                        TargetCostKind CostKind) const {
+  return TTIImpl->getBuildVectorCost(VecTy, Operands, CostKind);
+}
+
 InstructionCost TargetTransformInfo::getReplicationShuffleCost(
     Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedDstElts,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 6db36e958b28c..f623f2a540536 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -16,8 +16,9 @@
 #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
 #define LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
 
-#include "NVPTXTargetMachine.h"
 #include "MCTargetDesc/NVPTXBaseInfo.h"
+#include "NVPTXTargetMachine.h"
+#include "NVPTXUtilities.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/CodeGen/TargetLowering.h"
@@ -104,6 +105,26 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
       ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);
 
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) {
+    if (CostKind != TTI::TCK_RecipThroughput)
+      return std::nullopt;
+    auto VT = getTLI()->getValueType(DL, VecTy);
+    if (all_of(Operands, [](Value *Op) { return isa<Constant>(Op); }))
+      return TTI::TCC_Free;
+    if (Isv2x16VT(VT))
+      return 1; // Single vector mov
+    if (VT == MVT::v4i8) {
+      InstructionCost Cost = 3; // 3 x PRMT
+      for (auto *Op : Operands)
+        if (!isa<Constant>(Op))
+          Cost += 1; // zext operand to i32
+      return Cost;
+    }
+    return std::nullopt;
+  }
+
   void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
                                TTI::UnrollingPreferences &UP,
                                OptimizationRemarkEmitter *ORE);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f29fb6780253b..f732ef7b88195 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -11096,6 +11096,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>))
       return TTI::TCC_Free;
     auto *VecTy = getWidenedType(ScalarTy, VL.size());
+    if (auto Cost = TTI.getBuildVectorCost(VecTy, VL, CostKind);
+        Cost.has_value())
+      return *Cost;
     InstructionCost GatherCost = 0;
     SmallVector<Value *> Gathers(VL);
     if (!Root && isSplat(VL)) {
diff --git a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
index 13773bf901b9b..c74909d7ceb2a 100644
--- a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
+++ b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
@@ -1,59 +1,123 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_40 | FileCheck %s -check-prefix=NOVECTOR
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefixes=VECTOR,SM90
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefixes=VECTOR,SM80
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefixes=VECTOR,SM70
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefixes=NOVECTOR,SM50
 
 define void @fusion(ptr noalias nocapture align 256 dereferenceable(19267584) %arg, ptr noalias nocapture readonly align 256 dereferenceable(19267584) %arg1, i32 %arg2, i32 %arg3) local_unnamed_addr #0 {
-; CHECK-LABEL: @fusion(
-; CHECK-NEXT:    [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
-; CHECK-NEXT:    [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
-; CHECK-NEXT:    [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
-; CHECK-NEXT:    [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
-; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
-; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x half>, ptr [[TMP11]], align 8
-; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast <2 x half> [[TMP1]], splat (half 0xH5380)
-; CHECK-NEXT:    [[TMP3:%.*]] = fadd fast <2 x half> [[TMP2]], splat (half 0xH57F0)
-; CHECK-NEXT:    store <2 x half> [[TMP3]], ptr [[TMP16]], align 8
-; CHECK-NEXT:    ret void
+; VECTOR-LABEL: @fusion(
+; VECTOR-NEXT:    [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
+; VECTOR-NEXT:    [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
+; VECTOR-NEXT:    [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
+; VECTOR-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; VECTOR-NEXT:    [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
+; VECTOR-NEXT:    [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
+; VECTOR-NEXT:    [[TMP7:%.*]] = load <2 x half>, ptr [[TMP5]], align 8
+; VECTOR-NEXT:    [[TMP8:%.*]] = fmul fast <2 x half> [[TMP7]], splat (half 0xH5380)
+; VECTOR-NEXT:    [[TMP9:%.*]] = fadd fast <2 x half> [[TMP8]], splat (half 0xH57F0)
+; VECTOR-NEXT:    store <2 x half> [[TMP9]], ptr [[TMP6]], align 8
+; VECTOR-NEXT:    ret void
 ;
 ; NOVECTOR-LABEL: @fusion(
-; NOVECTOR-NEXT:    [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
-; NOVECTOR-NEXT:    [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
-; NOVECTOR-NEXT:    [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
-; NOVECTOR-NEXT:    [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; NOVECTOR-NEXT:    [[TMP7:%.*]] = or disjoint i64 [[TMP6]], 1
-; NOVECTOR-NEXT:    [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
-; NOVECTOR-NEXT:    [[TMP12:%.*]] = load half, ptr [[TMP11]], align 8
+; NOVECTOR-NEXT:    [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
+; NOVECTOR-NEXT:    [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
+; NOVECTOR-NEXT:    [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
+; NOVECTOR-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; NOVECTOR-NEXT:    [[TMP10:%.*]] = or disjoint i64 [[TMP4]], 1
+; NOVECTOR-NEXT:    [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
+; NOVECTOR-NEXT:    [[TMP7:%.*]] = load half, ptr [[TMP5]], align 8
+; NOVECTOR-NEXT:    [[TMP8:%.*]] = fmul fast half [[TMP7]], 0xH5380
+; NOVECTOR-NEXT:    [[TMP9:%.*]] = fadd fast half [[TMP8]], 0xH57F0
+; NOVECTOR-NEXT:    [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
+; NOVECTOR-NEXT:    store half [[TMP9]], ptr [[TMP6]], align 8
+; NOVECTOR-NEXT:    [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP10]]
+; NOVECTOR-NEXT:    [[TMP12:%.*]] = load half, ptr [[TMP11]], align 2
 ; NOVECTOR-NEXT:    [[TMP13:%.*]] = fmul fast half [[TMP12]], 0xH5380
 ; NOVECTOR-NEXT:    [[TMP14:%.*]] = fadd fast half [[TMP13]], 0xH57F0
-; NOVECTOR-NEXT:    [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
-; NOVECTOR-NEXT:    store half [[TMP14]], ptr [[TMP16]], align 8
-; NOVECTOR-NEXT:    [[TMP17:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP7]]
-; NOVECTOR-NEXT:    [[TMP18:%.*]] = load half, ptr [[TMP17]], align 2
-; NOVECTOR-NEXT:    [[TMP19:%.*]] = fmul fast half [[TMP18]], 0xH5380
-; NOVECTOR-NEXT:    [[TMP20:%.*]] = fadd fast half [[TMP19]], 0xH57F0
-; NOVECTOR-NEXT:    [[TMP21:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP7]]
-; NOVECTOR-NEXT:    store half [[TMP20]], ptr [[TMP21]], align 2
+; NOVECTOR-NEXT:    [[TMP15:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP10]]
+; NOVECTOR-NEXT:    store half [[TMP14]], ptr [[TMP15]], align 2
 ; NOVECTOR-NEXT:    ret void
 ;
-  %tmp = shl nuw nsw i32 %arg2, 6
-  %tmp4 = or i32 %tmp, %arg3
-  %tmp5 = shl nuw nsw i32 %tmp4, 2
-  %tmp6 = zext i32 %tmp5 to i64
-  %tmp7 = or disjoint i64 %tmp6, 1
-  %tmp11 = getelementptr inbounds half, ptr %arg1, i64 %tmp6
-  %tmp12 = load half, ptr %tmp11, align 8
-  %tmp13 = fmul fast half %tmp12, 0xH5380
-  %tmp14 = fadd fast half %tmp13, 0xH57F0
-  %tmp16 = getelementptr inbounds half, ptr %arg, i64 %tmp6
-  store half %tmp14, ptr %tmp16, align 8
-  %tmp17 = getelementptr inbounds half, ptr %arg1, i64 %tmp7
-  %tmp18 = load half, ptr %tmp17, align 2
-  %tmp19 = fmul fast half %tmp18, 0xH5380
-  %tmp20 = fadd fast half %tmp19, 0xH57F0
-  %tmp21 = getelementptr inbounds half, ptr %arg, i64 %tmp7
-  store half %tmp20, ptr %tmp21, align 2
+  %1 = shl nuw nsw i32 %arg2, 6
+  %4 = or i32 %1, %arg3
+  %5 = shl nuw nsw i32 %4, 2
+  %6 = zext i32 %5 to i64
+  %7 = or disjoint i64 %6, 1
+  %11 = getelementptr inbounds half, ptr %arg1, i64 %6
+  %12 = load half, ptr %11, align 8
+  %13 = fmul fast half %12, 0xH5380
+  %14 = fadd fast half %13, 0xH57F0
+  %16 = getelementptr inbounds half, ptr %arg, i64 %6
+  store half %14, ptr %16, align 8
+  %17 = getelementptr inbounds half, ptr %arg1, i64 %7
+  %18 = load half, ptr %17, align 2
+  %19 = fmul fast half %18, 0xH5380
+  %20 = fadd fast half %19, 0xH57F0
+  %21 = getelementptr inbounds half, ptr %arg, i64 %7
+  store half %20, ptr %21, align 2
   ret void
 }
 
+define ptx_kernel void @add_f16(ptr addrspace(1) %0, { half, half } %1, { half, half } %2) {
+; VECTOR-LABEL: @add_f16(
+; VECTOR-NEXT:    [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
+; VECTOR-NEXT:    [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
+; VECTOR-NEXT:    [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
+; VECTOR-NEXT:    [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
+; VECTOR-NEXT:    [[TMP8:%.*]] = insertelement <2 x half> poison, half [[TMP4]], i32 0
+; VECTOR-NEXT:    [[TMP9:%.*]] = insertelement <2 x half> [[TMP8]], half [[TMP5]], i32 1
+; VECTOR-NEXT:    [[TMP10:%.*]] = insertelement <2 x half> poison, half [[TMP6]], i32 0
+; VECTOR-NEXT:    [[TMP11:%.*]] = insertelement <2 x half> [[TMP10]], half [[TMP7]], i32 1
+; VECTOR-NEXT:    [[TMP12:%.*]] = fadd <2 x half> [[TMP9]], [[TMP11]]
+; VECTOR-NEXT:    [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; VECTOR-NEXT:    [[TMP14:%.*]] = shl i32 [[TMP13]], 1
+; VECTOR-NEXT:    [[TMP15:%.*]] = and i32 [[TMP14]], 62
+; VECTOR-NEXT:    [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
+; VECTOR-NEXT:    [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
+; VECTOR-NEXT:    store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
+; VECTOR-NEXT:    ret void
+;
+; NOVECTOR-LABEL: @add_f16(
+; NOVECTOR-NEXT:    [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
+; NOVECTOR-NEXT:    [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
+; NOVECTOR-NEXT:    [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
+; NOVECTOR-NEXT:    [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
+; NOVECTOR-NEXT:    [[TMP8:%.*]] = fadd half [[TMP4]], [[TMP6]]
+; NOVECTOR-NEXT:    [[TMP9:%.*]] = fadd half [[TMP5]], [[TMP7]]
+; NOVECTOR-NEXT:    [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; NOVECTOR-NEXT:    [[TMP14:%.*]] = shl i32 [[TMP13]], 1
+; NOVECTOR-NEXT:    [[TMP15:%.*]] = and i32 [[TMP14]], 62
+; NOVECTOR-NEXT:    [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
+; NOVECTOR-NEXT:    [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
+; NOVECTOR-NEXT:    [[TMP19:%.*]] = insertelement <2 x half> poison, half [[TMP8]], i64 0
+; NOVECTOR-NEXT:    [[TMP12:%.*]] = insertelement <2 x half> [[TMP19]], half [[TMP9]], i64 1
+; NOVECTOR-NEXT:    store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
+; NOVECTOR-NEXT:    ret void
+;
+  %5 = extractvalue { half, half } %1, 0
+  %6 = extractvalue { half, half } %1, 1
+  %7 = extractvalue { half, half } %2, 0
+  %8 = extractvalue { half, half } %2, 1
+  %9 = fadd half %5, %7
+  %10 = fadd half %6, %8
+  %11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %12 = shl i32 %11, 1
+  %13 = and i32 %12, 62
+  %14 = zext nneg i32 %13 to i64
+  %15 = getelementptr half, ptr addrspace(1) %0, i64 %14
+  %18 = insertelement <2 x half> poison, half %9, i64 0
+  %19 = insertelement <2 x half> %18, half %10, i64 1
+  store <2 x half> %19, ptr addrspace(1) %15, align 4
+  ret void
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
+
 attributes #0 = { nounwind }
+attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; SM50: {{.*}}
+; SM70: {{.*}}
+; SM80: {{.*}}
+; SM90: {{.*}}

>From 9134ff1d237cf65e310a4fddfd12ade6a48c1559 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Fri, 21 Feb 2025 18:53:47 +0000
Subject: [PATCH 2/3] Revert TTI changes, customize getScalarizationOverhead
 instead

---
 .../llvm/Analysis/TargetTransformInfo.h       | 16 -------
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  6 ---
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |  6 ---
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  7 ---
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   | 46 +++++++++++++------
 .../Transforms/Vectorize/SLPVectorizer.cpp    |  3 --
 6 files changed, 31 insertions(+), 53 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 7f45ed77775d6..99e21aca97631 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1482,12 +1482,6 @@ class TargetTransformInfo {
   InstructionCost getInsertExtractValueCost(unsigned Opcode,
                                             TTI::TargetCostKind CostKind) const;
 
-  /// \return The cost of ISD::BUILD_VECTOR, or nullopt if the cost should be
-  /// inferred from insert element and shuffle ops.
-  std::optional<InstructionCost>
-  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
-                     TargetCostKind CostKind) const;
-
   /// \return The cost of replication shuffle of \p VF elements typed \p EltTy
   /// \p ReplicationFactor times.
   ///
@@ -2225,10 +2219,6 @@ class TargetTransformInfo::Concept {
                                              TTI::TargetCostKind CostKind,
                                              unsigned Index) = 0;
 
-  virtual std::optional<InstructionCost>
-  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
-                     TargetCostKind CostKind) = 0;
-
   virtual InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                             const APInt &DemandedDstElts,
@@ -2957,12 +2947,6 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                      unsigned Index) override {
     return Impl.getVectorInstrCost(I, Val, CostKind, Index);
   }
-  std::optional<InstructionCost>
-  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
-                     TTI::TargetCostKind CostKind) override {
-    return Impl.getBuildVectorCost(VecTy, Operands, CostKind);
-  }
-
   InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                             const APInt &DemandedDstElts,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index f9e765741595a..745758426c714 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -741,12 +741,6 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
-  std::optional<InstructionCost>
-  getBuildVectorCost(VectorType *Val, ArrayRef<Value *> Operands,
-                     TTI::TargetCostKind CostKind) const {
-    return std::nullopt;
-  }
-
   unsigned getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                                      const APInt &DemandedDstElts,
                                      TTI::TargetCostKind CostKind) {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 93c51614857dc..eacf75c24695f 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1432,12 +1432,6 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                        Op1);
   }
 
-  std::optional<InstructionCost>
-  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
-                     TTI::TargetCostKind CostKind) {
-    return std::nullopt;
-  }
-
   InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
                                             int VF,
                                             const APInt &DemandedDstElts,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a9a9f7854f50c..4df551aca30a7 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1124,13 +1124,6 @@ InstructionCost TargetTransformInfo::getInsertExtractValueCost(
   return Cost;
 }
 
-std::optional<InstructionCost>
-TargetTransformInfo::getBuildVectorCost(VectorType *VecTy,
-                                        ArrayRef<Value *> Operands,
-                                        TargetCostKind CostKind) const {
-  return TTIImpl->getBuildVectorCost(VecTy, Operands, CostKind);
-}
-
 InstructionCost TargetTransformInfo::getReplicationShuffleCost(
     Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedDstElts,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index f623f2a540536..9e77f628da7a7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -105,24 +105,40 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
       ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);
 
-  std::optional<InstructionCost>
-  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
-                     TTI::TargetCostKind CostKind) {
-    if (CostKind != TTI::TCK_RecipThroughput)
-      return std::nullopt;
-    auto VT = getTLI()->getValueType(DL, VecTy);
-    if (all_of(Operands, [](Value *Op) { return isa<Constant>(Op); }))
-      return TTI::TCC_Free;
-    if (Isv2x16VT(VT))
-      return 1; // Single vector mov
-    if (VT == MVT::v4i8) {
+  InstructionCost getScalarizationOverhead(VectorType *InTy,
+                                           const APInt &DemandedElts,
+                                           bool Insert, bool Extract,
+                                           TTI::TargetCostKind CostKind,
+                                           ArrayRef<Value *> VL = {}) {
+    if (!InTy->getElementCount().isFixed())
+      return InstructionCost::getInvalid();
+
+    auto VT = getTLI()->getValueType(DL, InTy);
+    auto NumElements = InTy->getElementCount().getFixedValue();
+    InstructionCost Cost = 0;
+    if (Insert && !VL.empty()) {
+      bool AllConstant = all_of(seq(NumElements), [&](int Idx) {
+        return !DemandedElts[Idx] || isa<Constant>(VL[Idx]);
+      });
+      if (AllConstant) {
+        Cost += TTI::TCC_Free;
+        Insert = false;
+      }
+    }
+    if (Insert && Isv2x16VT(VT)) {
+      // Can be built in a single mov
+      Cost += 1;
+      Insert = false;
+    }
+    if (Insert && VT == MVT::v4i8) {
       InstructionCost Cost = 3; // 3 x PRMT
-      for (auto *Op : Operands)
-        if (!isa<Constant>(Op))
+      for (auto Idx : seq(NumElements))
+        if (DemandedElts[Idx])
           Cost += 1; // zext operand to i32
-      return Cost;
+      Insert = false;
     }
-    return std::nullopt;
+    return Cost + BaseT::getScalarizationOverhead(InTy, DemandedElts, Insert,
+                                                  Extract, CostKind, VL);
   }
 
   void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f732ef7b88195..f29fb6780253b 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -11096,9 +11096,6 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>))
       return TTI::TCC_Free;
     auto *VecTy = getWidenedType(ScalarTy, VL.size());
-    if (auto Cost = TTI.getBuildVectorCost(VecTy, VL, CostKind);
-        Cost.has_value())
-      return *Cost;
     InstructionCost GatherCost = 0;
     SmallVector<Value *> Gathers(VL);
     if (!Root && isSplat(VL)) {

>From 52f75fcefd46845d2f563f5d537276bf33e333b9 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Fri, 21 Feb 2025 19:10:02 +0000
Subject: [PATCH 3/3] Cleanup unused check labels

---
 llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
index c74909d7ceb2a..71979f32080f2 100644
--- a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
+++ b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
@@ -1,8 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefixes=VECTOR,SM90
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefixes=VECTOR,SM80
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefixes=VECTOR,SM70
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefixes=NOVECTOR,SM50
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefix=VECTOR
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefix=VECTOR
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefix=VECTOR
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefix=NOVECTOR
 
 define void @fusion(ptr noalias nocapture align 256 dereferenceable(19267584) %arg, ptr noalias nocapture readonly align 256 dereferenceable(19267584) %arg1, i32 %arg2, i32 %arg3) local_unnamed_addr #0 {
 ; VECTOR-LABEL: @fusion(
@@ -116,8 +116,3 @@ declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
 
 attributes #0 = { nounwind }
 attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; SM50: {{.*}}
-; SM70: {{.*}}
-; SM80: {{.*}}
-; SM90: {{.*}}



More information about the llvm-commits mailing list