[llvm] [ConstantFolding] Fold intrinsics of scalable vectors with splatted operands (PR #141845)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Wed May 28 13:14:17 PDT 2025
https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/141845
As noted in https://github.com/llvm/llvm-project/pull/141821#issuecomment-2917328924, whilst we currently constant fold intrinsics of fixed-length vectors via their scalar counterpart, we don't do the same for scalable vectors.
This handles the scalable vector case when the operands are splats.
One weird snag in ConstantVector::getSplat was that it produced a undef if passed in poison, so this also contains a fix by checking for PoisonValue before UndefValue.
>From 6c7885506c362d60dcba6facfa118be040adb69b Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 28 May 2025 21:09:08 +0100
Subject: [PATCH 1/2] Precommit tests
---
llvm/test/Transforms/InstSimplify/ConstProp/abs.ll | 9 +++++++++
llvm/test/Transforms/InstSimplify/ConstProp/fma.ll | 9 +++++++++
2 files changed, 18 insertions(+)
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
index 37233b0f29342..94783adfafac7 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
@@ -43,3 +43,12 @@ define <8 x i8> @vec_const() {
%r = call <8 x i8> @llvm.abs.v8i8(<8 x i8> <i8 -127, i8 -126, i8 -42, i8 -1, i8 0, i8 1, i8 42, i8 127>, i1 1)
ret <8 x i8> %r
}
+
+define <vscale x 1 x i8> @scalable_vec_const() {
+; CHECK-LABEL: @scalable_vec_const(
+; CHECK-NEXT: [[R:%.*]] = call <vscale x 1 x i8> @llvm.abs.nxv1i8(<vscale x 1 x i8> splat (i8 -42), i1 true)
+; CHECK-NEXT: ret <vscale x 1 x i8> [[R]]
+;
+ %r = call <vscale x 1 x i8> @llvm.abs(<vscale x 1 x i8> splat (i8 -42), i1 1)
+ ret <vscale x 1 x i8> %r
+}
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
index d3ade92a6db05..de0e3437e91bc 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
@@ -16,6 +16,15 @@ define double @PR20832() {
ret double %1
}
+define <vscale x 1 x double> @scalable_vector() {
+; CHECK-LABEL: @scalable_vector(
+; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 1 x double> @llvm.fma.nxv1f64(<vscale x 1 x double> splat (double 7.000000e+00), <vscale x 1 x double> splat (double 8.000000e+00), <vscale x 1 x double> zeroinitializer)
+; CHECK-NEXT: ret <vscale x 1 x double> [[TMP1]]
+;
+ %1 = call <vscale x 1 x double> @llvm.fma(<vscale x 1 x double> splat (double 7.0), <vscale x 1 x double> splat (double 8.0), <vscale x 1 x double> splat (double 0.0))
+ ret <vscale x 1 x double> %1
+}
+
; Test builtin fma with all finite non-zero constants.
define double @test_all_finite() {
; CHECK-LABEL: @test_all_finite(
>From 5b7d6d594b6470551875a31dd40a8c9cc08adc85 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 28 May 2025 21:10:05 +0100
Subject: [PATCH 2/2] [ConstantFolding] Fold intrinsics of scalable vectors
with splatted operands
---
llvm/lib/Analysis/ConstantFolding.cpp | 30 ++++++++++++++++++-
llvm/lib/IR/Constants.cpp | 4 ++-
.../Transforms/InstSimplify/ConstProp/abs.ll | 3 +-
.../Transforms/InstSimplify/ConstProp/fma.ll | 3 +-
llvm/test/Transforms/InstSimplify/exp10.ll | 3 +-
5 files changed, 35 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 412a0e8979193..40302fbc8ee52 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -3780,7 +3780,35 @@ static Constant *ConstantFoldScalableVectorCall(
default:
break;
}
- return nullptr;
+
+ // If trivially vectorizable, try folding it via the scalar call if all
+ // operands are splats.
+
+ // TODO: ConstantFoldFixedVectorCall should probably check this too?
+ if (!isTriviallyVectorizable(IntrinsicID))
+ return nullptr;
+
+ SmallVector<Constant *, 4> SplatOps;
+ for (auto [I, Op] : enumerate(Operands)) {
+ if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, I, /*TTI=*/nullptr)) {
+ SplatOps.push_back(Op);
+ continue;
+ }
+ // TODO: Should getSplatValue return a poison scalar for a poison vector?
+ if (isa<PoisonValue>(Op)) {
+ SplatOps.push_back(PoisonValue::get(Op->getType()->getScalarType()));
+ continue;
+ }
+ Constant *Splat = Op->getSplatValue();
+ if (!Splat)
+ return nullptr;
+ SplatOps.push_back(Splat);
+ }
+ Constant *Folded = ConstantFoldScalarCall(
+ Name, IntrinsicID, SVTy->getElementType(), SplatOps, TLI, Call);
+ if (!Folded)
+ return nullptr;
+ return ConstantVector::getSplat(SVTy->getElementCount(), Folded);
}
static std::pair<Constant *, Constant *>
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index b2087d3651143..fa453309b34ee 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1507,7 +1507,9 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
if (V->isNullValue())
return ConstantAggregateZero::get(VTy);
- else if (isa<UndefValue>(V))
+ if (isa<PoisonValue>(V))
+ return PoisonValue::get(VTy);
+ if (isa<UndefValue>(V))
return UndefValue::get(VTy);
Type *IdxTy = Type::getInt64Ty(VTy->getContext());
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
index 94783adfafac7..615ab10248b2a 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
@@ -46,8 +46,7 @@ define <8 x i8> @vec_const() {
define <vscale x 1 x i8> @scalable_vec_const() {
; CHECK-LABEL: @scalable_vec_const(
-; CHECK-NEXT: [[R:%.*]] = call <vscale x 1 x i8> @llvm.abs.nxv1i8(<vscale x 1 x i8> splat (i8 -42), i1 true)
-; CHECK-NEXT: ret <vscale x 1 x i8> [[R]]
+; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 42)
;
%r = call <vscale x 1 x i8> @llvm.abs(<vscale x 1 x i8> splat (i8 -42), i1 1)
ret <vscale x 1 x i8> %r
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
index de0e3437e91bc..2f56c2df0ca8f 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
@@ -18,8 +18,7 @@ define double @PR20832() {
define <vscale x 1 x double> @scalable_vector() {
; CHECK-LABEL: @scalable_vector(
-; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 1 x double> @llvm.fma.nxv1f64(<vscale x 1 x double> splat (double 7.000000e+00), <vscale x 1 x double> splat (double 8.000000e+00), <vscale x 1 x double> zeroinitializer)
-; CHECK-NEXT: ret <vscale x 1 x double> [[TMP1]]
+; CHECK-NEXT: ret <vscale x 1 x double> splat (double 5.600000e+01)
;
%1 = call <vscale x 1 x double> @llvm.fma(<vscale x 1 x double> splat (double 7.0), <vscale x 1 x double> splat (double 8.0), <vscale x 1 x double> splat (double 0.0))
ret <vscale x 1 x double> %1
diff --git a/llvm/test/Transforms/InstSimplify/exp10.ll b/llvm/test/Transforms/InstSimplify/exp10.ll
index a546bb1255d85..c415c419aad84 100644
--- a/llvm/test/Transforms/InstSimplify/exp10.ll
+++ b/llvm/test/Transforms/InstSimplify/exp10.ll
@@ -109,8 +109,7 @@ define <2 x float> @exp10_zero_vector() {
define <vscale x 2 x float> @exp10_zero_scalable_vector() {
; CHECK-LABEL: define <vscale x 2 x float> @exp10_zero_scalable_vector() {
-; CHECK-NEXT: [[RET:%.*]] = call <vscale x 2 x float> @llvm.exp10.nxv2f32(<vscale x 2 x float> zeroinitializer)
-; CHECK-NEXT: ret <vscale x 2 x float> [[RET]]
+; CHECK-NEXT: ret <vscale x 2 x float> splat (float 1.000000e+00)
;
%ret = call <vscale x 2 x float> @llvm.exp10.nxv2f32(<vscale x 2 x float> zeroinitializer)
ret <vscale x 2 x float> %ret
More information about the llvm-commits
mailing list