[llvm] [NFC][Scalarizer][TargetTransformInfo] Add `isVectorIntrinsicWithOverloadTypeAtArg` api (PR #114849)
Finn Plummer via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 4 10:30:23 PST 2024
https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/114849
This changes allows target intrinsic to specify and overwrite overloaded types.
This change will let us add scalarization for `asdouble`: #114847
>From b95d4f3c79f2a286bfb079599150abb04c76f39b Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 4 Nov 2024 17:38:36 +0000
Subject: [PATCH] [NFC][Scalarizer][TargetTransformInfo] Add
`isVectorIntrinsicWithOverloadTypeAtArg`
This changes allows target intrinsic to specify overloaded types.
This change will let us add scalarization for `asdouble`:
---
llvm/include/llvm/Analysis/TargetTransformInfo.h | 14 ++++++++++++++
.../llvm/Analysis/TargetTransformInfoImpl.h | 6 ++++++
llvm/include/llvm/CodeGen/BasicTTIImpl.h | 6 ++++++
llvm/lib/Analysis/TargetTransformInfo.cpp | 6 ++++++
.../Target/DirectX/DirectXTargetTransformInfo.cpp | 8 ++++++++
.../Target/DirectX/DirectXTargetTransformInfo.h | 3 +++
llvm/lib/Transforms/Scalar/Scalarizer.cpp | 9 ++++++---
7 files changed, 49 insertions(+), 3 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0459941fe05cdc..796b4011d71c0c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -896,6 +896,10 @@ class TargetTransformInfo {
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) const;
+ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+ unsigned ScalarOpdIdx,
+ bool Default) const;
+
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
/// are set if the demanded result elements need to be inserted and/or
/// extracted from vectors.
@@ -1969,6 +1973,9 @@ class TargetTransformInfo::Concept {
virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) = 0;
+ virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+ unsigned ScalarOpdIdx,
+ bool Default) = 0;
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
@@ -2530,6 +2537,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
}
+ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+ unsigned ScalarOpdIdx,
+ bool Default) override {
+ return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+ Default);
+ }
+
InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index dbdfb4d8cdfa32..42d1082cf4d9eb 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -392,6 +392,12 @@ class TargetTransformInfoImplBase {
return false;
}
+ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+ unsigned ScalarOpdIdx,
+ bool Default) const {
+ return Default;
+ }
+
InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index db3b5cddd7c1c3..b2841e778947dd 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -798,6 +798,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return false;
}
+ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+ unsigned ScalarOpdIdx,
+ bool Default) const {
+ return Default;
+ }
+
/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
bool Extract,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a47462b61e03b2..bf9733f971fdac 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -612,6 +612,12 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
}
+bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
+ Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) const {
+ return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+ Default);
+}
+
InstructionCost TargetTransformInfo::getScalarizationOverhead(
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 231afd8ae3eeaf..2be5fc4de2409c 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -25,6 +25,14 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
}
}
+bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
+ Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
+ switch (ID) {
+ default:
+ return Default;
+ }
+}
+
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
Intrinsic::ID ID) const {
switch (ID) {
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index 30b57ed97d6370..ff82b7404ca58a 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -37,6 +37,9 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx);
+ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+ unsigned ScalarOpdIdx,
+ bool Default);
};
} // namespace llvm
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 772f4c6c35ddec..719dce704872ae 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -727,7 +727,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
SmallVector<llvm::Type *, 3> Tys;
// Add return type if intrinsic is overloaded on it.
- if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+ if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+ ID, -1, isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)))
Tys.push_back(VS->SplitTy);
if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +768,15 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
}
Scattered[I] = scatter(&CI, OpI, *OpVS);
- if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
+ if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+ ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I))) {
OverloadIdx[I] = Tys.size();
Tys.push_back(OpVS->SplitTy);
}
} else {
ScalarOperands[I] = OpI;
- if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+ if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+ ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I)))
Tys.push_back(OpI->getType());
}
}
More information about the llvm-commits
mailing list