[llvm] [InstCombine] Pull extract through broadcast (PR #143380)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 9 09:59:07 PDT 2025


https://github.com/agorenstein-nvidia updated https://github.com/llvm/llvm-project/pull/143380

>From 34c1222e9119655c5d5286bd01be13752f300bbd Mon Sep 17 00:00:00 2001
From: Aaron Gorenstein <agorenstein at nvidia.com>
Date: Fri, 6 Jun 2025 09:55:52 -0400
Subject: [PATCH 1/7] Initial changes

---
 .../Transforms/InstCombine/InstCombineVectorOps.cpp  |  6 ++++++
 .../InstCombine/vec_extract_through_broadcast.ll     | 12 ++++++++++++
 2 files changed, 18 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index f946c3856948b..f2fb26c3ffae8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -542,6 +542,12 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
         }
       }
     } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
+      // extractelt (shufflevector %v1, %v2, zeroinitializer) ->
+      // extractelt %v1, 0
+      if (isa<FixedVectorType>(SVI->getType()))
+        if (all_of(SVI->getShuffleMask(), [](int Elt) { return Elt == 0; }))
+          return ExtractElementInst::Create(SVI->getOperand(0), Builder.getInt64(0));
+
       // If this is extracting an element from a shufflevector, figure out where
       // it came from and extract from the appropriate input element instead.
       // Restrict the following transformation to fixed-length vector.
diff --git a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
new file mode 100644
index 0000000000000..ec80fbf01a3c9
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
@@ -0,0 +1,12 @@
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+define float @extract_from_zero_init_shuffle(<2 x float> %1, i64 %idx) {
+; CHECK-LABEL: @extract_from_zero_init_shuffle(
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x float> [[W:%.*]], i64 0
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> zeroinitializer
+  %4 = extractelement <4 x float> %3, i64 %idx
+  ret float %4
+}
+

>From e315c89001f13dabea489e2531621cd4f9eadd75 Mon Sep 17 00:00:00 2001
From: Aaron Gorenstein <agorenstein at nvidia.com>
Date: Mon, 9 Jun 2025 11:51:26 -0400
Subject: [PATCH 2/7] PR feedback 1

---
 .../InstCombine/InstCombineVectorOps.cpp      | 12 +++---
 .../vec_extract_through_broadcast.ll          | 37 +++++++++++++++++++
 2 files changed, 44 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index f2fb26c3ffae8..36f61fad0d887 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -542,11 +542,13 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
         }
       }
     } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
-      // extractelt (shufflevector %v1, %v2, zeroinitializer) ->
-      // extractelt %v1, 0
-      if (isa<FixedVectorType>(SVI->getType()))
-        if (all_of(SVI->getShuffleMask(), [](int Elt) { return Elt == 0; }))
-          return ExtractElementInst::Create(SVI->getOperand(0), Builder.getInt64(0));
+      // extractelt (shufflevector %v1, %v2, splat-mask) idx ->
+      // extractelt %v1, splat-mask[0]
+      if (isa<VectorType>(SVI->getType())) {
+        auto mask = SVI->getShuffleMask();
+        if (mask[0] != -1 && all_equal(mask))
+          return ExtractElementInst::Create(SVI->getOperand(0), Builder.getInt64(mask[0]));
+      }
 
       // If this is extracting an element from a shufflevector, figure out where
       // it came from and extract from the appropriate input element instead.
diff --git a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
index ec80fbf01a3c9..7afe6015779dd 100644
--- a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
+++ b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
@@ -10,3 +10,40 @@ define float @extract_from_zero_init_shuffle(<2 x float> %1, i64 %idx) {
   ret float %4
 }
 
