[llvm] [VectorCombine] Scalarize extracts of ZExt if profitable. (PR #142976)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 5 07:21:46 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Florian Hahn (fhahn)
<details>
<summary>Changes</summary>
Add a new scalarization transform that tries to convert extracts of a vector ZExt to a set of scalar shift and mask operations. This can be profitable if the cost of extracting is the same or higher than the cost of 2 scalar ops. This is the case on AArch64 for example.
For AArch64,this shows up in a number of workloads, including av1aom, gmsh, minizinc and astc-encoder.
---
Full diff: https://github.com/llvm/llvm-project/pull/142976.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+68)
- (modified) llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll (+87-24)
``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 52cb1dbb33b86..85375654fbc19 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -120,6 +120,7 @@ class VectorCombine {
bool foldBinopOfReductions(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
+ bool scalarizeExtExtract(Instruction &I);
bool foldConcatOfBoolMasks(Instruction &I);
bool foldPermuteOfBinops(Instruction &I);
bool foldShuffleOfBinops(Instruction &I);
@@ -1710,6 +1711,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
return true;
}
+bool VectorCombine::scalarizeExtExtract(Instruction &I) {
+ if (!match(&I, m_ZExt(m_Value())))
+ return false;
+
+ // Try to convert a vector zext feeding only extracts to a set of scalar (Src
+ // << ExtIdx *Size) & (Size -1), if profitable.
+ auto *Ext = cast<ZExtInst>(&I);
+ auto *SrcTy = cast<FixedVectorType>(Ext->getOperand(0)->getType());
+ auto *DstTy = cast<FixedVectorType>(Ext->getType());
+
+ if (DL->getTypeSizeInBits(SrcTy) !=
+ DL->getTypeSizeInBits(DstTy->getElementType()))
+ return false;
+
+ InstructionCost VectorCost = TTI.getCastInstrCost(
+ Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
+ unsigned ExtCnt = 0;
+ bool ExtLane0 = false;
+ for (User *U : Ext->users()) {
+ const APInt *Idx;
+ if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
+ return false;
+ if (cast<Instruction>(U)->use_empty())
+ continue;
+ ExtCnt += 1;
+ ExtLane0 |= Idx->isZero();
+ VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
+ CostKind, Idx->getZExtValue(), U);
+ }
+
+ Type *ScalarDstTy = DstTy->getElementType();
+ InstructionCost ScalarCost =
+ ExtCnt * TTI.getArithmeticInstrCost(
+ Instruction::And, ScalarDstTy, CostKind,
+ {TTI::OK_AnyValue, TTI::OP_None},
+ {TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
+ (ExtCnt - ExtLane0) *
+ TTI.getArithmeticInstrCost(
+
+ Instruction::LShr, ScalarDstTy, CostKind,
+ {TTI::OK_AnyValue, TTI::OP_None},
+ {TTI::OK_NonUniformConstantValue, TTI::OP_None});
+ if (ScalarCost > VectorCost)
+ return false;
+
+ Value *ScalarV = Ext->getOperand(0);
+ if (!isGuaranteedNotToBePoison(ScalarV, &AC))
+ ScalarV = Builder.CreateFreeze(ScalarV);
+ ScalarV = Builder.CreateBitCast(
+ ScalarV,
+ IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
+ unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
+ Value *EltBitMask =
+ ConstantInt::get(ScalarV->getType(), (1ull << SrcEltSizeInBits) - 1);
+ for (auto *U : to_vector(Ext->users())) {
+ auto *Extract = cast<ExtractElementInst>(U);
+ unsigned Idx =
+ cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
+ auto *S = Builder.CreateLShr(
+ ScalarV, ConstantInt::get(ScalarV->getType(), Idx * SrcEltSizeInBits));
+ auto *A = Builder.CreateAnd(S, EltBitMask);
+ U->replaceAllUsesWith(A);
+ }
+ return true;
+}
+
/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
/// to "(bitcast (concat X, Y))"
/// where X/Y are bitcasted from i1 mask vectors.
@@ -3576,6 +3643,7 @@ bool VectorCombine::run() {
if (IsVectorType) {
MadeChange |= scalarizeOpOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
+ MadeChange |= scalarizeExtExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
MadeChange |= foldInterleaveIntrinsics(I);
}
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
index 09c03991ad7c3..23538589ae32c 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
@@ -9,15 +9,25 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) {
; CHECK-LABEL: define void @zext_v4i8_all_lanes_used(
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
+; CHECK-NEXT: [[TMP8:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT: [[TMP9:%.*]] = and i32 [[TMP8]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP9]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
; CHECK-NEXT: ret void
;
entry:
@@ -68,13 +78,21 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) {
; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_1(
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
; CHECK-NEXT: ret void
;
entry:
@@ -93,13 +111,21 @@ define void @zext_v4i8_3_lanes_used_2(<4 x i8> %src) {
; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_2(
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
; CHECK-NEXT: ret void
;
entry:
@@ -118,11 +144,17 @@ define void @zext_v4i8_2_lanes_used_1(<4 x i8> %src) {
; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_1(
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
; CHECK-NEXT: ret void
;
entry:
@@ -139,11 +171,17 @@ define void @zext_v4i8_2_lanes_used_2(<4 x i8> %src) {
; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_2(
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
; CHECK-NEXT: ret void
;
entry:
@@ -160,15 +198,24 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) {
; CHECK-LABEL: define void @zext_v4i8_all_lanes_used_noundef(
; CHECK-SAME: <4 x i8> noundef [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 24
+; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 255
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i32 [[TMP0]], 16
+; CHECK-NEXT: [[TMP4:%.*]] = and i32 [[TMP3]], 255
+; CHECK-NEXT: [[TMP5:%.*]] = lshr i32 [[TMP0]], 8
+; CHECK-NEXT: [[TMP6:%.*]] = and i32 [[TMP5]], 255
+; CHECK-NEXT: [[TMP7:%.*]] = lshr i32 [[TMP0]], 0
+; CHECK-NEXT: [[TMP8:%.*]] = and i32 [[TMP7]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP8]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP6]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP4]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP2]])
; CHECK-NEXT: ret void
;
entry:
@@ -221,15 +268,25 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) {
; CHECK-LABEL: define void @zext_v4i16_all_lanes_used(
; CHECK-SAME: <4 x i16> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i16> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 48
+; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], 65535
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 32
+; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 65535
+; CHECK-NEXT: [[TMP6:%.*]] = lshr i64 [[TMP1]], 16
+; CHECK-NEXT: [[TMP7:%.*]] = and i64 [[TMP6]], 65535
+; CHECK-NEXT: [[TMP8:%.*]] = lshr i64 [[TMP1]], 0
+; CHECK-NEXT: [[TMP9:%.*]] = and i64 [[TMP8]], 65535
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i16> [[SRC]] to <4 x i64>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i64> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i64> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i64> [[EXT9]], i64 2
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i64> [[EXT9]], i64 3
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]])
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]])
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_2]])
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_3]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP9]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP7]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP5]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP3]])
; CHECK-NEXT: ret void
;
entry:
@@ -250,11 +307,17 @@ define void @zext_v2i32_all_lanes_used(<2 x i32> %src) {
; CHECK-LABEL: define void @zext_v2i32_all_lanes_used(
; CHECK-SAME: <2 x i32> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <2 x i32> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 32
+; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], 4294967295
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 0
+; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 4294967295
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <2 x i32> [[SRC]] to <2 x i64>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <2 x i64> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <2 x i64> [[EXT9]], i64 1
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]])
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP5]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP3]])
; CHECK-NEXT: ret void
;
entry:
``````````
</details>
https://github.com/llvm/llvm-project/pull/142976
More information about the llvm-commits
mailing list