[llvm] 4acdb8e - [VectorCombine] Scalarize extracts of ZExt if profitable. (#142976)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 3 00:49:38 PDT 2025
Author: Florian Hahn
Date: 2025-07-03T08:49:32+01:00
New Revision: 4acdb8e14ef7e72d5f56b45d51de72c7b797be03
URL: https://github.com/llvm/llvm-project/commit/4acdb8e14ef7e72d5f56b45d51de72c7b797be03
DIFF: https://github.com/llvm/llvm-project/commit/4acdb8e14ef7e72d5f56b45d51de72c7b797be03.diff
LOG: [VectorCombine] Scalarize extracts of ZExt if profitable. (#142976)
Add a new scalarization transform that tries to convert extracts of a
vector ZExt to a set of scalar shift and mask operations. This can be
profitable if the cost of extracting is the same or higher than the cost
of 2 scalar ops. This is the case on AArch64 for example.
For AArch64,this shows up in a number of workloads, including av1aom,
gmsh, minizinc and astc-encoder.
PR: https://github.com/llvm/llvm-project/pull/142976
Added:
Modified:
llvm/lib/Transforms/Vectorize/VectorCombine.cpp
llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 136a44a95ac8d..b9ce20ebd3e63 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -123,6 +123,7 @@ class VectorCombine {
bool foldBinopOfReductions(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
+ bool scalarizeExtExtract(Instruction &I);
bool foldConcatOfBoolMasks(Instruction &I);
bool foldPermuteOfBinops(Instruction &I);
bool foldShuffleOfBinops(Instruction &I);
@@ -1777,6 +1778,73 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
return true;
}
+bool VectorCombine::scalarizeExtExtract(Instruction &I) {
+ auto *Ext = dyn_cast<ZExtInst>(&I);
+ if (!Ext)
+ return false;
+
+ // Try to convert a vector zext feeding only extracts to a set of scalar
+ // (Src << ExtIdx *Size) & (Size -1)
+ // if profitable .
+ auto *SrcTy = dyn_cast<FixedVectorType>(Ext->getOperand(0)->getType());
+ if (!SrcTy)
+ return false;
+ auto *DstTy = cast<FixedVectorType>(Ext->getType());
+
+ Type *ScalarDstTy = DstTy->getElementType();
+ if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy))
+ return false;
+
+ InstructionCost VectorCost =
+ TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy,
+ TTI::CastContextHint::None, CostKind, Ext);
+ unsigned ExtCnt = 0;
+ bool ExtLane0 = false;
+ for (User *U : Ext->users()) {
+ const APInt *Idx;
+ if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
+ return false;
+ if (cast<Instruction>(U)->use_empty())
+ continue;
+ ExtCnt += 1;
+ ExtLane0 |= Idx->isZero();
+ VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
+ CostKind, Idx->getZExtValue(), U);
+ }
+
+ InstructionCost ScalarCost =
+ ExtCnt * TTI.getArithmeticInstrCost(
+ Instruction::And, ScalarDstTy, CostKind,
+ {TTI::OK_AnyValue, TTI::OP_None},
+ {TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
+ (ExtCnt - ExtLane0) *
+ TTI.getArithmeticInstrCost(
+ Instruction::LShr, ScalarDstTy, CostKind,
+ {TTI::OK_AnyValue, TTI::OP_None},
+ {TTI::OK_NonUniformConstantValue, TTI::OP_None});
+ if (ScalarCost > VectorCost)
+ return false;
+
+ Value *ScalarV = Ext->getOperand(0);
+ if (!isGuaranteedNotToBePoison(ScalarV, &AC, dyn_cast<Instruction>(ScalarV),
+ &DT))
+ ScalarV = Builder.CreateFreeze(ScalarV);
+ ScalarV = Builder.CreateBitCast(
+ ScalarV,
+ IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
+ uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
+ uint64_t EltBitMask = (1ull << SrcEltSizeInBits) - 1;
+ for (User *U : Ext->users()) {
+ auto *Extract = cast<ExtractElementInst>(U);
+ uint64_t Idx =
+ cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
+ Value *LShr = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits);
+ Value *And = Builder.CreateAnd(LShr, EltBitMask);
+ U->replaceAllUsesWith(And);
+ }
+ return true;
+}
+
/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
/// to "(bitcast (concat X, Y))"
/// where X/Y are bitcasted from i1 mask vectors.
@@ -3665,6 +3733,7 @@ bool VectorCombine::run() {
if (IsVectorType) {
MadeChange |= scalarizeOpOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
+ MadeChange |= scalarizeExtExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
MadeChange |= foldInterleaveIntrinsics(I);
}
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
index 09c03991ad7c3..60700412686ea 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
@@ -9,15 +9,23 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) {
; CHECK-LABEL: define void @zext_v4i8_all_lanes_used(
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
+; CHECK-NEXT: [[TMP9:%.*]] = and i32 [[TMP1]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP9]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP2]])
; CHECK-NEXT: ret void
;
entry:
@@ -68,13 +76,20 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) {
; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_1(
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP2]])
; CHECK-NEXT: ret void
;
entry:
@@ -93,13 +108,19 @@ define void @zext_v4i8_3_lanes_used_2(<4 x i8> %src) {
; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_2(
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP1]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP2]])
; CHECK-NEXT: ret void
;
entry:
@@ -118,11 +139,17 @@ define void @zext_v4i8_2_lanes_used_1(<4 x i8> %src) {
; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_1(
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
; CHECK-NEXT: ret void
;
entry:
@@ -139,11 +166,16 @@ define void @zext_v4i8_2_lanes_used_2(<4 x i8> %src) {
; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_2(
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP1]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
; CHECK-NEXT: ret void
;
entry:
@@ -160,15 +192,22 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) {
; CHECK-LABEL: define void @zext_v4i8_all_lanes_used_noundef(
; CHECK-SAME: <4 x i8> noundef [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 24
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i32 [[TMP0]], 16
+; CHECK-NEXT: [[TMP4:%.*]] = and i32 [[TMP3]], 255
+; CHECK-NEXT: [[TMP5:%.*]] = lshr i32 [[TMP0]], 8
+; CHECK-NEXT: [[TMP6:%.*]] = and i32 [[TMP5]], 255
+; CHECK-NEXT: [[TMP8:%.*]] = and i32 [[TMP0]], 255
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP8]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP6]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP4]])
+; CHECK-NEXT: call void @use.i32(i32 [[TMP1]])
; CHECK-NEXT: ret void
;
entry:
@@ -221,15 +260,23 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) {
; CHECK-LABEL: define void @zext_v4i16_all_lanes_used(
; CHECK-SAME: <4 x i16> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i16> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 48
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 32
+; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 65535
+; CHECK-NEXT: [[TMP6:%.*]] = lshr i64 [[TMP1]], 16
+; CHECK-NEXT: [[TMP7:%.*]] = and i64 [[TMP6]], 65535
+; CHECK-NEXT: [[TMP9:%.*]] = and i64 [[TMP1]], 65535
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i16> [[SRC]] to <4 x i64>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i64> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i64> [[EXT9]], i64 1
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i64> [[EXT9]], i64 2
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i64> [[EXT9]], i64 3
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]])
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]])
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_2]])
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_3]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP9]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP7]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP5]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP2]])
; CHECK-NEXT: ret void
;
entry:
@@ -250,11 +297,15 @@ define void @zext_v2i32_all_lanes_used(<2 x i32> %src) {
; CHECK-LABEL: define void @zext_v2i32_all_lanes_used(
; CHECK-SAME: <2 x i32> [[SRC:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze <2 x i32> [[SRC]]
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 32
+; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP1]], 4294967295
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <2 x i32> [[SRC]] to <2 x i64>
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <2 x i64> [[EXT9]], i64 0
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <2 x i64> [[EXT9]], i64 1
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]])
-; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP5]])
+; CHECK-NEXT: call void @use.i64(i64 [[TMP2]])
; CHECK-NEXT: ret void
;
entry:
@@ -266,3 +317,32 @@ entry:
call void @use.i64(i64 %ext.1)
ret void
}
+
+define void @zext_nxv4i8_all_lanes_used(<vscale x 4 x i8> %src) {
+; CHECK-LABEL: define void @zext_nxv4i8_all_lanes_used(
+; CHECK-SAME: <vscale x 4 x i8> [[SRC:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <vscale x 4 x i8> [[SRC]] to <vscale x 4 x i32>
+; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <vscale x 4 x i32> [[EXT9]], i64 0
+; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <vscale x 4 x i32> [[EXT9]], i64 1
+; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <vscale x 4 x i32> [[EXT9]], i64 2
+; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <vscale x 4 x i32> [[EXT9]], i64 3
+; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
+; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
+; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT: ret void
+;
+entry:
+ %ext9 = zext nneg <vscale x 4 x i8> %src to <vscale x 4 x i32>
+ %ext.0 = extractelement <vscale x 4 x i32> %ext9, i64 0
+ %ext.1 = extractelement <vscale x 4 x i32> %ext9, i64 1
+ %ext.2 = extractelement <vscale x 4 x i32> %ext9, i64 2
+ %ext.3 = extractelement <vscale x 4 x i32> %ext9, i64 3
+
+ call void @use.i32(i32 %ext.0)
+ call void @use.i32(i32 %ext.1)
+ call void @use.i32(i32 %ext.2)
+ call void @use.i32(i32 %ext.3)
+ ret void
+}
More information about the llvm-commits
mailing list