[llvm] [ConstantFold] Support scalable constant splats in ConstantFoldCastInstruction (PR #133207)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 26 22:07:26 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Author: Luke Lau (lukel97)

<details>
<summary>Changes</summary>

Stacked on #<!-- -->132960 to prevent a regression

Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats.

We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast.

I believe this will also allow casts of fixed vector ConstantExprs to be folded but I couldn't come up with a test case for this, the ConstantExprs seem to be folded away before reaching InstCombine.

Fixes #<!-- -->132922




---
Full diff: https://github.com/llvm/llvm-project/pull/133207.diff


5 Files Affected:

- (modified) llvm/lib/IR/ConstantFold.cpp (+6-4) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+6-6) 
- (modified) llvm/test/Transforms/InstCombine/fpextend.ll (+11) 
- (modified) llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll (+12) 
- (modified) llvm/test/Transforms/InstCombine/scalable-trunc.ll (+15) 


``````````diff
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index b577f69eeaba0..692d15546e70e 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -160,10 +160,10 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
   // If the cast operand is a constant vector, perform the cast by
   // operating on each element. In the cast of bitcasts, the element
   // count may be mismatched; don't attempt to handle that here.
-  if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
-      DestTy->isVectorTy() &&
-      cast<FixedVectorType>(DestTy)->getNumElements() ==
-          cast<FixedVectorType>(V->getType())->getNumElements()) {
+  if ((isa<ConstantVector, ConstantDataVector, ConstantExpr>(V)) &&
+      DestTy->isVectorTy() && V->getType()->isVectorTy() &&
+      cast<VectorType>(DestTy)->getElementCount() ==
+          cast<VectorType>(V->getType())->getElementCount()) {
     VectorType *DestVecTy = cast<VectorType>(DestTy);
     Type *DstEltTy = DestVecTy->getElementType();
     // Fast path for splatted constants.
@@ -174,6 +174,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
       return ConstantVector::getSplat(
           cast<VectorType>(DestTy)->getElementCount(), Res);
     }
+    if (isa<ScalableVectorType>(DestTy))
+      return nullptr;
     SmallVector<Constant *, 16> res;
     Type *Ty = IntegerType::get(V->getContext(), 32);
     for (unsigned i = 0,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 4ec1af394464b..1a95636f37ed7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1684,12 +1684,12 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
     if (Type *T = shrinkFPConstant(CFP, PreferBFloat))
       return T;
 
-  // We can only correctly find a minimum type for a scalable vector when it is
-  // a splat. For splats of constant values the fpext is wrapped up as a
-  // ConstantExpr.
-  if (auto *FPCExt = dyn_cast<ConstantExpr>(V))
-    if (FPCExt->getOpcode() == Instruction::FPExt)
-      return FPCExt->getOperand(0)->getType();
+  // Try to shrink scalable and fixed splat vectors.
+  if (auto *FPC = dyn_cast<Constant>(V))
+    if (isa<VectorType>(V->getType()))
+      if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue()))
+        if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
+          return T;
 
   // Try to shrink a vector of FP constants. This returns nullptr on scalable
   // vectors
diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index c9adbe10d8db4..9125339c00ecf 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -448,3 +448,14 @@ define bfloat @bf16_frem(bfloat %x) {
   %t3 = fptrunc float %t2 to bfloat
   ret bfloat %t3
 }
+
+define <4 x float> @v4f32_fadd(<4 x float> %a) {
+; CHECK-LABEL: @v4f32_fadd(
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x float> [[A:%.*]], splat (float -1.000000e+00)
+; CHECK-NEXT:    ret <4 x float> [[TMP1]]
+;
+  %2 = fpext <4 x float> %a to <4 x double>
+  %4 = fadd <4 x double> %2, splat (double -1.000000e+00)
+  %5 = fptrunc <4 x double> %4 to <4 x float>
+  ret <4 x float> %5
+}
diff --git a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
index 731b079881f08..0982ecfbd3ea3 100644
--- a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
@@ -13,3 +13,15 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
   %5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float>
   ret <vscale x 2 x float> %5
 }
+
+define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
+; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
+; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
+; CHECK-NEXT:    [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], splat (float -1.000000e+00)
+; CHECK-NEXT:    ret <vscale x 2 x float> [[TMP3]]
+;
+  %2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
+  %4 = fadd <vscale x 2 x double> %2, splat (double -1.000000e+00)
+  %5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float>
+  ret <vscale x 2 x float> %5
+}
diff --git a/llvm/test/Transforms/InstCombine/scalable-trunc.ll b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
index dcf4abe10425b..6272ccfe9cdbd 100644
--- a/llvm/test/Transforms/InstCombine/scalable-trunc.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
@@ -20,6 +20,21 @@ entry:
   ret void
 }
 
+define <vscale x 1 x i8> @constant_splat_trunc() {
+; CHECK-LABEL: @constant_splat_trunc(
+; CHECK-NEXT:    ret <vscale x 1 x i8> splat (i8 1)
+;
+  %1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
+  ret <vscale x 1 x i8> %1
+}
+
+define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
+; CHECK-LABEL: @constant_splat_trunc_constantexpr(
+; CHECK-NEXT:    ret <vscale x 1 x i8> splat (i8 1)
+;
+  ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+}
+
 declare void @llvm.aarch64.sve.st1.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i1>, ptr)
 declare <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1>)
 declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern)

``````````

</details>


https://github.com/llvm/llvm-project/pull/133207


More information about the llvm-commits mailing list