[llvm] b0473c5 - [InstCombine] Pull extract through broadcast (#143380)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 4 09:19:55 PDT 2025


Author: agorenstein-nvidia
Date: 2025-07-04T18:19:50+02:00
New Revision: b0473c599b0418c71d15150e0ea19d57df3b98e5

URL: https://github.com/llvm/llvm-project/commit/b0473c599b0418c71d15150e0ea19d57df3b98e5
DIFF: https://github.com/llvm/llvm-project/commit/b0473c599b0418c71d15150e0ea19d57df3b98e5.diff

LOG: [InstCombine] Pull extract through broadcast (#143380)

The change adds a new instcombine pattern, and associated test, for
patterns like this:

```
  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> zeroinitializer
  %4 = extractelement <4 x float> %3, i64 %idx
```

The shufflevector has a splat, or broadcast, mask, so the extractelement
simply must be the first element of %1, so we transform this to

```
  %2 = extractelement <2 x float> %1, i64 0
```

Added: 
    llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
    llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
    llvm/test/Transforms/InstCombine/vec_shuffle.ll
    llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll
    llvm/test/Transforms/InstCombine/vscale_extractelement.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index a746a5229fb9a..1a8d4215b5b71 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -542,27 +542,39 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
         }
       }
     } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
