[llvm] [ConstantFolding] Fold intrinsics of scalable vectors with splatted operands (PR #141845)
via llvm-commits
llvm-commits at lists.llvm.org
Wed May 28 13:14:49 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Luke Lau (lukel97)
<details>
<summary>Changes</summary>
As noted in https://github.com/llvm/llvm-project/pull/141821#issuecomment-2917328924, whilst we currently constant fold intrinsics of fixed-length vectors via their scalar counterpart, we don't do the same for scalable vectors.
This handles the scalable vector case when the operands are splats.
One weird snag in ConstantVector::getSplat was that it produced a undef if passed in poison, so this also contains a fix by checking for PoisonValue before UndefValue.
---
Full diff: https://github.com/llvm/llvm-project/pull/141845.diff
5 Files Affected:
- (modified) llvm/lib/Analysis/ConstantFolding.cpp (+29-1)
- (modified) llvm/lib/IR/Constants.cpp (+3-1)
- (modified) llvm/test/Transforms/InstSimplify/ConstProp/abs.ll (+8)
- (modified) llvm/test/Transforms/InstSimplify/ConstProp/fma.ll (+8)
- (modified) llvm/test/Transforms/InstSimplify/exp10.ll (+1-2)
``````````diff
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 412a0e8979193..40302fbc8ee52 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -3780,7 +3780,35 @@ static Constant *ConstantFoldScalableVectorCall(
default:
break;
}
- return nullptr;
+
+ // If trivially vectorizable, try folding it via the scalar call if all
+ // operands are splats.
+
+ // TODO: ConstantFoldFixedVectorCall should probably check this too?
+ if (!isTriviallyVectorizable(IntrinsicID))
+ return nullptr;
+
+ SmallVector<Constant *, 4> SplatOps;
+ for (auto [I, Op] : enumerate(Operands)) {
+ if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, I, /*TTI=*/nullptr)) {
+ SplatOps.push_back(Op);
+ continue;
+ }
+ // TODO: Should getSplatValue return a poison scalar for a poison vector?
+ if (isa<PoisonValue>(Op)) {
+ SplatOps.push_back(PoisonValue::get(Op->getType()->getScalarType()));
+ continue;
+ }
+ Constant *Splat = Op->getSplatValue();
+ if (!Splat)
+ return nullptr;
+ SplatOps.push_back(Splat);
+ }
+ Constant *Folded = ConstantFoldScalarCall(
+ Name, IntrinsicID, SVTy->getElementType(), SplatOps, TLI, Call);
+ if (!Folded)
+ return nullptr;
+ return ConstantVector::getSplat(SVTy->getElementCount(), Folded);
}
static std::pair<Constant *, Constant *>
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index b2087d3651143..fa453309b34ee 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1507,7 +1507,9 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
if (V->isNullValue())
return ConstantAggregateZero::get(VTy);
- else if (isa<UndefValue>(V))
+ if (isa<PoisonValue>(V))
+ return PoisonValue::get(VTy);
+ if (isa<UndefValue>(V))
return UndefValue::get(VTy);
Type *IdxTy = Type::getInt64Ty(VTy->getContext());
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
index 37233b0f29342..615ab10248b2a 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
@@ -43,3 +43,11 @@ define <8 x i8> @vec_const() {
%r = call <8 x i8> @llvm.abs.v8i8(<8 x i8> <i8 -127, i8 -126, i8 -42, i8 -1, i8 0, i8 1, i8 42, i8 127>, i1 1)
ret <8 x i8> %r
}
+
+define <vscale x 1 x i8> @scalable_vec_const() {
+; CHECK-LABEL: @scalable_vec_const(
+; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 42)
+;
+ %r = call <vscale x 1 x i8> @llvm.abs(<vscale x 1 x i8> splat (i8 -42), i1 1)
+ ret <vscale x 1 x i8> %r
+}
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
index d3ade92a6db05..2f56c2df0ca8f 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
@@ -16,6 +16,14 @@ define double @PR20832() {
ret double %1
}
+define <vscale x 1 x double> @scalable_vector() {
+; CHECK-LABEL: @scalable_vector(
+; CHECK-NEXT: ret <vscale x 1 x double> splat (double 5.600000e+01)
+;
+ %1 = call <vscale x 1 x double> @llvm.fma(<vscale x 1 x double> splat (double 7.0), <vscale x 1 x double> splat (double 8.0), <vscale x 1 x double> splat (double 0.0))
+ ret <vscale x 1 x double> %1
+}
+
; Test builtin fma with all finite non-zero constants.
define double @test_all_finite() {
; CHECK-LABEL: @test_all_finite(
diff --git a/llvm/test/Transforms/InstSimplify/exp10.ll b/llvm/test/Transforms/InstSimplify/exp10.ll
index a546bb1255d85..c415c419aad84 100644
--- a/llvm/test/Transforms/InstSimplify/exp10.ll
+++ b/llvm/test/Transforms/InstSimplify/exp10.ll
@@ -109,8 +109,7 @@ define <2 x float> @exp10_zero_vector() {
define <vscale x 2 x float> @exp10_zero_scalable_vector() {
; CHECK-LABEL: define <vscale x 2 x float> @exp10_zero_scalable_vector() {
-; CHECK-NEXT: [[RET:%.*]] = call <vscale x 2 x float> @llvm.exp10.nxv2f32(<vscale x 2 x float> zeroinitializer)
-; CHECK-NEXT: ret <vscale x 2 x float> [[RET]]
+; CHECK-NEXT: ret <vscale x 2 x float> splat (float 1.000000e+00)
;
%ret = call <vscale x 2 x float> @llvm.exp10.nxv2f32(<vscale x 2 x float> zeroinitializer)
ret <vscale x 2 x float> %ret
``````````
</details>
https://github.com/llvm/llvm-project/pull/141845
More information about the llvm-commits
mailing list