[llvm] [InstCombine] Fold bitcast (extelt (bitcast X), Idx) into bitcast+shufflevector. (PR #136998)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 23 08:04:22 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Ricardo Jesus (rj-jesus)
<details>
<summary>Changes</summary>
Fold sequences such as:
```llvm
%bc = bitcast <8 x i32> %v to <2 x i128>
%ext = extractelement <2 x i128> %bc, i64 0
%res = bitcast i128 %ext to <2 x i64>
```
Into:
```llvm
%bc = bitcast <8 x i32> %v to <4 x i64>
%res = shufflevector <4 x i64> %bc, <4 x i64> poison, <2 x i32> <i32 0, i32 1>
```
The motivation for this is a long-standing regression affecting SIMDe on
AArch64 introduced indirectly by the AlwaysInliner (1a2e77c). Some
reproducers:
* https://godbolt.org/z/53qx18s6M
* https://godbolt.org/z/o5e43h5M7
This is an alternative approach to #<!-- -->135769 to fix the regressions above.
---
Full diff: https://github.com/llvm/llvm-project/pull/136998.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+48)
- (modified) llvm/test/Transforms/InstCombine/bitcast.ll (+28)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 1a95636f37ed7..d656dcc21ae1e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -2380,6 +2380,51 @@ static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI,
return Result;
}
+/// If the input is (extractelement (bitcast X), Idx) and the source and
+/// destination types are vectors, we are performing a vector extract from X. We
+/// can replace the extractelement+bitcast with a shufflevector, avoiding the
+/// final scalar->vector bitcast. This pattern is usually handled better by the
+/// backend.
+///
+/// Example:
+/// %bc = bitcast <8 x i32> %X to <2 x i128>
+/// %ext = extractelement <2 x i128> %bc1, i64 1
+/// bitcast i128 %ext to <2 x i64>
+/// --->
+/// %bc = bitcast <8 x i32> %X to <4 x i64>
+/// shufflevector <4 x i64> %bc, <4 x i64> poison, <2 x i32> <i32 2, i32 3>
+static Instruction *foldBitCastExtElt(BitCastInst &BitCast,
+ InstCombiner::BuilderTy &Builder) {
+ Value *X;
+ uint64_t Index;
+ if (!match(
+ BitCast.getOperand(0),
+ m_OneUse(m_ExtractElt(m_BitCast(m_Value(X)), m_ConstantInt(Index)))))
+ return nullptr;
+
+ auto *SrcTy = dyn_cast<FixedVectorType>(X->getType());
+ auto *DstTy = dyn_cast<FixedVectorType>(BitCast.getType());
+ if (!SrcTy || !DstTy)
+ return nullptr;
+
+ // Check if the mask indices would overflow.
+ unsigned NumElts = DstTy->getNumElements();
+ if (Index > INT32_MAX || NumElts > INT32_MAX ||
+ (Index + 1) * NumElts - 1 > INT32_MAX)
+ return nullptr;
+
+ unsigned DstEltWidth = DstTy->getScalarSizeInBits();
+ unsigned SrcVecWidth = SrcTy->getPrimitiveSizeInBits();
+ assert((SrcVecWidth % DstEltWidth == 0) && "Invalid types.");
+ auto *NewVecTy =
+ FixedVectorType::get(DstTy->getElementType(), SrcVecWidth / DstEltWidth);
+ auto *NewBC = Builder.CreateBitCast(X, NewVecTy, "bc");
+
+ unsigned StartIdx = Index * NumElts;
+ auto Mask = llvm::to_vector<16>(llvm::seq<int>(StartIdx, StartIdx + NumElts));
+ return new ShuffleVectorInst(NewBC, Mask);
+}
+
/// Canonicalize scalar bitcasts of extracted elements into a bitcast of the
/// vector followed by extract element. The backend tends to handle bitcasts of
/// vectors better than bitcasts of scalars because vector registers are
@@ -2866,6 +2911,9 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
if (Instruction *I = canonicalizeBitCastExtElt(CI, *this))
return I;
+ if (Instruction *I = foldBitCastExtElt(CI, Builder))
+ return I;
+
if (Instruction *I = foldBitCastBitwiseLogic(CI, Builder))
return I;
diff --git a/llvm/test/Transforms/InstCombine/bitcast.ll b/llvm/test/Transforms/InstCombine/bitcast.ll
index 37d41de3e9991..cade44412341d 100644
--- a/llvm/test/Transforms/InstCombine/bitcast.ll
+++ b/llvm/test/Transforms/InstCombine/bitcast.ll
@@ -480,6 +480,34 @@ define double @bitcast_extelt8(<1 x i64> %A) {
ret double %bc
}
+; Extract a subvector from a vector, extracted element wider than source.
+
+define <2 x i64> @bitcast_extelt9(<8 x i32> %A) {
+; CHECK-LABEL: @bitcast_extelt9(
+; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x i32> [[A:%.*]] to <4 x i64>
+; CHECK-NEXT: [[BC2:%.*]] = shufflevector <4 x i64> [[BC]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: ret <2 x i64> [[BC2]]
+;
+ %bc1 = bitcast <8 x i32> %A to <2 x i128>
+ %ext = extractelement <2 x i128> %bc1, i64 1
+ %bc2 = bitcast i128 %ext to <2 x i64>
+ ret <2 x i64> %bc2
+}
+
+; Extract a subvector from a vector, extracted element narrower than source.
+
+define <2 x i8> @bitcast_extelt10(<8 x i32> %A) {
+; CHECK-LABEL: @bitcast_extelt10(
+; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x i32> [[A:%.*]] to <32 x i8>
+; CHECK-NEXT: [[BC2:%.*]] = shufflevector <32 x i8> [[BC]], <32 x i8> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT: ret <2 x i8> [[BC2]]
+;
+ %bc1 = bitcast <8 x i32> %A to <16 x i16>
+ %ext = extractelement <16 x i16> %bc1, i64 3
+ %bc2 = bitcast i16 %ext to <2 x i8>
+ ret <2 x i8> %bc2
+}
+
define <2 x i32> @test4(i32 %A, i32 %B){
; CHECK-LABEL: @test4(
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[A:%.*]], i64 0
``````````
</details>
https://github.com/llvm/llvm-project/pull/136998
More information about the llvm-commits
mailing list