[llvm] [ConstantFold] Support scalable constant splats in ConstantFoldCastInstruction (PR #133207)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 26 22:07:26 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-ir
Author: Luke Lau (lukel97)
<details>
<summary>Changes</summary>
Stacked on #<!-- -->132960 to prevent a regression
Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats.
We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast.
I believe this will also allow casts of fixed vector ConstantExprs to be folded but I couldn't come up with a test case for this, the ConstantExprs seem to be folded away before reaching InstCombine.
Fixes #<!-- -->132922
---
Full diff: https://github.com/llvm/llvm-project/pull/133207.diff
5 Files Affected:
- (modified) llvm/lib/IR/ConstantFold.cpp (+6-4)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+6-6)
- (modified) llvm/test/Transforms/InstCombine/fpextend.ll (+11)
- (modified) llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll (+12)
- (modified) llvm/test/Transforms/InstCombine/scalable-trunc.ll (+15)
``````````diff
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index b577f69eeaba0..692d15546e70e 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -160,10 +160,10 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
// If the cast operand is a constant vector, perform the cast by
// operating on each element. In the cast of bitcasts, the element
// count may be mismatched; don't attempt to handle that here.
- if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
- DestTy->isVectorTy() &&
- cast<FixedVectorType>(DestTy)->getNumElements() ==
- cast<FixedVectorType>(V->getType())->getNumElements()) {
+ if ((isa<ConstantVector, ConstantDataVector, ConstantExpr>(V)) &&
+ DestTy->isVectorTy() && V->getType()->isVectorTy() &&
+ cast<VectorType>(DestTy)->getElementCount() ==
+ cast<VectorType>(V->getType())->getElementCount()) {
VectorType *DestVecTy = cast<VectorType>(DestTy);
Type *DstEltTy = DestVecTy->getElementType();
// Fast path for splatted constants.
@@ -174,6 +174,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
return ConstantVector::getSplat(
cast<VectorType>(DestTy)->getElementCount(), Res);
}
+ if (isa<ScalableVectorType>(DestTy))
+ return nullptr;
SmallVector<Constant *, 16> res;
Type *Ty = IntegerType::get(V->getContext(), 32);
for (unsigned i = 0,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 4ec1af394464b..1a95636f37ed7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1684,12 +1684,12 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
if (Type *T = shrinkFPConstant(CFP, PreferBFloat))
return T;
- // We can only correctly find a minimum type for a scalable vector when it is
- // a splat. For splats of constant values the fpext is wrapped up as a
- // ConstantExpr.
- if (auto *FPCExt = dyn_cast<ConstantExpr>(V))
- if (FPCExt->getOpcode() == Instruction::FPExt)
- return FPCExt->getOperand(0)->getType();
+ // Try to shrink scalable and fixed splat vectors.
+ if (auto *FPC = dyn_cast<Constant>(V))
+ if (isa<VectorType>(V->getType()))
+ if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue()))
+ if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
+ return T;
// Try to shrink a vector of FP constants. This returns nullptr on scalable
// vectors
diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index c9adbe10d8db4..9125339c00ecf 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -448,3 +448,14 @@ define bfloat @bf16_frem(bfloat %x) {
%t3 = fptrunc float %t2 to bfloat
ret bfloat %t3
}
+
+define <4 x float> @v4f32_fadd(<4 x float> %a) {
+; CHECK-LABEL: @v4f32_fadd(
+; CHECK-NEXT: [[TMP1:%.*]] = fadd <4 x float> [[A:%.*]], splat (float -1.000000e+00)
+; CHECK-NEXT: ret <4 x float> [[TMP1]]
+;
+ %2 = fpext <4 x float> %a to <4 x double>
+ %4 = fadd <4 x double> %2, splat (double -1.000000e+00)
+ %5 = fptrunc <4 x double> %4 to <4 x float>
+ ret <4 x float> %5
+}
diff --git a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
index 731b079881f08..0982ecfbd3ea3 100644
--- a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
@@ -13,3 +13,15 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
%5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float>
ret <vscale x 2 x float> %5
}
+
+define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
+; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
+; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
+; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], splat (float -1.000000e+00)
+; CHECK-NEXT: ret <vscale x 2 x float> [[TMP3]]
+;
+ %2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
+ %4 = fadd <vscale x 2 x double> %2, splat (double -1.000000e+00)
+ %5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float>
+ ret <vscale x 2 x float> %5
+}
diff --git a/llvm/test/Transforms/InstCombine/scalable-trunc.ll b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
index dcf4abe10425b..6272ccfe9cdbd 100644
--- a/llvm/test/Transforms/InstCombine/scalable-trunc.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
@@ -20,6 +20,21 @@ entry:
ret void
}
+define <vscale x 1 x i8> @constant_splat_trunc() {
+; CHECK-LABEL: @constant_splat_trunc(
+; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
+;
+ %1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
+ ret <vscale x 1 x i8> %1
+}
+
+define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
+; CHECK-LABEL: @constant_splat_trunc_constantexpr(
+; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
+;
+ ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+}
+
declare void @llvm.aarch64.sve.st1.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i1>, ptr)
declare <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1>)
declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern)
``````````
</details>
https://github.com/llvm/llvm-project/pull/133207
More information about the llvm-commits
mailing list