[llvm] [VectorCombine] Scalarize extracts of ZExt if profitable. (PR #142976)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 5 07:21:09 PDT 2025


https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/142976

Add a new scalarization transform that tries to convert extracts of a vector ZExt to a set of scalar shift and mask operations. This can be profitable if the cost of extracting is the same or higher than the cost of 2 scalar ops. This is the case on AArch64 for example.

For AArch64,this shows up in a number of workloads, including av1aom, gmsh, minizinc and astc-encoder.

>From 681cd152ab9e4478602c60c04d7610d726d37b63 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 5 Jun 2025 15:06:23 +0100
Subject: [PATCH] [VectorCombine] Scalarize extracts of ZExt if profitable.

Add a new scalarization transform that tries to convert extracts of
a vector ZExt to a set of scalar shift and mask operations. This can be
profitable if the cost of extracting is the same or higher than the cost
of 2 scalar ops. This is the case on AArch64 for example.

For AArch64,this shows up in a number of workloads, including av1aom, gmsh,
minizinc and astc-encoder.
---
 .../Transforms/Vectorize/VectorCombine.cpp    |  68 +++++++++++
 .../VectorCombine/AArch64/ext-extract.ll      | 111 ++++++++++++++----
 2 files changed, 155 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 52cb1dbb33b86..85375654fbc19 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -120,6 +120,7 @@ class VectorCombine {
   bool foldBinopOfReductions(Instruction &I);
   bool foldSingleElementStore(Instruction &I);
   bool scalarizeLoadExtract(Instruction &I);
+  bool scalarizeExtExtract(Instruction &I);
   bool foldConcatOfBoolMasks(Instruction &I);
   bool foldPermuteOfBinops(Instruction &I);
   bool foldShuffleOfBinops(Instruction &I);
@@ -1710,6 +1711,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
+bool VectorCombine::scalarizeExtExtract(Instruction &I) {
+  if (!match(&I, m_ZExt(m_Value())))
+    return false;
+
+  // Try to convert a vector zext feeding only extracts to a set of scalar (Src
+  // << ExtIdx *Size) & (Size -1), if profitable.
+  auto *Ext = cast<ZExtInst>(&I);
+  auto *SrcTy = cast<FixedVectorType>(Ext->getOperand(0)->getType());
+  auto *DstTy = cast<FixedVectorType>(Ext->getType());
+
+  if (DL->getTypeSizeInBits(SrcTy) !=
+      DL->getTypeSizeInBits(DstTy->getElementType()))
+    return false;
+
+  InstructionCost VectorCost = TTI.getCastInstrCost(
+      Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
+  unsigned ExtCnt = 0;
+  bool ExtLane0 = false;
+  for (User *U : Ext->users()) {
+    const APInt *Idx;
+    if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
+      return false;
+    if (cast<Instruction>(U)->use_empty())
+      continue;
+    ExtCnt += 1;
+    ExtLane0 |= Idx->isZero();
+    VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
+                                         CostKind, Idx->getZExtValue(), U);
+  }
+
+  Type *ScalarDstTy = DstTy->getElementType();
+  InstructionCost ScalarCost =
+      ExtCnt * TTI.getArithmeticInstrCost(
+                   Instruction::And, ScalarDstTy, CostKind,
+                   {TTI::OK_AnyValue, TTI::OP_None},
+                   {TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
+      (ExtCnt - ExtLane0) *
+          TTI.getArithmeticInstrCost(
+
+              Instruction::LShr, ScalarDstTy, CostKind,
+              {TTI::OK_AnyValue, TTI::OP_None},
+              {TTI::OK_NonUniformConstantValue, TTI::OP_None});
+  if (ScalarCost > VectorCost)
+    return false;
+
+  Value *ScalarV = Ext->getOperand(0);
+  if (!isGuaranteedNotToBePoison(ScalarV, &AC))
+    ScalarV = Builder.CreateFreeze(ScalarV);
+  ScalarV = Builder.CreateBitCast(
+      ScalarV,
+      IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
+  unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
+  Value *EltBitMask =
+      ConstantInt::get(ScalarV->getType(), (1ull << SrcEltSizeInBits) - 1);
+  for (auto *U : to_vector(Ext->users())) {
+    auto *Extract = cast<ExtractElementInst>(U);
+    unsigned Idx =
+        cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
+    auto *S = Builder.CreateLShr(
+        ScalarV, ConstantInt::get(ScalarV->getType(), Idx * SrcEltSizeInBits));
+    auto *A = Builder.CreateAnd(S, EltBitMask);
+    U->replaceAllUsesWith(A);
+  }
+  return true;
+}
+
 /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
 /// to "(bitcast (concat X, Y))"
 /// where X/Y are bitcasted from i1 mask vectors.
@@ -3576,6 +3643,7 @@ bool VectorCombine::run() {
     if (IsVectorType) {
       MadeChange |= scalarizeOpOrCmp(I);
       MadeChange |= scalarizeLoadExtract(I);
+      MadeChange |= scalarizeExtExtract(I);
       MadeChange |= scalarizeVPIntrinsic(I);
       MadeChange |= foldInterleaveIntrinsics(I);
     }
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
index 09c03991ad7c3..23538589ae32c 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
@@ -9,15 +9,25 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_all_lanes_used(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP7:%.*]] = and i32 [[TMP6]], 255
+; CHECK-NEXT:    [[TMP8:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP9:%.*]] = and i32 [[TMP8]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP9]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -68,13 +78,21 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_1(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP7:%.*]] = and i32 [[TMP6]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -93,13 +111,21 @@ define void @zext_v4i8_3_lanes_used_2(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_2(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = and i32 [[TMP6]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -118,11 +144,17 @@ define void @zext_v4i8_2_lanes_used_1(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_1(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -139,11 +171,17 @@ define void @zext_v4i8_2_lanes_used_2(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_2(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -160,15 +198,24 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) {
 ; CHECK-LABEL: define void @zext_v4i8_all_lanes_used_noundef(
 ; CHECK-SAME: <4 x i8> noundef [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 24
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 255
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i32 [[TMP0]], 16
+; CHECK-NEXT:    [[TMP4:%.*]] = and i32 [[TMP3]], 255
+; CHECK-NEXT:    [[TMP5:%.*]] = lshr i32 [[TMP0]], 8
+; CHECK-NEXT:    [[TMP6:%.*]] = and i32 [[TMP5]], 255
+; CHECK-NEXT:    [[TMP7:%.*]] = lshr i32 [[TMP0]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = and i32 [[TMP7]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP8]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP6]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP4]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP2]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -221,15 +268,25 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) {
 ; CHECK-LABEL: define void @zext_v4i16_all_lanes_used(
 ; CHECK-SAME: <4 x i16> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i16> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i64 [[TMP1]], 48
+; CHECK-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 65535
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i64 [[TMP1]], 32
+; CHECK-NEXT:    [[TMP5:%.*]] = and i64 [[TMP4]], 65535
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i64 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP7:%.*]] = and i64 [[TMP6]], 65535
+; CHECK-NEXT:    [[TMP8:%.*]] = lshr i64 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP9:%.*]] = and i64 [[TMP8]], 65535
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i16> [[SRC]] to <4 x i64>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i64> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i64> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i64> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i64> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP9]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP7]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP5]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -250,11 +307,17 @@ define void @zext_v2i32_all_lanes_used(<2 x i32> %src) {
 ; CHECK-LABEL: define void @zext_v2i32_all_lanes_used(
 ; CHECK-SAME: <2 x i32> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <2 x i32> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i64 [[TMP1]], 32
+; CHECK-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 4294967295
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i64 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP5:%.*]] = and i64 [[TMP4]], 4294967295
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <2 x i32> [[SRC]] to <2 x i64>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <2 x i64> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <2 x i64> [[EXT9]], i64 1
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_1]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP5]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:



More information about the llvm-commits mailing list