[llvm] [VectorCombine] Scalarize extracts of ZExt if profitable. (PR #142976)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 5 11:03:02 PDT 2025


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/142976

>From 44f87e62acc69c7658ec3546d92d6241a693b374 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 5 Jun 2025 15:06:23 +0100
Subject: [PATCH 1/2] [VectorCombine] Scalarize extracts of ZExt if profitable.

Add a new scalarization transform that tries to convert extracts of
a vector ZExt to a set of scalar shift and mask operations. This can be
profitable if the cost of extracting is the same or higher than the cost
of 2 scalar ops. This is the case on AArch64 for example.

For AArch64,this shows up in a number of workloads, including av1aom, gmsh,
minizinc and astc-encoder.
---
 .../Transforms/Vectorize/VectorCombine.cpp    |  68 +++++++++++
 .../VectorCombine/AArch64/ext-extract.ll      | 111 ++++++++++++++----
 2 files changed, 155 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 52cb1dbb33b86..85375654fbc19 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -120,6 +120,7 @@ class VectorCombine {
   bool foldBinopOfReductions(Instruction &I);
   bool foldSingleElementStore(Instruction &I);
   bool scalarizeLoadExtract(Instruction &I);
+  bool scalarizeExtExtract(Instruction &I);
   bool foldConcatOfBoolMasks(Instruction &I);
   bool foldPermuteOfBinops(Instruction &I);
   bool foldShuffleOfBinops(Instruction &I);
@@ -1710,6 +1711,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
+bool VectorCombine::scalarizeExtExtract(Instruction &I) {
+  if (!match(&I, m_ZExt(m_Value())))
+    return false;
+
+  // Try to convert a vector zext feeding only extracts to a set of scalar (Src
+  // << ExtIdx *Size) & (Size -1), if profitable.
+  auto *Ext = cast<ZExtInst>(&I);
+  auto *SrcTy = cast<FixedVectorType>(Ext->getOperand(0)->getType());
+  auto *DstTy = cast<FixedVectorType>(Ext->getType());
+
+  if (DL->getTypeSizeInBits(SrcTy) !=
+      DL->getTypeSizeInBits(DstTy->getElementType()))
+    return false;
+
+  InstructionCost VectorCost = TTI.getCastInstrCost(
+      Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
+  unsigned ExtCnt = 0;
+  bool ExtLane0 = false;
+  for (User *U : Ext->users()) {
+    const APInt *Idx;
+    if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
+      return false;
+    if (cast<Instruction>(U)->use_empty())
+      continue;
+    ExtCnt += 1;
+    ExtLane0 |= Idx->isZero();
+    VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
+                                         CostKind, Idx->getZExtValue(), U);
+  }
+
+  Type *ScalarDstTy = DstTy->getElementType();
+  InstructionCost ScalarCost =
+      ExtCnt * TTI.getArithmeticInstrCost(
+                   Instruction::And, ScalarDstTy, CostKind,
+                   {TTI::OK_AnyValue, TTI::OP_None},
+                   {TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
+      (ExtCnt - ExtLane0) *
+          TTI.getArithmeticInstrCost(
+
+              Instruction::LShr, ScalarDstTy, CostKind,
+              {TTI::OK_AnyValue, TTI::OP_None},
+              {TTI::OK_NonUniformConstantValue, TTI::OP_None});
+  if (ScalarCost > VectorCost)
+    return false;
+
+  Value *ScalarV = Ext->getOperand(0);
+  if (!isGuaranteedNotToBePoison(ScalarV, &AC))
+    ScalarV = Builder.CreateFreeze(ScalarV);
+  ScalarV = Builder.CreateBitCast(
+      ScalarV,
+      IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
+  unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
+  Value *EltBitMask =
+      ConstantInt::get(ScalarV->getType(), (1ull << SrcEltSizeInBits) - 1);
+  for (auto *U : to_vector(Ext->users())) {
+    auto *Extract = cast<ExtractElementInst>(U);
+    unsigned Idx =
+        cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
+    auto *S = Builder.CreateLShr(
+        ScalarV, ConstantInt::get(ScalarV->getType(), Idx * SrcEltSizeInBits));
+    auto *A = Builder.CreateAnd(S, EltBitMask);
+    U->replaceAllUsesWith(A);
+  }
+  return true;
+}
+
 /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
 /// to "(bitcast (concat X, Y))"
 /// where X/Y are bitcasted from i1 mask vectors.
@@ -3576,6 +3643,7 @@ bool VectorCombine::run() {
     if (IsVectorType) {
       MadeChange |= scalarizeOpOrCmp(I);
       MadeChange |= scalarizeLoadExtract(I);
+      MadeChange |= scalarizeExtExtract(I);
       MadeChange |= scalarizeVPIntrinsic(I);
       MadeChange |= foldInterleaveIntrinsics(I);
     }
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
index 09c03991ad7c3..23538589ae32c 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
@@ -9,15 +9,25 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_all_lanes_used(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP7:%.*]] = and i32 [[TMP6]], 255
+; CHECK-NEXT:    [[TMP8:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP9:%.*]] = and i32 [[TMP8]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP9]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -68,13 +78,21 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_1(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP7:%.*]] = and i32 [[TMP6]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -93,13 +111,21 @@ define void @zext_v4i8_3_lanes_used_2(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_2(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = and i32 [[TMP6]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -118,11 +144,17 @@ define void @zext_v4i8_2_lanes_used_1(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_1(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -139,11 +171,17 @@ define void @zext_v4i8_2_lanes_used_2(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_2(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -160,15 +198,24 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) {
 ; CHECK-LABEL: define void @zext_v4i8_all_lanes_used_noundef(
 ; CHECK-SAME: <4 x i8> noundef [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 24
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 255
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i32 [[TMP0]], 16
+; CHECK-NEXT:    [[TMP4:%.*]] = and i32 [[TMP3]], 255
+; CHECK-NEXT:    [[TMP5:%.*]] = lshr i32 [[TMP0]], 8
+; CHECK-NEXT:    [[TMP6:%.*]] = and i32 [[TMP5]], 255
+; CHECK-NEXT:    [[TMP7:%.*]] = lshr i32 [[TMP0]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = and i32 [[TMP7]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP8]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP6]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP4]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP2]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -221,15 +268,25 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) {
 ; CHECK-LABEL: define void @zext_v4i16_all_lanes_used(
 ; CHECK-SAME: <4 x i16> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i16> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i64 [[TMP1]], 48
+; CHECK-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 65535
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i64 [[TMP1]], 32
+; CHECK-NEXT:    [[TMP5:%.*]] = and i64 [[TMP4]], 65535
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i64 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP7:%.*]] = and i64 [[TMP6]], 65535
+; CHECK-NEXT:    [[TMP8:%.*]] = lshr i64 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP9:%.*]] = and i64 [[TMP8]], 65535
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i16> [[SRC]] to <4 x i64>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i64> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i64> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i64> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i64> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP9]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP7]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP5]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -250,11 +307,17 @@ define void @zext_v2i32_all_lanes_used(<2 x i32> %src) {
 ; CHECK-LABEL: define void @zext_v2i32_all_lanes_used(
 ; CHECK-SAME: <2 x i32> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <2 x i32> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i64 [[TMP1]], 32
+; CHECK-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 4294967295
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i64 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP5:%.*]] = and i64 [[TMP4]], 4294967295
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <2 x i32> [[SRC]] to <2 x i64>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <2 x i64> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <2 x i64> [[EXT9]], i64 1
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_1]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP5]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:

>From 5daa841022d69033e1b93e119372aeba99572fa6 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 5 Jun 2025 19:02:26 +0100
Subject: [PATCH 2/2] !fixup address comments, thanks!

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 15 +++++-----
 .../VectorCombine/AArch64/ext-extract.ll      | 29 +++++++++++++++++++
 2 files changed, 37 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 85375654fbc19..3f299748b3583 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1718,15 +1718,18 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) {
   // Try to convert a vector zext feeding only extracts to a set of scalar (Src
   // << ExtIdx *Size) & (Size -1), if profitable.
   auto *Ext = cast<ZExtInst>(&I);
-  auto *SrcTy = cast<FixedVectorType>(Ext->getOperand(0)->getType());
+  auto *SrcTy = dyn_cast<FixedVectorType>(Ext->getOperand(0)->getType());
+  if (!SrcTy)
+    return false;
   auto *DstTy = cast<FixedVectorType>(Ext->getType());
 
-  if (DL->getTypeSizeInBits(SrcTy) !=
-      DL->getTypeSizeInBits(DstTy->getElementType()))
+  Type *ScalarDstTy = DstTy->getElementType();
+  if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy))
     return false;
 
-  InstructionCost VectorCost = TTI.getCastInstrCost(
-      Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
+  InstructionCost VectorCost =
+      TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy,
+                           TTI::CastContextHint::None, CostKind, Ext);
   unsigned ExtCnt = 0;
   bool ExtLane0 = false;
   for (User *U : Ext->users()) {
@@ -1741,7 +1744,6 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) {
                                          CostKind, Idx->getZExtValue(), U);
   }
 
-  Type *ScalarDstTy = DstTy->getElementType();
   InstructionCost ScalarCost =
       ExtCnt * TTI.getArithmeticInstrCost(
                    Instruction::And, ScalarDstTy, CostKind,
@@ -1749,7 +1751,6 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) {
                    {TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
       (ExtCnt - ExtLane0) *
           TTI.getArithmeticInstrCost(
-
               Instruction::LShr, ScalarDstTy, CostKind,
               {TTI::OK_AnyValue, TTI::OP_None},
               {TTI::OK_NonUniformConstantValue, TTI::OP_None});
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
index 23538589ae32c..97e250542f96d 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
@@ -329,3 +329,32 @@ entry:
   call void @use.i64(i64 %ext.1)
   ret void
 }
+
+define void @zext_nv4i8_all_lanes_used(<vscale x 4 x i8> %src) {
+; CHECK-LABEL: define void @zext_nv4i8_all_lanes_used(
+; CHECK-SAME: <vscale x 4 x i8> [[SRC:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <vscale x 4 x i8> [[SRC]] to <vscale x 4 x i32>
+; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <vscale x 4 x i32> [[EXT9]], i64 0
+; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <vscale x 4 x i32> [[EXT9]], i64 1
+; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <vscale x 4 x i32> [[EXT9]], i64 2
+; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <vscale x 4 x i32> [[EXT9]], i64 3
+; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
+; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
+; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %ext9 = zext nneg <vscale x 4 x i8> %src to <vscale x 4 x i32>
+  %ext.0 = extractelement <vscale x 4 x i32> %ext9, i64 0
+  %ext.1 = extractelement <vscale x 4 x i32> %ext9, i64 1
+  %ext.2 = extractelement <vscale x 4 x i32> %ext9, i64 2
+  %ext.3 = extractelement <vscale x 4 x i32> %ext9, i64 3
+
+  call void @use.i32(i32 %ext.0)
+  call void @use.i32(i32 %ext.1)
+  call void @use.i32(i32 %ext.2)
+  call void @use.i32(i32 %ext.3)
+  ret void
+}



More information about the llvm-commits mailing list