[llvm] [ConstantFold] Support scalable constant splats in ConstantFoldCastInstruction (PR #133207)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 26 22:06:48 PDT 2025


https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/133207

Stacked on #132960 to prevent a regression

Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats.

We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast.

I believe this will also allow casts of fixed vector ConstantExprs to be folded but I couldn't come up with a test case for this, the ConstantExprs seem to be folded away before reaching InstCombine.

Fixes #132922




>From 4b9c2437f5b666cc1962c3957a3884ce2246b6cc Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 26 Mar 2025 01:13:34 +0800
Subject: [PATCH 1/6] Precommit tests

---
 .../InstCombine/scalable-const-fp-splat.ll         | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
index 731b079881f08..3648a868fe1cb 100644
--- a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
@@ -13,3 +13,17 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
   %5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float>
   ret <vscale x 2 x float> %5
 }
+
+define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
+; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
+; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = fpext <vscale x 2 x float> [[A]] to <vscale x 2 x double>
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd <vscale x 2 x double> [[TMP1]], splat (double -1.000000e+00)
+; CHECK-NEXT:    [[TMP3:%.*]] = fptrunc <vscale x 2 x double> [[TMP2]] to <vscale x 2 x float>
+; CHECK-NEXT:    ret <vscale x 2 x float> [[TMP3]]
+;
+  %2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
+  %4 = fadd <vscale x 2 x double> %2, splat (double -1.000000e+00)
+  %5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float>
+  ret <vscale x 2 x float> %5
+}

>From 01be82ce028b51e54f247379459d3d994165699e Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 26 Mar 2025 01:16:49 +0800
Subject: [PATCH 2/6] [InstCombine] Handle scalable splats of constants in
 getMinimumFPType

We previously handled ConstantExpr scalable splats in 5d929794a87602cfd873381e11cc99149196bb49, but only fpexts.

ConstantExpr fpexts have since been removed, and simultaneously we didn't handle splats of constants that weren't extended.

This updates it to remove the fpext check and instead see if we can shrink the result of getSplatValue.

Note that the test case doesn't get completely folded away due to #132922
---
 llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp  | 11 ++++++-----
 .../Transforms/InstCombine/scalable-const-fp-splat.ll |  5 ++---
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 4ec1af394464b..3faaf1e52db26 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1685,11 +1685,12 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
       return T;
 
   // We can only correctly find a minimum type for a scalable vector when it is
-  // a splat. For splats of constant values the fpext is wrapped up as a
-  // ConstantExpr.
-  if (auto *FPCExt = dyn_cast<ConstantExpr>(V))
-    if (FPCExt->getOpcode() == Instruction::FPExt)
-      return FPCExt->getOperand(0)->getType();
+  // a splat.
+  if (auto *FPCE = dyn_cast<ConstantExpr>(V))
+    if (isa<ScalableVectorType>(V->getType()))
+      if (auto *Splat = dyn_cast<ConstantFP>(FPCE->getSplatValue()))
+        if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
+          return T;
 
   // Try to shrink a vector of FP constants. This returns nullptr on scalable
   // vectors
diff --git a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
index 3648a868fe1cb..595486361d16e 100644
--- a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
@@ -17,9 +17,8 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
 define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
 ; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
 ; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = fpext <vscale x 2 x float> [[A]] to <vscale x 2 x double>
-; CHECK-NEXT:    [[TMP2:%.*]] = fadd <vscale x 2 x double> [[TMP1]], splat (double -1.000000e+00)
-; CHECK-NEXT:    [[TMP3:%.*]] = fptrunc <vscale x 2 x double> [[TMP2]] to <vscale x 2 x float>
+; CHECK-NEXT:    [[TMP1:%.*]] = fptrunc <vscale x 2 x double> splat (double -1.000000e+00) to <vscale x 2 x float>
+; CHECK-NEXT:    [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], [[TMP1]]
 ; CHECK-NEXT:    ret <vscale x 2 x float> [[TMP3]]
 ;
   %2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>

>From d26e6721df2ce5d6aa5f581c7a230ac4d4e45719 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 27 Mar 2025 06:41:12 +0200
Subject: [PATCH 3/6] Add fixed vector test

---
 llvm/test/Transforms/InstCombine/fpextend.ll | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index c9adbe10d8db4..9125339c00ecf 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -448,3 +448,14 @@ define bfloat @bf16_frem(bfloat %x) {
   %t3 = fptrunc float %t2 to bfloat
   ret bfloat %t3
 }
+
+define <4 x float> @v4f32_fadd(<4 x float> %a) {
+; CHECK-LABEL: @v4f32_fadd(
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x float> [[A:%.*]], splat (float -1.000000e+00)
+; CHECK-NEXT:    ret <4 x float> [[TMP1]]
+;
+  %2 = fpext <4 x float> %a to <4 x double>
+  %4 = fadd <4 x double> %2, splat (double -1.000000e+00)
+  %5 = fptrunc <4 x double> %4 to <4 x float>
+  ret <4 x float> %5
+}

>From 8c7f8c69782f4a2fc416d07fc54e8a546f065fba Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 27 Mar 2025 06:41:52 +0200
Subject: [PATCH 4/6] Also handle fixed-length splats, as a sort of fast-path.
 Should be NFC?

---
 llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 3faaf1e52db26..1a95636f37ed7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1684,11 +1684,10 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
     if (Type *T = shrinkFPConstant(CFP, PreferBFloat))
       return T;
 
-  // We can only correctly find a minimum type for a scalable vector when it is
-  // a splat.
-  if (auto *FPCE = dyn_cast<ConstantExpr>(V))
-    if (isa<ScalableVectorType>(V->getType()))
-      if (auto *Splat = dyn_cast<ConstantFP>(FPCE->getSplatValue()))
+  // Try to shrink scalable and fixed splat vectors.
+  if (auto *FPC = dyn_cast<Constant>(V))
+    if (isa<VectorType>(V->getType()))
+      if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue()))
         if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
           return T;
 

