[llvm] [ConstantFold] Support scalable constant splats in ConstantFoldCastInstruction (PR #133207)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 27 10:02:48 PDT 2025
https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/133207
>From a383b6108960c9c82b1f9441f7abc188c9b8d40a Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 26 Mar 2025 21:02:47 +0800
Subject: [PATCH 1/2] Precommit tests
---
.../test/Transforms/InstCombine/scalable-trunc.ll | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/llvm/test/Transforms/InstCombine/scalable-trunc.ll b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
index dcf4abe10425b..e07f773d9b600 100644
--- a/llvm/test/Transforms/InstCombine/scalable-trunc.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
@@ -20,6 +20,21 @@ entry:
ret void
}
+define <vscale x 1 x i8> @constant_splat_trunc() {
+; CHECK-LABEL: @constant_splat_trunc(
+; CHECK-NEXT: ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+;
+ %1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
+ ret <vscale x 1 x i8> %1
+}
+
+define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
+; CHECK-LABEL: @constant_splat_trunc_constantexpr(
+; CHECK-NEXT: ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+;
+ ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+}
+
declare void @llvm.aarch64.sve.st1.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i1>, ptr)
declare <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1>)
declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern)
>From 687f0ea1caf53a8af704e390192448f4b476e86c Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 26 Mar 2025 23:46:11 +0800
Subject: [PATCH 2/2] [ConstantFold] Support scalable constant splats in
ConstantFoldCastInstruction
Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats.
We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast.
By allowing ConstantExprs this also allow fixed vector ConstantExprs to be folded, which causes the diffs in llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll and llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll. I can remove them from this PR if reviewers would prefer.
Fixes #132922
---
llvm/lib/IR/ConstantFold.cpp | 10 ++++++----
.../known-bits-from-operator-constexpr.ll | 2 +-
.../Transforms/InstCombine/scalable-const-fp-splat.ll | 3 +--
llvm/test/Transforms/InstCombine/scalable-trunc.ll | 4 ++--
.../Transforms/InstSimplify/ConstProp/cast-vector.ll | 4 ++--
.../InstSimplify/ConstProp/vscale-inseltpoison.ll | 2 +-
llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll | 2 +-
.../Transforms/InstSimplify/vscale-inseltpoison.ll | 2 +-
llvm/test/Transforms/InstSimplify/vscale.ll | 2 +-
.../LoopVectorize/AArch64/induction-costs-sve.ll | 6 +++---
.../RISCV/truncate-to-minimal-bitwidth-evl-crash.ll | 2 +-
llvm/test/Transforms/VectorCombine/pr88796.ll | 2 +-
12 files changed, 21 insertions(+), 20 deletions(-)
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index b577f69eeaba0..692d15546e70e 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -160,10 +160,10 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
// If the cast operand is a constant vector, perform the cast by
// operating on each element. In the cast of bitcasts, the element
// count may be mismatched; don't attempt to handle that here.
- if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
- DestTy->isVectorTy() &&
- cast<FixedVectorType>(DestTy)->getNumElements() ==
- cast<FixedVectorType>(V->getType())->getNumElements()) {
+ if ((isa<ConstantVector, ConstantDataVector, ConstantExpr>(V)) &&
+ DestTy->isVectorTy() && V->getType()->isVectorTy() &&
+ cast<VectorType>(DestTy)->getElementCount() ==
+ cast<VectorType>(V->getType())->getElementCount()) {
VectorType *DestVecTy = cast<VectorType>(DestTy);
Type *DstEltTy = DestVecTy->getElementType();
// Fast path for splatted constants.
@@ -174,6 +174,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
return ConstantVector::getSplat(
cast<VectorType>(DestTy)->getElementCount(), Res);
}
+ if (isa<ScalableVectorType>(DestTy))
+ return nullptr;
SmallVector<Constant *, 16> res;
Type *Ty = IntegerType::get(V->getContext(), 32);
for (unsigned i = 0,
diff --git a/llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll b/llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll
index e3e30e052ee58..4dd9106898390 100644
--- a/llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll
+++ b/llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll
@@ -7,7 +7,7 @@
@g = global [21 x i32] zeroinitializer
define i32 @test1(i32 %a) {
; CHECK-LABEL: @test1(
-; CHECK-NEXT: [[T:%.*]] = sub i32 [[A:%.*]], extractelement (<4 x i32> ptrtoint (<4 x ptr> getelementptr inbounds ([21 x i32], ptr @g, <4 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 17>) to <4 x i32>), i32 3)
+; CHECK-NEXT: [[T:%.*]] = sub i32 [[A:%.*]], ptrtoint (ptr getelementptr inbounds ([21 x i32], ptr @g, i32 0, i32 17) to i32)
; CHECK-NEXT: ret i32 [[T]]
;
%t = sub i32 %a, extractelement (<4 x i32> ptrtoint (<4 x ptr> getelementptr inbounds ([21 x i32], ptr @g, <4 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 17>) to <4 x i32>), i32 3)
diff --git a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
index 595486361d16e..0982ecfbd3ea3 100644
--- a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
@@ -17,8 +17,7 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = fptrunc <vscale x 2 x double> splat (double -1.000000e+00) to <vscale x 2 x float>
-; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], splat (float -1.000000e+00)
; CHECK-NEXT: ret <vscale x 2 x float> [[TMP3]]
;
%2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
diff --git a/llvm/test/Transforms/InstCombine/scalable-trunc.ll b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
index e07f773d9b600..6272ccfe9cdbd 100644
--- a/llvm/test/Transforms/InstCombine/scalable-trunc.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
@@ -22,7 +22,7 @@ entry:
define <vscale x 1 x i8> @constant_splat_trunc() {
; CHECK-LABEL: @constant_splat_trunc(
-; CHECK-NEXT: ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
;
%1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
ret <vscale x 1 x i8> %1
@@ -30,7 +30,7 @@ define <vscale x 1 x i8> @constant_splat_trunc() {
define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
; CHECK-LABEL: @constant_splat_trunc_constantexpr(
-; CHECK-NEXT: ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
;
ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
}
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll b/llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll
index 3e4504a166366..f42f4071ac239 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll
@@ -8,7 +8,7 @@
define <2 x i16> @test1() {
; CHECK-LABEL: @test1(
; CHECK-NEXT: entry:
-; CHECK-NEXT: ret <2 x i16> ptrtoint (<2 x ptr> getelementptr inbounds ([10 x i32], ptr null, <2 x i64> zeroinitializer, <2 x i64> <i64 5, i64 7>) to <2 x i16>)
+; CHECK-NEXT: ret <2 x i16> <i16 ptrtoint (ptr getelementptr inbounds ([10 x i32], ptr null, i64 0, i64 5) to i16), i16 ptrtoint (ptr getelementptr inbounds ([10 x i32], ptr null, i64 0, i64 7) to i16)>
;
entry:
%gep = getelementptr inbounds [10 x i32], ptr null, i16 0, <2 x i16> <i16 5, i16 7>
@@ -23,7 +23,7 @@ entry:
define <2 x i16> @test2() {
; CHECK-LABEL: @test2(
; CHECK-NEXT: entry:
-; CHECK-NEXT: ret <2 x i16> ptrtoint (<2 x ptr> getelementptr (i32, ptr null, <2 x i64> <i64 5, i64 7>) to <2 x i16>)
+; CHECK-NEXT: ret <2 x i16> <i16 ptrtoint (ptr getelementptr (i32, ptr null, i64 5) to i16), i16 ptrtoint (ptr getelementptr (i32, ptr null, i64 7) to i16)>
;
entry:
%gep = getelementptr i32, ptr null, <2 x i16> <i16 5, i16 7>
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/vscale-inseltpoison.ll b/llvm/test/Transforms/InstSimplify/ConstProp/vscale-inseltpoison.ll
index a38dfaf8f5819..edc1260eca821 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/vscale-inseltpoison.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/vscale-inseltpoison.ll
@@ -208,7 +208,7 @@ define <vscale x 4 x i32> @shufflevector() {
define <vscale x 4 x float> @bitcast() {
; CHECK-LABEL: @bitcast(
-; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
+; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
;
%i1 = insertelement <vscale x 4 x i32> poison, i32 1, i32 0
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll b/llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll
index e24f57445a4d1..8ee6fa6e5f37f 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll
@@ -208,7 +208,7 @@ define <vscale x 4 x i32> @shufflevector() {
define <vscale x 4 x float> @bitcast() {
; CHECK-LABEL: @bitcast(
-; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
+; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
;
%i1 = insertelement <vscale x 4 x i32> undef, i32 1, i32 0
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll b/llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll
index 70ca39da95310..593f334abac1e 100644
--- a/llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll
+++ b/llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll
@@ -140,7 +140,7 @@ define <vscale x 2 x i1> @cmp_le_smax_always_true(<vscale x 2 x i64> %x) {
define <vscale x 4 x float> @bitcast() {
; CHECK-LABEL: @bitcast(
-; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
+; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
;
%i1 = insertelement <vscale x 4 x i32> poison, i32 1, i32 0
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/InstSimplify/vscale.ll b/llvm/test/Transforms/InstSimplify/vscale.ll
index 47cd88f4d5e4a..c09a0c201d761 100644
--- a/llvm/test/Transforms/InstSimplify/vscale.ll
+++ b/llvm/test/Transforms/InstSimplify/vscale.ll
@@ -152,7 +152,7 @@ define <vscale x 2 x i1> @cmp_le_smax_always_true(<vscale x 2 x i64> %x) {
define <vscale x 4 x float> @bitcast() {
; CHECK-LABEL: @bitcast(
-; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
+; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
;
%i1 = insertelement <vscale x 4 x i32> undef, i32 1, i32 0
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/induction-costs-sve.ll b/llvm/test/Transforms/LoopVectorize/AArch64/induction-costs-sve.ll
index d7b9d4eba2462..08fea4bfc9b2e 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/induction-costs-sve.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/induction-costs-sve.ll
@@ -51,8 +51,8 @@ define void @iv_casts(ptr %dst, ptr %src, i32 %x, i64 %N) #0 {
; DEFAULT-NEXT: [[TMP31:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD4]] to <vscale x 8 x i16>
; DEFAULT-NEXT: [[TMP32:%.*]] = or <vscale x 8 x i16> [[TMP28]], [[TMP30]]
; DEFAULT-NEXT: [[TMP33:%.*]] = or <vscale x 8 x i16> [[TMP29]], [[TMP31]]
-; DEFAULT-NEXT: [[TMP34:%.*]] = lshr <vscale x 8 x i16> [[TMP32]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
-; DEFAULT-NEXT: [[TMP35:%.*]] = lshr <vscale x 8 x i16> [[TMP33]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
+; DEFAULT-NEXT: [[TMP34:%.*]] = lshr <vscale x 8 x i16> [[TMP32]], splat (i16 1)
+; DEFAULT-NEXT: [[TMP35:%.*]] = lshr <vscale x 8 x i16> [[TMP33]], splat (i16 1)
; DEFAULT-NEXT: [[TMP36:%.*]] = trunc <vscale x 8 x i16> [[TMP34]] to <vscale x 8 x i8>
; DEFAULT-NEXT: [[TMP37:%.*]] = trunc <vscale x 8 x i16> [[TMP35]] to <vscale x 8 x i8>
; DEFAULT-NEXT: [[TMP38:%.*]] = getelementptr i8, ptr [[DST]], i64 [[INDEX]]
@@ -131,7 +131,7 @@ define void @iv_casts(ptr %dst, ptr %src, i32 %x, i64 %N) #0 {
; PRED-NEXT: [[TMP22:%.*]] = mul <vscale x 16 x i16> [[TMP17]], [[TMP16]]
; PRED-NEXT: [[TMP24:%.*]] = zext <vscale x 16 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 16 x i16>
; PRED-NEXT: [[TMP20:%.*]] = or <vscale x 16 x i16> [[TMP22]], [[TMP24]]
-; PRED-NEXT: [[TMP21:%.*]] = lshr <vscale x 16 x i16> [[TMP20]], trunc (<vscale x 16 x i32> splat (i32 1) to <vscale x 16 x i16>)
+; PRED-NEXT: [[TMP21:%.*]] = lshr <vscale x 16 x i16> [[TMP20]], splat (i16 1)
; PRED-NEXT: [[TMP23:%.*]] = trunc <vscale x 16 x i16> [[TMP21]] to <vscale x 16 x i8>
; PRED-NEXT: [[TMP26:%.*]] = getelementptr i8, ptr [[DST]], i64 [[INDEX]]
; PRED-NEXT: [[TMP27:%.*]] = getelementptr i8, ptr [[TMP26]], i32 0
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-evl-crash.ll b/llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-evl-crash.ll
index f884653a485b0..bfdb7eec752c3 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-evl-crash.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-evl-crash.ll
@@ -28,7 +28,7 @@ define void @truncate_to_minimal_bitwidths_widen_cast_recipe(ptr %src) {
; CHECK-NEXT: [[VP_OP_LOAD:%.*]] = call <vscale x 1 x i8> @llvm.vp.load.nxv1i8.p0(ptr align 1 [[TMP6]], <vscale x 1 x i1> splat (i1 true), i32 [[TMP3]])
; CHECK-NEXT: [[TMP7:%.*]] = zext <vscale x 1 x i8> [[VP_OP_LOAD]] to <vscale x 1 x i16>
; CHECK-NEXT: [[TMP12:%.*]] = mul <vscale x 1 x i16> zeroinitializer, [[TMP7]]
-; CHECK-NEXT: [[VP_OP1:%.*]] = lshr <vscale x 1 x i16> [[TMP12]], trunc (<vscale x 1 x i32> splat (i32 1) to <vscale x 1 x i16>)
+; CHECK-NEXT: [[VP_OP1:%.*]] = lshr <vscale x 1 x i16> [[TMP12]], splat (i16 1)
; CHECK-NEXT: [[TMP8:%.*]] = trunc <vscale x 1 x i16> [[VP_OP1]] to <vscale x 1 x i8>
; CHECK-NEXT: call void @llvm.vp.scatter.nxv1i8.nxv1p0(<vscale x 1 x i8> [[TMP8]], <vscale x 1 x ptr> align 1 zeroinitializer, <vscale x 1 x i1> splat (i1 true), i32 [[TMP3]])
; CHECK-NEXT: [[TMP9:%.*]] = zext i32 [[TMP3]] to i64
diff --git a/llvm/test/Transforms/VectorCombine/pr88796.ll b/llvm/test/Transforms/VectorCombine/pr88796.ll
index 6f988922f2cc0..3ca0786a6e803 100644
--- a/llvm/test/Transforms/VectorCombine/pr88796.ll
+++ b/llvm/test/Transforms/VectorCombine/pr88796.ll
@@ -4,7 +4,7 @@
define i32 @test() {
; CHECK-LABEL: define i32 @test() {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[TMP0:%.*]] = tail call i16 @llvm.vector.reduce.and.nxv8i16(<vscale x 8 x i16> trunc (<vscale x 8 x i32> splat (i32 268435456) to <vscale x 8 x i16>))
+; CHECK-NEXT: [[TMP0:%.*]] = tail call i16 @llvm.vector.reduce.and.nxv8i16(<vscale x 8 x i16> zeroinitializer)
; CHECK-NEXT: ret i32 0
;
entry:
More information about the llvm-commits
mailing list