[llvm] b0473c5 - [InstCombine] Pull extract through broadcast (#143380)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 4 09:19:55 PDT 2025
Author: agorenstein-nvidia
Date: 2025-07-04T18:19:50+02:00
New Revision: b0473c599b0418c71d15150e0ea19d57df3b98e5
URL: https://github.com/llvm/llvm-project/commit/b0473c599b0418c71d15150e0ea19d57df3b98e5
DIFF: https://github.com/llvm/llvm-project/commit/b0473c599b0418c71d15150e0ea19d57df3b98e5.diff
LOG: [InstCombine] Pull extract through broadcast (#143380)
The change adds a new instcombine pattern, and associated test, for
patterns like this:
```
%3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> zeroinitializer
%4 = extractelement <4 x float> %3, i64 %idx
```
The shufflevector has a splat, or broadcast, mask, so the extractelement
simply must be the first element of %1, so we transform this to
```
%2 = extractelement <2 x float> %1, i64 0
```
Added:
llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
Modified:
llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
llvm/test/Transforms/InstCombine/vec_shuffle.ll
llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll
llvm/test/Transforms/InstCombine/vscale_extractelement.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index a746a5229fb9a..1a8d4215b5b71 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -542,27 +542,39 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
}
}
} else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
- // If this is extracting an element from a shufflevector, figure out where
- // it came from and extract from the appropriate input element instead.
- // Restrict the following transformation to fixed-length vector.
- if (isa<FixedVectorType>(SVI->getType()) && isa<ConstantInt>(Index)) {
- int SrcIdx =
- SVI->getMaskValue(cast<ConstantInt>(Index)->getZExtValue());
- Value *Src;
- unsigned LHSWidth = cast<FixedVectorType>(SVI->getOperand(0)->getType())
- ->getNumElements();
-
- if (SrcIdx < 0)
- return replaceInstUsesWith(EI, PoisonValue::get(EI.getType()));
- if (SrcIdx < (int)LHSWidth)
- Src = SVI->getOperand(0);
- else {
- SrcIdx -= LHSWidth;
- Src = SVI->getOperand(1);
+ int SplatIndex = getSplatIndex(SVI->getShuffleMask());
+ // We know the all-0 splat must be reading from the first operand, even
+ // in the case of scalable vectors (vscale is always > 0).
+ if (SplatIndex == 0)
+ return ExtractElementInst::Create(SVI->getOperand(0),
+ Builder.getInt64(0));
+
+ if (isa<FixedVectorType>(SVI->getType())) {
+ std::optional<int> SrcIdx;
+ // getSplatIndex returns -1 to mean not-found.
+ if (SplatIndex != -1)
+ SrcIdx = SplatIndex;
+ else if (ConstantInt *CI = dyn_cast<ConstantInt>(Index))
+ SrcIdx = SVI->getMaskValue(CI->getZExtValue());
+
+ if (SrcIdx) {
+ Value *Src;
+ unsigned LHSWidth =
+ cast<FixedVectorType>(SVI->getOperand(0)->getType())
+ ->getNumElements();
+
+ if (*SrcIdx < 0)
+ return replaceInstUsesWith(EI, PoisonValue::get(EI.getType()));
+ if (*SrcIdx < (int)LHSWidth)
+ Src = SVI->getOperand(0);
+ else {
+ *SrcIdx -= LHSWidth;
+ Src = SVI->getOperand(1);
+ }
+ Type *Int64Ty = Type::getInt64Ty(EI.getContext());
+ return ExtractElementInst::Create(
+ Src, ConstantInt::get(Int64Ty, *SrcIdx, false));
}
- Type *Int64Ty = Type::getInt64Ty(EI.getContext());
- return ExtractElementInst::Create(
- Src, ConstantInt::get(Int64Ty, SrcIdx, false));
}
} else if (auto *CI = dyn_cast<CastInst>(I)) {
// Canonicalize extractelement(cast) -> cast(extractelement).
diff --git a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
new file mode 100644
index 0000000000000..1ec3dc3e4b40a
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
@@ -0,0 +1,52 @@
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+define float @extract_from_zero_init_shuffle(<2 x float> %vec, i64 %idx) {
+; CHECK-LABEL: @extract_from_zero_init_shuffle(
+; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 0
+; CHECK-NEXT: ret float %extract
+;
+ %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> zeroinitializer
+ %extract = extractelement <4 x float> %shuffle, i64 %idx
+ ret float %extract
+}
+
+
+define float @extract_from_general_splat(<2 x float> %vec, i64 %idx) {
+; CHECK-LABEL: @extract_from_general_splat(
+; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 1
+; CHECK-NEXT: ret float %extract
+;
+ %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
+ %extract = extractelement <4 x float> %shuffle, i64 %idx
+ ret float %extract
+}
+
+define float @extract_from_general_scalable_splat(<vscale x 2 x float> %vec, i64 %idx) {
+; CHECK-LABEL: @extract_from_general_scalable_splat(
+; CHECK-NEXT: %extract = extractelement <vscale x 2 x float> %vec, i64 0
+; CHECK-NEXT: ret float %extract
+;
+ %shuffle = shufflevector <vscale x 2 x float> %vec, <vscale x 2 x float> poison, <vscale x 4 x i32> zeroinitializer
+ %extract = extractelement <vscale x 4 x float> %shuffle, i64 %idx
+ ret float %extract
+}
+
+define float @extract_from_splat_with_poison_0(<2 x float> %vec, i64 %idx) {
+; CHECK-LABEL: @extract_from_splat_with_poison_0(
+; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 1
+; CHECK-NEXT: ret float %extract
+;
+ %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> <i32 poison, i32 1, i32 1, i32 1>
+ %extract = extractelement <4 x float> %shuffle, i64 %idx
+ ret float %extract
+}
+
+define float @extract_from_splat_with_poison_1(<2 x float> %vec, i64 %idx) {
+; CHECK-LABEL: @extract_from_splat_with_poison_1(
+; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 1
+; CHECK-NEXT: ret float %extract
+;
+ %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> <i32 1, i32 poison, i32 1, i32 1>
+ %extract = extractelement <4 x float> %shuffle, i64 %idx
+ ret float %extract
+}
diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
index 0a9c71dba7947..86fc5bbf72e7b 100644
--- a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
@@ -61,8 +61,7 @@ define float @test6(<4 x float> %X) {
define float @testvscale6(<vscale x 4 x float> %X) {
; CHECK-LABEL: @testvscale6(
-; CHECK-NEXT: [[T2:%.*]] = shufflevector <vscale x 4 x float> [[X:%.*]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[T2]], i64 0
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[X:%.*]], i64 0
; CHECK-NEXT: ret float [[R]]
;
%X1 = bitcast <vscale x 4 x float> %X to <vscale x 4 x i32>
diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle.ll b/llvm/test/Transforms/InstCombine/vec_shuffle.ll
index 003eddf7f121b..39f76f18b13ca 100644
--- a/llvm/test/Transforms/InstCombine/vec_shuffle.ll
+++ b/llvm/test/Transforms/InstCombine/vec_shuffle.ll
@@ -67,8 +67,7 @@ define float @test6(<4 x float> %X) {
define float @testvscale6(<vscale x 4 x float> %X) {
; CHECK-LABEL: @testvscale6(
-; CHECK-NEXT: [[T2:%.*]] = shufflevector <vscale x 4 x float> [[X:%.*]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[T2]], i64 0
+; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[X:%.*]], i64 0
; CHECK-NEXT: ret float [[R]]
;
%X1 = bitcast <vscale x 4 x float> %X to <vscale x 4 x i32>
diff --git a/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll
index 2655c20354607..36ed39a3d3242 100644
--- a/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll
@@ -89,12 +89,12 @@ define i8 @extractelement_bitcast_insert_extra_use_bitcast(<vscale x 2 x i32> %a
ret i8 %r
}
+; while it may be that the extract is out-of-bounds, any valid index
+; is going to yield %v (because the mask is all-zeros).
+
define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range(
-; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[V:%.*]], i64 0
-; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4
-; CHECK-NEXT: ret i32 [[R]]
+; CHECK-NEXT: ret i32 [[V:%.*]]
;
%in = insertelement <vscale x 4 x i32> poison, i32 %v, i32 0
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
@@ -104,10 +104,7 @@ define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
define i32 @extractelement_shuffle_invalid_index(i32 %v) {
; CHECK-LABEL: @extractelement_shuffle_invalid_index(
-; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[V:%.*]], i64 0
-; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4294967295
-; CHECK-NEXT: ret i32 [[R]]
+; CHECK-NEXT: ret i32 [[V:%.*]]
;
%in = insertelement <vscale x 4 x i32> poison, i32 %v, i32 0
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll
index 07090e9099ae1..9ac8a92abb689 100644
--- a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll
+++ b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll
@@ -53,12 +53,12 @@ define i8 @extractelement_bitcast_useless_insert(<vscale x 2 x i32> %a, i32 %x)
ret i8 %r
}
+; while in these tests it may be that the extract is out-of-bounds,
+; any valid index is going to yield %v (because the mask is all-zeros).
+
define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range(
-; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> undef, i32 [[V:%.*]], i64 0
-; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4
-; CHECK-NEXT: ret i32 [[R]]
+; CHECK-NEXT: ret i32 [[V:%.*]]
;
%in = insertelement <vscale x 4 x i32> undef, i32 %v, i32 0
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
@@ -68,10 +68,7 @@ define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
define i32 @extractelement_shuffle_invalid_index(i32 %v) {
; CHECK-LABEL: @extractelement_shuffle_invalid_index(
-; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> undef, i32 [[V:%.*]], i64 0
-; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4294967295
-; CHECK-NEXT: ret i32 [[R]]
+; CHECK-NEXT: ret i32 [[V:%.*]]
;
%in = insertelement <vscale x 4 x i32> undef, i32 %v, i32 0
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
More information about the llvm-commits
mailing list