-      // If this is extracting an element from a shufflevector, figure out where
-      // it came from and extract from the appropriate input element instead.
-      // Restrict the following transformation to fixed-length vector.
-      if (isa<FixedVectorType>(SVI->getType()) && isa<ConstantInt>(Index)) {
-        int SrcIdx =
-            SVI->getMaskValue(cast<ConstantInt>(Index)->getZExtValue());
-        Value *Src;
-        unsigned LHSWidth = cast<FixedVectorType>(SVI->getOperand(0)->getType())
-                                ->getNumElements();
-
-        if (SrcIdx < 0)
-          return replaceInstUsesWith(EI, PoisonValue::get(EI.getType()));
-        if (SrcIdx < (int)LHSWidth)
-          Src = SVI->getOperand(0);
-        else {
-          SrcIdx -= LHSWidth;
-          Src = SVI->getOperand(1);
+      int SplatIndex = getSplatIndex(SVI->getShuffleMask());
+      // We know the all-0 splat must be reading from the first operand, even
+      // in the case of scalable vectors (vscale is always > 0).
+      if (SplatIndex == 0)
+        return ExtractElementInst::Create(SVI->getOperand(0),
+                                          Builder.getInt64(0));
+
+      if (isa<FixedVectorType>(SVI->getType())) {
+        std::optional<int> SrcIdx;
+        // getSplatIndex returns -1 to mean not-found.
+        if (SplatIndex != -1)
+          SrcIdx = SplatIndex;
+        else if (ConstantInt *CI = dyn_cast<ConstantInt>(Index))
+          SrcIdx = SVI->getMaskValue(CI->getZExtValue());
+
+        if (SrcIdx) {
+          Value *Src;
+          unsigned LHSWidth =
+              cast<FixedVectorType>(SVI->getOperand(0)->getType())
+                  ->getNumElements();
+
+          if (*SrcIdx < 0)
+            return replaceInstUsesWith(EI, PoisonValue::get(EI.getType()));
+          if (*SrcIdx < (int)LHSWidth)
+            Src = SVI->getOperand(0);
+          else {
+            *SrcIdx -= LHSWidth;
+            Src = SVI->getOperand(1);
+          }
+          Type *Int64Ty = Type::getInt64Ty(EI.getContext());
+          return ExtractElementInst::Create(
+              Src, ConstantInt::get(Int64Ty, *SrcIdx, false));
         }
-        Type *Int64Ty = Type::getInt64Ty(EI.getContext());
-        return ExtractElementInst::Create(
-            Src, ConstantInt::get(Int64Ty, SrcIdx, false));
       }
     } else if (auto *CI = dyn_cast<CastInst>(I)) {
       // Canonicalize extractelement(cast) -> cast(extractelement).

diff  --git a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
new file mode 100644
index 0000000000000..1ec3dc3e4b40a
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
@@ -0,0 +1,52 @@
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+define float @extract_from_zero_init_shuffle(<2 x float> %vec, i64 %idx) {
+; CHECK-LABEL: @extract_from_zero_init_shuffle(
+; CHECK-NEXT:    %extract = extractelement <2 x float> %vec, i64 0
+; CHECK-NEXT:    ret float %extract
+;
+  %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> zeroinitializer
+  %extract = extractelement <4 x float> %shuffle, i64 %idx
+  ret float %extract
+}
+
+
+define float @extract_from_general_splat(<2 x float> %vec, i64 %idx) {
+; CHECK-LABEL: @extract_from_general_splat(
+; CHECK-NEXT:    %extract = extractelement <2 x float> %vec, i64 1
+; CHECK-NEXT:    ret float %extract
+;
+  %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
+  %extract = extractelement <4 x float> %shuffle, i64 %idx
+  ret float %extract
+}
+
+define float @extract_from_general_scalable_splat(<vscale x 2 x float> %vec, i64 %idx) {
+; CHECK-LABEL: @extract_from_general_scalable_splat(
+; CHECK-NEXT:    %extract = extractelement <vscale x 2 x float> %vec, i64 0
+; CHECK-NEXT:    ret float %extract
+;
+  %shuffle = shufflevector <vscale x 2 x float> %vec, <vscale x 2 x float> poison, <vscale x 4 x i32> zeroinitializer
+  %extract = extractelement <vscale x 4 x float> %shuffle, i64 %idx
+  ret float %extract
+}
+
+define float @extract_from_splat_with_poison_0(<2 x float> %vec, i64 %idx) {
+; CHECK-LABEL: @extract_from_splat_with_poison_0(
+; CHECK-NEXT:    %extract = extractelement <2 x float> %vec, i64 1
+; CHECK-NEXT:    ret float %extract
+;
+  %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> <i32 poison, i32 1, i32 1, i32 1>
+  %extract = extractelement <4 x float> %shuffle, i64 %idx
+  ret float %extract
+}
+
+define float @extract_from_splat_with_poison_1(<2 x float> %vec, i64 %idx) {
+; CHECK-LABEL: @extract_from_splat_with_poison_1(
+; CHECK-NEXT:    %extract = extractelement <2 x float> %vec, i64 1
+; CHECK-NEXT:    ret float %extract
+;
+  %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> <i32 1, i32 poison, i32 1, i32 1>
+  %extract = extractelement <4 x float> %shuffle, i64 %idx
+  ret float %extract
+}

diff  --git a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
index 0a9c71dba7947..86fc5bbf72e7b 100644
--- a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
@@ -61,8 +61,7 @@ define float @test6(<4 x float> %X) {
 
 define float @testvscale6(<vscale x 4 x float> %X) {
 ; CHECK-LABEL: @testvscale6(
-; CHECK-NEXT:    [[T2:%.*]] = shufflevector <vscale x 4 x float> [[X:%.*]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x float> [[T2]], i64 0
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x float> [[X:%.*]], i64 0
 ; CHECK-NEXT:    ret float [[R]]
 ;
   %X1 = bitcast <vscale x 4 x float> %X to <vscale x 4 x i32>

diff  --git a/llvm/test/Transforms/InstCombine/vec_shuffle.ll b/llvm/test/Transforms/InstCombine/vec_shuffle.ll
index 003eddf7f121b..39f76f18b13ca 100644
--- a/llvm/test/Transforms/InstCombine/vec_shuffle.ll
+++ b/llvm/test/Transforms/InstCombine/vec_shuffle.ll
@@ -67,8 +67,7 @@ define float @test6(<4 x float> %X) {
 
 define float @testvscale6(<vscale x 4 x float> %X) {
 ; CHECK-LABEL: @testvscale6(
-; CHECK-NEXT:    [[T2:%.*]] = shufflevector <vscale x 4 x float> [[X:%.*]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x float> [[T2]], i64 0
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x float> [[X:%.*]], i64 0
 ; CHECK-NEXT:    ret float [[R]]
 ;
   %X1 = bitcast <vscale x 4 x float> %X to <vscale x 4 x i32>

diff  --git a/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll
index 2655c20354607..36ed39a3d3242 100644
--- a/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll
@@ -89,12 +89,12 @@ define i8 @extractelement_bitcast_insert_extra_use_bitcast(<vscale x 2 x i32> %a
   ret i8 %r
 }
 
+; while it may be that the extract is out-of-bounds, any valid index
+; is going to yield %v (because the mask is all-zeros).
+
 define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
 ; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range(
-; CHECK-NEXT:    [[IN:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[V:%.*]], i64 0
-; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    ret i32 [[V:%.*]]
 ;
   %in = insertelement <vscale x 4 x i32> poison, i32 %v, i32 0
   %splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
@@ -104,10 +104,7 @@ define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
 
 define i32 @extractelement_shuffle_invalid_index(i32 %v) {
 ; CHECK-LABEL: @extractelement_shuffle_invalid_index(
-; CHECK-NEXT:    [[IN:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[V:%.*]], i64 0
-; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4294967295
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    ret i32 [[V:%.*]]
 ;
   %in = insertelement <vscale x 4 x i32> poison, i32 %v, i32 0
   %splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer

diff  --git a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll
index 07090e9099ae1..9ac8a92abb689 100644
--- a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll
+++ b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll
@@ -53,12 +53,12 @@ define i8 @extractelement_bitcast_useless_insert(<vscale x 2 x i32> %a, i32 %x)
   ret i8 %r
 }
 
+; while in these tests it may be that the extract is out-of-bounds,
+; any valid index is going to yield %v (because the mask is all-zeros).
+
 define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
 ; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range(
-; CHECK-NEXT:    [[IN:%.*]] = insertelement <vscale x 4 x i32> undef, i32 [[V:%.*]], i64 0
-; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    ret i32 [[V:%.*]]
 ;
   %in = insertelement <vscale x 4 x i32> undef, i32 %v, i32 0
   %splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
@@ -68,10 +68,7 @@ define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
 
 define i32 @extractelement_shuffle_invalid_index(i32 %v) {
 ; CHECK-LABEL: @extractelement_shuffle_invalid_index(
-; CHECK-NEXT:    [[IN:%.*]] = insertelement <vscale x 4 x i32> undef, i32 [[V:%.*]], i64 0
-; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4294967295
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    ret i32 [[V:%.*]]
 ;
   %in = insertelement <vscale x 4 x i32> undef, i32 %v, i32 0
   %splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer


        


More information about the llvm-commits mailing list