[llvm] [ConstantFolding] Fold intrinsics of scalable vectors with splatted operands (PR #141845)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Wed May 28 13:14:17 PDT 2025


https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/141845

As noted in https://github.com/llvm/llvm-project/pull/141821#issuecomment-2917328924, whilst we currently constant fold intrinsics of fixed-length vectors via their scalar counterpart, we don't do the same for scalable vectors.

This handles the scalable vector case when the operands are splats.

One weird snag in ConstantVector::getSplat was that it produced a undef if passed in poison, so this also contains a fix by checking for PoisonValue before UndefValue.

>From 6c7885506c362d60dcba6facfa118be040adb69b Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 28 May 2025 21:09:08 +0100
Subject: [PATCH 1/2] Precommit tests

---
 llvm/test/Transforms/InstSimplify/ConstProp/abs.ll | 9 +++++++++
 llvm/test/Transforms/InstSimplify/ConstProp/fma.ll | 9 +++++++++
 2 files changed, 18 insertions(+)

diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
index 37233b0f29342..94783adfafac7 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
@@ -43,3 +43,12 @@ define <8 x i8> @vec_const() {
   %r = call <8 x i8> @llvm.abs.v8i8(<8 x i8> <i8 -127, i8 -126, i8 -42, i8 -1, i8 0, i8 1, i8 42, i8 127>, i1 1)
   ret <8 x i8> %r
 }
+
+define <vscale x 1 x i8> @scalable_vec_const() {
+; CHECK-LABEL: @scalable_vec_const(
+; CHECK-NEXT:    [[R:%.*]] = call <vscale x 1 x i8> @llvm.abs.nxv1i8(<vscale x 1 x i8> splat (i8 -42), i1 true)
+; CHECK-NEXT:    ret <vscale x 1 x i8> [[R]]
+;
+  %r = call <vscale x 1 x i8> @llvm.abs(<vscale x 1 x i8> splat (i8 -42), i1 1)
+  ret <vscale x 1 x i8> %r
+}
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
index d3ade92a6db05..de0e3437e91bc 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
@@ -16,6 +16,15 @@ define double @PR20832()  {
   ret double %1
 }
 
+define <vscale x 1 x double> @scalable_vector()  {
+; CHECK-LABEL: @scalable_vector(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 1 x double> @llvm.fma.nxv1f64(<vscale x 1 x double> splat (double 7.000000e+00), <vscale x 1 x double> splat (double 8.000000e+00), <vscale x 1 x double> zeroinitializer)
+; CHECK-NEXT:    ret <vscale x 1 x double> [[TMP1]]
+;
+  %1 = call <vscale x 1 x double> @llvm.fma(<vscale x 1 x double> splat (double 7.0), <vscale x 1 x double> splat (double 8.0), <vscale x 1 x double> splat (double 0.0))
+  ret <vscale x 1 x double> %1
+}
+
 ; Test builtin fma with all finite non-zero constants.
 define double @test_all_finite()  {
 ; CHECK-LABEL: @test_all_finite(

>From 5b7d6d594b6470551875a31dd40a8c9cc08adc85 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 28 May 2025 21:10:05 +0100
Subject: [PATCH 2/2] [ConstantFolding] Fold intrinsics of scalable vectors
 with splatted operands

---
 llvm/lib/Analysis/ConstantFolding.cpp         | 30 ++++++++++++++++++-
 llvm/lib/IR/Constants.cpp                     |  4 ++-
 .../Transforms/InstSimplify/ConstProp/abs.ll  |  3 +-
 .../Transforms/InstSimplify/ConstProp/fma.ll  |  3 +-
 llvm/test/Transforms/InstSimplify/exp10.ll    |  3 +-
 5 files changed, 35 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 412a0e8979193..40302fbc8ee52 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -3780,7 +3780,35 @@ static Constant *ConstantFoldScalableVectorCall(
   default:
     break;
   }
-  return nullptr;
+
+  // If trivially vectorizable, try folding it via the scalar call if all
+  // operands are splats.
+
+  // TODO: ConstantFoldFixedVectorCall should probably check this too?
+  if (!isTriviallyVectorizable(IntrinsicID))
+    return nullptr;
+
+  SmallVector<Constant *, 4> SplatOps;
+  for (auto [I, Op] : enumerate(Operands)) {
+    if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, I, /*TTI=*/nullptr)) {
+      SplatOps.push_back(Op);
+      continue;
+    }
+    // TODO: Should getSplatValue return a poison scalar for a poison vector?
+    if (isa<PoisonValue>(Op)) {
+      SplatOps.push_back(PoisonValue::get(Op->getType()->getScalarType()));
+      continue;
+    }
+    Constant *Splat = Op->getSplatValue();
+    if (!Splat)
+      return nullptr;
+    SplatOps.push_back(Splat);
+  }
+  Constant *Folded = ConstantFoldScalarCall(
+      Name, IntrinsicID, SVTy->getElementType(), SplatOps, TLI, Call);
+  if (!Folded)
+    return nullptr;
+  return ConstantVector::getSplat(SVTy->getElementCount(), Folded);
 }
 
 static std::pair<Constant *, Constant *>
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index b2087d3651143..fa453309b34ee 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1507,7 +1507,9 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
 
   if (V->isNullValue())
     return ConstantAggregateZero::get(VTy);
-  else if (isa<UndefValue>(V))
+  if (isa<PoisonValue>(V))
+    return PoisonValue::get(VTy);
+  if (isa<UndefValue>(V))
     return UndefValue::get(VTy);
 
   Type *IdxTy = Type::getInt64Ty(VTy->getContext());
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
index 94783adfafac7..615ab10248b2a 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/abs.ll
@@ -46,8 +46,7 @@ define <8 x i8> @vec_const() {
 
 define <vscale x 1 x i8> @scalable_vec_const() {
 ; CHECK-LABEL: @scalable_vec_const(
-; CHECK-NEXT:    [[R:%.*]] = call <vscale x 1 x i8> @llvm.abs.nxv1i8(<vscale x 1 x i8> splat (i8 -42), i1 true)
-; CHECK-NEXT:    ret <vscale x 1 x i8> [[R]]
+; CHECK-NEXT:    ret <vscale x 1 x i8> splat (i8 42)
 ;
   %r = call <vscale x 1 x i8> @llvm.abs(<vscale x 1 x i8> splat (i8 -42), i1 1)
   ret <vscale x 1 x i8> %r
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
index de0e3437e91bc..2f56c2df0ca8f 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/fma.ll
@@ -18,8 +18,7 @@ define double @PR20832()  {
 
 define <vscale x 1 x double> @scalable_vector()  {
 ; CHECK-LABEL: @scalable_vector(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 1 x double> @llvm.fma.nxv1f64(<vscale x 1 x double> splat (double 7.000000e+00), <vscale x 1 x double> splat (double 8.000000e+00), <vscale x 1 x double> zeroinitializer)
-; CHECK-NEXT:    ret <vscale x 1 x double> [[TMP1]]
+; CHECK-NEXT:    ret <vscale x 1 x double> splat (double 5.600000e+01)
 ;
   %1 = call <vscale x 1 x double> @llvm.fma(<vscale x 1 x double> splat (double 7.0), <vscale x 1 x double> splat (double 8.0), <vscale x 1 x double> splat (double 0.0))
   ret <vscale x 1 x double> %1
diff --git a/llvm/test/Transforms/InstSimplify/exp10.ll b/llvm/test/Transforms/InstSimplify/exp10.ll
index a546bb1255d85..c415c419aad84 100644
--- a/llvm/test/Transforms/InstSimplify/exp10.ll
+++ b/llvm/test/Transforms/InstSimplify/exp10.ll
@@ -109,8 +109,7 @@ define <2 x float> @exp10_zero_vector() {
 
 define <vscale x 2 x float> @exp10_zero_scalable_vector() {
 ; CHECK-LABEL: define <vscale x 2 x float> @exp10_zero_scalable_vector() {
-; CHECK-NEXT:    [[RET:%.*]] = call <vscale x 2 x float> @llvm.exp10.nxv2f32(<vscale x 2 x float> zeroinitializer)
-; CHECK-NEXT:    ret <vscale x 2 x float> [[RET]]
+; CHECK-NEXT:    ret <vscale x 2 x float> splat (float 1.000000e+00)
 ;
   %ret = call <vscale x 2 x float> @llvm.exp10.nxv2f32(<vscale x 2 x float> zeroinitializer)
   ret <vscale x 2 x float> %ret



More information about the llvm-commits mailing list