+
+define float @extract_from_general_splat(<2 x float> %1, i64 %idx) {
+; CHECK-LABEL: @extract_from_general_splat(
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x float> [[W:%.*]], i64 1
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
+  %4 = extractelement <4 x float> %3, i64 %idx
+  ret float %4
+}
+
+define float @extract_from_general_scalable_splat(<vscale x 2 x float> %1, i64 %idx) {
+; CHECK-LABEL: @extract_from_general_scalable_splat(
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <vscale x 2 x float> [[W:%.*]], i64 0
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+  %3 = shufflevector <vscale x 2 x float> %1, <vscale x 2 x float> poison, <vscale x 4 x i32> zeroinitializer
+  %4 = extractelement <vscale x 4 x float> %3, i64 %idx
+  ret float %4
+}
+
+define float @no_extract_from_general_no_splat_0(<2 x float> %1, i64 %idx) {
+  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 undef, i32 1, i32 1, i32 1>
+  %4 = extractelement <4 x float> %3, i64 %idx
+  ret float %4
+}
+
+define float @no_extract_from_general_no_splat_1(<2 x float> %1, i64 %idx) {
+; CHECK-LABEL: @no_extract_from_general_no_splat_0(
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[W:%.*]], <2 x float> poison, <4 x i32> <i32 poison, i32 1, i32 1, i32 1>
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x float> [[TMP1]], i64 %idx
+; CHECK-NEXT:    ret float [[TMP2]]
+;
+  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 1, i32 undef, i32 1, i32 1>
+  %4 = extractelement <4 x float> %3, i64 %idx
+  ret float %4
+}

>From 274cec9375c0220abaea446243f64ef9f73027f8 Mon Sep 17 00:00:00 2001
From: Aaron Gorenstein <agorenstein at nvidia.com>
Date: Mon, 9 Jun 2025 12:02:57 -0400
Subject: [PATCH 3/7] clang-format

---
 llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 36f61fad0d887..1f2aef985524f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -547,7 +547,8 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
       if (isa<VectorType>(SVI->getType())) {
         auto mask = SVI->getShuffleMask();
         if (mask[0] != -1 && all_equal(mask))
-          return ExtractElementInst::Create(SVI->getOperand(0), Builder.getInt64(mask[0]));
+          return ExtractElementInst::Create(SVI->getOperand(0),
+                                            Builder.getInt64(mask[0]));
       }
 
       // If this is extracting an element from a shufflevector, figure out where

>From 4fda162dc315044fe9a0aa38d5db89be7c1a4a9e Mon Sep 17 00:00:00 2001
From: Aaron Gorenstein <agorenstein at nvidia.com>
Date: Mon, 9 Jun 2025 12:09:09 -0400
Subject: [PATCH 4/7] Fixing edit damage in test file

---
 .../InstCombine/vec_extract_through_broadcast.ll         | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
index 7afe6015779dd..e14bf14cfa387 100644
--- a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
+++ b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
@@ -32,14 +32,19 @@ define float @extract_from_general_scalable_splat(<vscale x 2 x float> %1, i64 %
 }
 
 define float @no_extract_from_general_no_splat_0(<2 x float> %1, i64 %idx) {
+; CHECK-LABEL: @no_extract_from_general_no_splat_0(
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[W:%.*]], <2 x float> poison, <4 x i32> <i32 poison, i32 1, i32 1, i32 1>
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x float> [[TMP1]], i64 %idx
+; CHECK-NEXT:    ret float [[TMP2]]
+;
   %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 undef, i32 1, i32 1, i32 1>
   %4 = extractelement <4 x float> %3, i64 %idx
   ret float %4
 }
 
 define float @no_extract_from_general_no_splat_1(<2 x float> %1, i64 %idx) {
-; CHECK-LABEL: @no_extract_from_general_no_splat_0(
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[W:%.*]], <2 x float> poison, <4 x i32> <i32 poison, i32 1, i32 1, i32 1>
+; CHECK-LABEL: @no_extract_from_general_no_splat_1(
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[W:%.*]], <2 x float> poison, <4 x i32> <i32 1, i32 poison, i32 1, i32 1>
 ; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x float> [[TMP1]], i64 %idx
 ; CHECK-NEXT:    ret float [[TMP2]]
 ;

>From 5d9e4d98bb5b116c0323e5229f54fd98bb44671e Mon Sep 17 00:00:00 2001
From: Aaron Gorenstein <agorenstein at nvidia.com>
Date: Mon, 9 Jun 2025 12:14:51 -0400
Subject: [PATCH 5/7] Removed unneeded check

---
 .../Transforms/InstCombine/InstCombineVectorOps.cpp    | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 1f2aef985524f..789ceaa0be0f4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -544,12 +544,10 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
     } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
       // extractelt (shufflevector %v1, %v2, splat-mask) idx ->
       // extractelt %v1, splat-mask[0]
-      if (isa<VectorType>(SVI->getType())) {
-        auto mask = SVI->getShuffleMask();
-        if (mask[0] != -1 && all_equal(mask))
-          return ExtractElementInst::Create(SVI->getOperand(0),
-                                            Builder.getInt64(mask[0]));
-      }
+      auto mask = SVI->getShuffleMask();
+      if (mask[0] != -1 && all_equal(mask))
+        return ExtractElementInst::Create(SVI->getOperand(0),
+                                          Builder.getInt64(mask[0]));
 
       // If this is extracting an element from a shufflevector, figure out where
       // it came from and extract from the appropriate input element instead.

>From a1aec2eb38245a397917c9daf1fc02f6ab8ff504 Mon Sep 17 00:00:00 2001
From: Aaron Gorenstein <agorenstein at nvidia.com>
Date: Mon, 9 Jun 2025 12:35:06 -0400
Subject: [PATCH 6/7] Improve undef/poison usage in test (per PR automation
 failure), and using meaningful global instead of magic constant for
 shufflemask poison-indicator

---
 llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp      | 2 +-
 .../Transforms/InstCombine/vec_extract_through_broadcast.ll   | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 789ceaa0be0f4..946c8cd09bb6f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -545,7 +545,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
       // extractelt (shufflevector %v1, %v2, splat-mask) idx ->
       // extractelt %v1, splat-mask[0]
       auto mask = SVI->getShuffleMask();
-      if (mask[0] != -1 && all_equal(mask))
+      if (mask[0] != PoisonMaskElem && all_equal(mask))
         return ExtractElementInst::Create(SVI->getOperand(0),
                                           Builder.getInt64(mask[0]));
 
diff --git a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
index e14bf14cfa387..b93b4d017918a 100644
--- a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
+++ b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
@@ -37,7 +37,7 @@ define float @no_extract_from_general_no_splat_0(<2 x float> %1, i64 %idx) {
 ; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x float> [[TMP1]], i64 %idx
 ; CHECK-NEXT:    ret float [[TMP2]]
 ;
-  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 undef, i32 1, i32 1, i32 1>
+  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 poison, i32 1, i32 1, i32 1>
   %4 = extractelement <4 x float> %3, i64 %idx
   ret float %4
 }
@@ -48,7 +48,7 @@ define float @no_extract_from_general_no_splat_1(<2 x float> %1, i64 %idx) {
 ; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x float> [[TMP1]], i64 %idx
 ; CHECK-NEXT:    ret float [[TMP2]]
 ;
-  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 1, i32 undef, i32 1, i32 1>
+  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 1, i32 poison, i32 1, i32 1>
   %4 = extractelement <4 x float> %3, i64 %idx
   ret float %4
 }

>From 70d3c0124b9593202eb494be0fd46d422b50ef3b Mon Sep 17 00:00:00 2001
From: Aaron Gorenstein <agorenstein at nvidia.com>
Date: Mon, 9 Jun 2025 12:58:46 -0400
Subject: [PATCH 7/7] Fixed varname convention per PR feedback; on reflection
 from PR feedback removing negative tests

---
 .../InstCombine/InstCombineVectorOps.cpp      |  6 ++---
 .../vec_extract_through_broadcast.ll          | 22 -------------------
 2 files changed, 3 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 946c8cd09bb6f..5519855a85054 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -544,10 +544,10 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
     } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
       // extractelt (shufflevector %v1, %v2, splat-mask) idx ->
       // extractelt %v1, splat-mask[0]
-      auto mask = SVI->getShuffleMask();
-      if (mask[0] != PoisonMaskElem && all_equal(mask))
+      auto Mask = SVI->getShuffleMask();
+      if (Mask[0] != PoisonMaskElem && all_equal(Mask))
         return ExtractElementInst::Create(SVI->getOperand(0),
-                                          Builder.getInt64(mask[0]));
+                                          Builder.getInt64(Mask[0]));
 
       // If this is extracting an element from a shufflevector, figure out where
       // it came from and extract from the appropriate input element instead.