>From 3758a218e8bafd305440ad77c102eaecaab77933 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 26 Mar 2025 21:02:47 +0800
Subject: [PATCH 5/6] Precommit tests

---
 .../test/Transforms/InstCombine/scalable-trunc.ll | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/scalable-trunc.ll b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
index dcf4abe10425b..e07f773d9b600 100644
--- a/llvm/test/Transforms/InstCombine/scalable-trunc.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
@@ -20,6 +20,21 @@ entry:
   ret void
 }
 
+define <vscale x 1 x i8> @constant_splat_trunc() {
+; CHECK-LABEL: @constant_splat_trunc(
+; CHECK-NEXT:    ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+;
+  %1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
+  ret <vscale x 1 x i8> %1
+}
+
+define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
+; CHECK-LABEL: @constant_splat_trunc_constantexpr(
+; CHECK-NEXT:    ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+;
+  ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+}
+
 declare void @llvm.aarch64.sve.st1.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i1>, ptr)
 declare <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1>)
 declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern)

>From 41f958688e620bd6b508c10fbcfd9232012e9142 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 26 Mar 2025 23:46:11 +0800
Subject: [PATCH 6/6] [ConstantFold] Support scalable constant splats in
 ConstantFoldCastInstruction

Stacked on #132960 to prevent a regression

Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats.

We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast.

I believe this will also allow casts of fixed vector ConstantExprs to be folded but I couldn't come up with a test case for this, the ConstantExprs seem to be folded away before reaching InstCombine.

Fixes #132922
---
 llvm/lib/IR/ConstantFold.cpp                           | 10 ++++++----
 .../Transforms/InstCombine/scalable-const-fp-splat.ll  |  3 +--
 llvm/test/Transforms/InstCombine/scalable-trunc.ll     |  4 ++--
 3 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index b577f69eeaba0..692d15546e70e 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -160,10 +160,10 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
   // If the cast operand is a constant vector, perform the cast by
   // operating on each element. In the cast of bitcasts, the element
   // count may be mismatched; don't attempt to handle that here.
-  if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
-      DestTy->isVectorTy() &&
-      cast<FixedVectorType>(DestTy)->getNumElements() ==
-          cast<FixedVectorType>(V->getType())->getNumElements()) {
+  if ((isa<ConstantVector, ConstantDataVector, ConstantExpr>(V)) &&
+      DestTy->isVectorTy() && V->getType()->isVectorTy() &&
+      cast<VectorType>(DestTy)->getElementCount() ==
+          cast<VectorType>(V->getType())->getElementCount()) {
     VectorType *DestVecTy = cast<VectorType>(DestTy);
     Type *DstEltTy = DestVecTy->getElementType();
     // Fast path for splatted constants.
@@ -174,6 +174,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
       return ConstantVector::getSplat(
           cast<VectorType>(DestTy)->getElementCount(), Res);
     }
+    if (isa<ScalableVectorType>(DestTy))
+      return nullptr;
     SmallVector<Constant *, 16> res;
     Type *Ty = IntegerType::get(V->getContext(), 32);
     for (unsigned i = 0,
diff --git a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
index 595486361d16e..0982ecfbd3ea3 100644
--- a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
@@ -17,8 +17,7 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
 define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
 ; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
 ; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = fptrunc <vscale x 2 x double> splat (double -1.000000e+00) to <vscale x 2 x float>
-; CHECK-NEXT:    [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], splat (float -1.000000e+00)
 ; CHECK-NEXT:    ret <vscale x 2 x float> [[TMP3]]
 ;
   %2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
diff --git a/llvm/test/Transforms/InstCombine/scalable-trunc.ll b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
index e07f773d9b600..6272ccfe9cdbd 100644
--- a/llvm/test/Transforms/InstCombine/scalable-trunc.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
@@ -22,7 +22,7 @@ entry:
 
 define <vscale x 1 x i8> @constant_splat_trunc() {
 ; CHECK-LABEL: @constant_splat_trunc(
-; CHECK-NEXT:    ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+; CHECK-NEXT:    ret <vscale x 1 x i8> splat (i8 1)
 ;
   %1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
   ret <vscale x 1 x i8> %1
@@ -30,7 +30,7 @@ define <vscale x 1 x i8> @constant_splat_trunc() {
 
 define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
 ; CHECK-LABEL: @constant_splat_trunc_constantexpr(
-; CHECK-NEXT:    ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+; CHECK-NEXT:    ret <vscale x 1 x i8> splat (i8 1)
 ;
   ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
 }



More information about the llvm-commits mailing list