diff --git a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
index b93b4d017918a..5ed837eb22760 100644
--- a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
+++ b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
@@ -30,25 +30,3 @@ define float @extract_from_general_scalable_splat(<vscale x 2 x float> %1, i64 %
   %4 = extractelement <vscale x 4 x float> %3, i64 %idx
   ret float %4
 }
-
-define float @no_extract_from_general_no_splat_0(<2 x float> %1, i64 %idx) {
-; CHECK-LABEL: @no_extract_from_general_no_splat_0(
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[W:%.*]], <2 x float> poison, <4 x i32> <i32 poison, i32 1, i32 1, i32 1>
-; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x float> [[TMP1]], i64 %idx
-; CHECK-NEXT:    ret float [[TMP2]]
-;
-  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 poison, i32 1, i32 1, i32 1>
-  %4 = extractelement <4 x float> %3, i64 %idx
-  ret float %4
-}
-
-define float @no_extract_from_general_no_splat_1(<2 x float> %1, i64 %idx) {
-; CHECK-LABEL: @no_extract_from_general_no_splat_1(
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[W:%.*]], <2 x float> poison, <4 x i32> <i32 1, i32 poison, i32 1, i32 1>
-; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x float> [[TMP1]], i64 %idx
-; CHECK-NEXT:    ret float [[TMP2]]
-;
-  %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 1, i32 poison, i32 1, i32 1>
-  %4 = extractelement <4 x float> %3, i64 %idx
-  ret float %4
-}



More information about the llvm-commits mailing list