[llvm] [InstCombine] foldVecExtTruncToExtElt - extend to handle trunc(lshr(extractelement(x,c1),c2)) -> extractelement(bitcast(x),c3) patterns. (PR #109689)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 28 03:36:35 PDT 2024


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/109689

>From 122fe3dd998e471a795b3d654c4baf574e274dea Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Sun, 22 Sep 2024 15:55:15 +0100
Subject: [PATCH 1/6] [InstCombine] Move into trunc+extractlement ->
 extractelement+bitcast fold into foldVecExtTruncToExtElt helper. NFC.

Minor refactor step for #107404
---
 .../InstCombine/InstCombineCasts.cpp          | 76 +++++++++++--------
 1 file changed, 46 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index ea51d779045718..901d95ce5d7bf7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -436,6 +436,50 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc,
   return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
 }
 
+// Whenever an element is extracted from a vector, and then truncated,
+// canonicalize by converting it to a bitcast followed by an
+// extractelement.
+//
+// Example (little endian):
+//   trunc (extractelement <4 x i64> %X, 0) to i32
+//   --->
+//   extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
+static Instruction *foldVecExtTruncToExtElt(TruncInst &Trunc,
+                                            InstCombinerImpl &IC) {
+  Value *Src = Trunc.getOperand(0);
+  Type *SrcType = Src->getType();
+  Type *DstType = Trunc.getType();
+
+  // Only attempt this if we have simple aliasing of the vector elements.
+  // A badly fit destination size would result in an invalid cast.
+  unsigned SrcBits = SrcType->getScalarSizeInBits();
+  unsigned DstBits = DstType->getScalarSizeInBits();
+  unsigned TruncRatio = SrcBits / DstBits;
+  if ((SrcBits % DstBits) != 0)
+    return nullptr;
+
+  Value *VecOp;
+  ConstantInt *Cst;
+  if (!match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst)))))
+    return nullptr;
+
+  auto *VecOpTy = cast<VectorType>(VecOp->getType());
+  auto VecElts = VecOpTy->getElementCount();
+
+  uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
+  uint64_t VecOpIdx = Cst->getZExtValue();
+  uint64_t NewIdx = IC.getDataLayout().isBigEndian()
+                        ? (VecOpIdx + 1) * TruncRatio - 1
+                        : VecOpIdx * TruncRatio;
+  assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
+         "overflow 32-bits");
+
+  auto *BitCastTo =
+      VectorType::get(DstType, BitCastNumElts, VecElts.isScalable());
+  Value *BitCast = IC.Builder.CreateBitCast(VecOp, BitCastTo);
+  return ExtractElementInst::Create(BitCast, IC.Builder.getInt32(NewIdx));
+}
+
 /// Funnel/Rotate left/right may occur in a wider type than necessary because of
 /// type promotion rules. Try to narrow the inputs and convert to funnel shift.
 Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) {
@@ -848,36 +892,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
   if (Instruction *I = foldVecTruncToExtElt(Trunc, *this))
     return I;
 
-  // Whenever an element is extracted from a vector, and then truncated,
-  // canonicalize by converting it to a bitcast followed by an
-  // extractelement.
-  //
-  // Example (little endian):
-  //   trunc (extractelement <4 x i64> %X, 0) to i32
-  //   --->
-  //   extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
-  Value *VecOp;
-  ConstantInt *Cst;
-  if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) {
-    auto *VecOpTy = cast<VectorType>(VecOp->getType());
-    auto VecElts = VecOpTy->getElementCount();
-
-    // A badly fit destination size would result in an invalid cast.
-    if (SrcWidth % DestWidth == 0) {
-      uint64_t TruncRatio = SrcWidth / DestWidth;
-      uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
-      uint64_t VecOpIdx = Cst->getZExtValue();
-      uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1
-                                         : VecOpIdx * TruncRatio;
-      assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
-             "overflow 32-bits");
-
-      auto *BitCastTo =
-          VectorType::get(DestTy, BitCastNumElts, VecElts.isScalable());
-      Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo);
-      return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx));
-    }
-  }
+  if (Instruction *I = foldVecExtTruncToExtElt(Trunc, *this))
+    return I;
 
   // trunc (ctlz_i32(zext(A), B) --> add(ctlz_i16(A, B), C)
   if (match(Src, m_OneUse(m_Intrinsic<Intrinsic::ctlz>(m_ZExt(m_Value(A)),

>From 729626b7b5473aed8e6abb8929a7c174a0f8ffd7 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Mon, 23 Sep 2024 18:10:51 +0100
Subject: [PATCH 2/6] [InstCombine] Add tests for trunc+lshr+extractlement ->
 extractelement+bitcast fold for #107404

---
 .../trunc-extractelement-inseltpoison.ll      | 84 +++++++++++++++++++
 .../InstCombine/trunc-extractelement.ll       | 84 +++++++++++++++++++
 2 files changed, 168 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/trunc-extractelement-inseltpoison.ll b/llvm/test/Transforms/InstCombine/trunc-extractelement-inseltpoison.ll
index e9e105b91f3c19..eeb98761d7080f 100644
--- a/llvm/test/Transforms/InstCombine/trunc-extractelement-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/trunc-extractelement-inseltpoison.ll
@@ -18,6 +18,19 @@ define i32 @shrinkExtractElt_i64_to_i32_0(<3 x i64> %x) {
   ret i32 %t
 }
 
+define i32 @shrinkShiftExtractElt_i64_to_i32_0(<3 x i64> %x) {
+; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
+; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 32
+; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i32
+; ANY-NEXT:    ret i32 [[T]]
+;
+  %e = extractelement <3 x i64> %x, i32 0
+  %s = lshr i64 %e, 32
+  %t = trunc i64 %s to i32
+  ret i32 %t
+}
+
 define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
 ; LE-LABEL: @vscale_shrinkExtractElt_i64_to_i32_0(
 ; LE-NEXT:    [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
@@ -34,6 +47,18 @@ define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
   ret i32 %t
 }
 
+define i32 @vscale_shrinkShiftExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
+; ANY-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
+; ANY-NEXT:    [[E:%.*]] = extractelement <vscale x 3 x i64> [[X:%.*]], i64 0
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 32
+; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i32
+; ANY-NEXT:    ret i32 [[T]]
+;
+  %e = extractelement <vscale x 3 x i64> %x, i32 0
+  %s = lshr i64 %e, 32
+  %t = trunc i64 %s to i32
+  ret i32 %t
+}
 
 define i32 @shrinkExtractElt_i64_to_i32_1(<3 x i64> %x) {
 ; LE-LABEL: @shrinkExtractElt_i64_to_i32_1(
@@ -83,6 +108,19 @@ define i16 @shrinkExtractElt_i64_to_i16_0(<3 x i64> %x) {
   ret i16 %t
 }
 
+define i16 @shrinkShiftExtractElt_i64_to_i16_0(<3 x i64> %x) {
+; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
+; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 48
+; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i16
+; ANY-NEXT:    ret i16 [[T]]
+;
+  %e = extractelement <3 x i64> %x, i16 0
+  %s = ashr i64 %e, 48
+  %t = trunc i64 %s to i16
+  ret i16 %t
+}
+
 define i16 @shrinkExtractElt_i64_to_i16_1(<3 x i64> %x) {
 ; LE-LABEL: @shrinkExtractElt_i64_to_i16_1(
 ; LE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
@@ -157,6 +195,20 @@ define i30 @shrinkExtractElt_i40_to_i30_1(<3 x i40> %x) {
   ret i30 %t
 }
 
+; Do not optimize if the shift amount isn't a whole number of truncated bits.
+define i16 @shrinkShiftExtractElt_i64_to_i16_0_badshift(<3 x i64> %x) {
+; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_0_badshift(
+; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 31
+; ANY-NEXT:    [[T:%.*]] = trunc i64 [[S]] to i16
+; ANY-NEXT:    ret i16 [[T]]
+;
+  %e = extractelement <3 x i64> %x, i16 0
+  %s = lshr i64 %e, 31
+  %t = trunc i64 %s to i16
+  ret i16 %t
+}
+
 ; Do not canonicalize if that would increase the instruction count.
 declare void @use(i64)
 define i16 @shrinkExtractElt_i64_to_i16_2_extra_use(<3 x i64> %x) {
@@ -172,6 +224,38 @@ define i16 @shrinkExtractElt_i64_to_i16_2_extra_use(<3 x i64> %x) {
   ret i16 %t
 }
 
+; Do not canonicalize if that would increase the instruction count.
+define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_shift_use(<3 x i64> %x) {
+; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_shift_use(
+; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 48
+; ANY-NEXT:    call void @use(i64 [[S]])
+; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i16
+; ANY-NEXT:    ret i16 [[T]]
+;
+  %e = extractelement <3 x i64> %x, i64 2
+  %s = lshr i64 %e, 48
+  call void @use(i64 %s)
+  %t = trunc i64 %s to i16
+  ret i16 %t
+}
+
+; OK to reuse the extract if we remove the shift+trunc.
+define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(<3 x i64> %x) {
+; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
+; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
+; ANY-NEXT:    call void @use(i64 [[E]])
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 48
+; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i16
+; ANY-NEXT:    ret i16 [[T]]
+;
+  %e = extractelement <3 x i64> %x, i64 2
+  call void @use(i64 %e)
+  %s = lshr i64 %e, 48
+  %t = trunc i64 %s to i16
+  ret i16 %t
+}
+
 ; Check to ensure PR45314 remains fixed.
 define <4 x i64> @PR45314(<4 x i64> %x) {
 ; LE-LABEL: @PR45314(
diff --git a/llvm/test/Transforms/InstCombine/trunc-extractelement.ll b/llvm/test/Transforms/InstCombine/trunc-extractelement.ll
index 5e62ca9cd591da..87125d407f98f3 100644
--- a/llvm/test/Transforms/InstCombine/trunc-extractelement.ll
+++ b/llvm/test/Transforms/InstCombine/trunc-extractelement.ll
@@ -18,6 +18,19 @@ define i32 @shrinkExtractElt_i64_to_i32_0(<3 x i64> %x) {
   ret i32 %t
 }
 
+define i32 @shrinkShiftExtractElt_i64_to_i32_0(<3 x i64> %x) {
+; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
+; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 32
+; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i32
+; ANY-NEXT:    ret i32 [[T]]
+;
+  %e = extractelement <3 x i64> %x, i32 0
+  %s = lshr i64 %e, 32
+  %t = trunc i64 %s to i32
+  ret i32 %t
+}
+
 define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
 ; LE-LABEL: @vscale_shrinkExtractElt_i64_to_i32_0(
 ; LE-NEXT:    [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
@@ -34,6 +47,18 @@ define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
   ret i32 %t
 }
 
+define i32 @vscale_shrinkShiftExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
+; ANY-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
+; ANY-NEXT:    [[E:%.*]] = extractelement <vscale x 3 x i64> [[X:%.*]], i64 0
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 32
+; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i32
+; ANY-NEXT:    ret i32 [[T]]
+;
+  %e = extractelement <vscale x 3 x i64> %x, i32 0
+  %s = lshr i64 %e, 32
+  %t = trunc i64 %s to i32
+  ret i32 %t
+}
 
 define i32 @shrinkExtractElt_i64_to_i32_1(<3 x i64> %x) {
 ; LE-LABEL: @shrinkExtractElt_i64_to_i32_1(
@@ -83,6 +108,19 @@ define i16 @shrinkExtractElt_i64_to_i16_0(<3 x i64> %x) {
   ret i16 %t
 }
 
+define i16 @shrinkShiftExtractElt_i64_to_i16_0(<3 x i64> %x) {
+; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
+; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 48
+; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i16
+; ANY-NEXT:    ret i16 [[T]]
+;
+  %e = extractelement <3 x i64> %x, i16 0
+  %s = ashr i64 %e, 48
+  %t = trunc i64 %s to i16
+  ret i16 %t
+}
+
 define i16 @shrinkExtractElt_i64_to_i16_1(<3 x i64> %x) {
 ; LE-LABEL: @shrinkExtractElt_i64_to_i16_1(
 ; LE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
@@ -157,6 +195,20 @@ define i30 @shrinkExtractElt_i40_to_i30_1(<3 x i40> %x) {
   ret i30 %t
 }
 
+; Do not optimize if the shift amount isn't a whole number of truncated bits.
+define i16 @shrinkShiftExtractElt_i64_to_i16_0_badshift(<3 x i64> %x) {
+; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_0_badshift(
+; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 31
+; ANY-NEXT:    [[T:%.*]] = trunc i64 [[S]] to i16
+; ANY-NEXT:    ret i16 [[T]]
+;
+  %e = extractelement <3 x i64> %x, i16 0
+  %s = lshr i64 %e, 31
+  %t = trunc i64 %s to i16
+  ret i16 %t
+}
+
 ; Do not canonicalize if that would increase the instruction count.
 declare void @use(i64)
 define i16 @shrinkExtractElt_i64_to_i16_2_extra_use(<3 x i64> %x) {
@@ -172,6 +224,38 @@ define i16 @shrinkExtractElt_i64_to_i16_2_extra_use(<3 x i64> %x) {
   ret i16 %t
 }
 
+; Do not canonicalize if that would increase the instruction count.
+define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_shift_use(<3 x i64> %x) {
+; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_shift_use(
+; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 48
+; ANY-NEXT:    call void @use(i64 [[S]])
+; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i16
+; ANY-NEXT:    ret i16 [[T]]
+;
+  %e = extractelement <3 x i64> %x, i64 2
+  %s = lshr i64 %e, 48
+  call void @use(i64 %s)
+  %t = trunc i64 %s to i16
+  ret i16 %t
+}
+
+; OK to reuse the extract if we remove the shift+trunc.
+define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(<3 x i64> %x) {
+; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
+; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
+; ANY-NEXT:    call void @use(i64 [[E]])
+; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 48
+; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i16
+; ANY-NEXT:    ret i16 [[T]]
+;
+  %e = extractelement <3 x i64> %x, i64 2
+  call void @use(i64 %e)
+  %s = lshr i64 %e, 48
+  %t = trunc i64 %s to i16
+  ret i16 %t
+}
+
 ; Check to ensure PR45314 remains fixed.
 define <4 x i64> @PR45314(<4 x i64> %x) {
 ; LE-LABEL: @PR45314(

>From 8c9fcd914d1f04caa991dee22b59006e3101b65a Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Mon, 23 Sep 2024 18:13:36 +0100
Subject: [PATCH 3/6] [InstCombine] foldVecExtTruncToExtElt - extend to handle
 trunc(lshr(extractelement(x,c1),c2)) -> extractelement(bitcast(x),c3)
 patterns.

Fixes #107404
---
 .../InstCombine/InstCombineCasts.cpp          | 23 ++++++-
 .../trunc-extractelement-inseltpoison.ll      | 61 ++++++++++++-------
 .../InstCombine/trunc-extractelement.ll       | 61 ++++++++++++-------
 3 files changed, 100 insertions(+), 45 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 901d95ce5d7bf7..874ea515594a56 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -460,7 +460,11 @@ static Instruction *foldVecExtTruncToExtElt(TruncInst &Trunc,
 
   Value *VecOp;
   ConstantInt *Cst;
-  if (!match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst)))))
+  const APInt *ShiftAmount = nullptr;
+  if (!match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst)))) &&
+      !match(Src,
+             m_OneUse(m_LShr(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst)),
+                             m_APInt(ShiftAmount)))))
     return nullptr;
 
   auto *VecOpTy = cast<VectorType>(VecOp->getType());
@@ -469,10 +473,23 @@ static Instruction *foldVecExtTruncToExtElt(TruncInst &Trunc,
   uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
   uint64_t VecOpIdx = Cst->getZExtValue();
   uint64_t NewIdx = IC.getDataLayout().isBigEndian()
-                        ? (VecOpIdx + 1) * TruncRatio - 1
+                        ? (VecOpIdx * TruncRatio) + (TruncRatio - 1)
                         : VecOpIdx * TruncRatio;
+
+  // Adjust index by the whole number of truncated elements.
+  if (ShiftAmount) {
+    // Check shift amount is in range and shifts a whole number of truncated
+    // elements.
+    if (ShiftAmount->uge(SrcBits) || ShiftAmount->urem(DstBits) != 0)
+      return nullptr;
+
+    uint64_t IdxOfs = ShiftAmount->udiv(DstBits).getZExtValue();
+    NewIdx = IC.getDataLayout().isBigEndian() ? (NewIdx - IdxOfs)
+                                              : (NewIdx + IdxOfs);
+  }
+
   assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
-         "overflow 32-bits");
+         NewIdx <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
 
   auto *BitCastTo =
       VectorType::get(DstType, BitCastNumElts, VecElts.isScalable());
diff --git a/llvm/test/Transforms/InstCombine/trunc-extractelement-inseltpoison.ll b/llvm/test/Transforms/InstCombine/trunc-extractelement-inseltpoison.ll
index eeb98761d7080f..5d32158e61715b 100644
--- a/llvm/test/Transforms/InstCombine/trunc-extractelement-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/trunc-extractelement-inseltpoison.ll
@@ -19,11 +19,15 @@ define i32 @shrinkExtractElt_i64_to_i32_0(<3 x i64> %x) {
 }
 
 define i32 @shrinkShiftExtractElt_i64_to_i32_0(<3 x i64> %x) {
-; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
-; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
-; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 32
-; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i32
-; ANY-NEXT:    ret i32 [[T]]
+; LE-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
+; LE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <6 x i32>
+; LE-NEXT:    [[T:%.*]] = extractelement <6 x i32> [[TMP1]], i64 1
+; LE-NEXT:    ret i32 [[T]]
+;
+; BE-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
+; BE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <6 x i32>
+; BE-NEXT:    [[T:%.*]] = extractelement <6 x i32> [[TMP1]], i64 0
+; BE-NEXT:    ret i32 [[T]]
 ;
   %e = extractelement <3 x i64> %x, i32 0
   %s = lshr i64 %e, 32
@@ -48,11 +52,15 @@ define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
 }
 
 define i32 @vscale_shrinkShiftExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
-; ANY-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
-; ANY-NEXT:    [[E:%.*]] = extractelement <vscale x 3 x i64> [[X:%.*]], i64 0
-; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 32
-; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i32
-; ANY-NEXT:    ret i32 [[T]]
+; LE-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
+; LE-NEXT:    [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
+; LE-NEXT:    [[T:%.*]] = extractelement <vscale x 6 x i32> [[TMP1]], i64 1
+; LE-NEXT:    ret i32 [[T]]
+;
+; BE-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
+; BE-NEXT:    [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
+; BE-NEXT:    [[T:%.*]] = extractelement <vscale x 6 x i32> [[TMP1]], i64 0
+; BE-NEXT:    ret i32 [[T]]
 ;
   %e = extractelement <vscale x 3 x i64> %x, i32 0
   %s = lshr i64 %e, 32
@@ -109,11 +117,15 @@ define i16 @shrinkExtractElt_i64_to_i16_0(<3 x i64> %x) {
 }
 
 define i16 @shrinkShiftExtractElt_i64_to_i16_0(<3 x i64> %x) {
-; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
-; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
-; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 48
-; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i16
-; ANY-NEXT:    ret i16 [[T]]
+; LE-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
+; LE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
+; LE-NEXT:    [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 3
+; LE-NEXT:    ret i16 [[T]]
+;
+; BE-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
+; BE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
+; BE-NEXT:    [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 0
+; BE-NEXT:    ret i16 [[T]]
 ;
   %e = extractelement <3 x i64> %x, i16 0
   %s = ashr i64 %e, 48
@@ -242,12 +254,19 @@ define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_shift_use(<3 x i64> %x) {
 
 ; OK to reuse the extract if we remove the shift+trunc.
 define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(<3 x i64> %x) {
-; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
-; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
-; ANY-NEXT:    call void @use(i64 [[E]])
-; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 48
-; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i16
-; ANY-NEXT:    ret i16 [[T]]
+; LE-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
+; LE-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
+; LE-NEXT:    call void @use(i64 [[E]])
+; LE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X]] to <12 x i16>
+; LE-NEXT:    [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 11
+; LE-NEXT:    ret i16 [[T]]
+;
+; BE-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
+; BE-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
+; BE-NEXT:    call void @use(i64 [[E]])
+; BE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X]] to <12 x i16>
+; BE-NEXT:    [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 8
+; BE-NEXT:    ret i16 [[T]]
 ;
   %e = extractelement <3 x i64> %x, i64 2
   call void @use(i64 %e)
diff --git a/llvm/test/Transforms/InstCombine/trunc-extractelement.ll b/llvm/test/Transforms/InstCombine/trunc-extractelement.ll
index 87125d407f98f3..ba2d07009d9c78 100644
--- a/llvm/test/Transforms/InstCombine/trunc-extractelement.ll
+++ b/llvm/test/Transforms/InstCombine/trunc-extractelement.ll
@@ -19,11 +19,15 @@ define i32 @shrinkExtractElt_i64_to_i32_0(<3 x i64> %x) {
 }
 
 define i32 @shrinkShiftExtractElt_i64_to_i32_0(<3 x i64> %x) {
-; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
-; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
-; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 32
-; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i32
-; ANY-NEXT:    ret i32 [[T]]
+; LE-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
+; LE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <6 x i32>
+; LE-NEXT:    [[T:%.*]] = extractelement <6 x i32> [[TMP1]], i64 1
+; LE-NEXT:    ret i32 [[T]]
+;
+; BE-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
+; BE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <6 x i32>
+; BE-NEXT:    [[T:%.*]] = extractelement <6 x i32> [[TMP1]], i64 0
+; BE-NEXT:    ret i32 [[T]]
 ;
   %e = extractelement <3 x i64> %x, i32 0
   %s = lshr i64 %e, 32
@@ -48,11 +52,15 @@ define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
 }
 
 define i32 @vscale_shrinkShiftExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
-; ANY-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
-; ANY-NEXT:    [[E:%.*]] = extractelement <vscale x 3 x i64> [[X:%.*]], i64 0
-; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 32
-; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i32
-; ANY-NEXT:    ret i32 [[T]]
+; LE-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
+; LE-NEXT:    [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
+; LE-NEXT:    [[T:%.*]] = extractelement <vscale x 6 x i32> [[TMP1]], i64 1
+; LE-NEXT:    ret i32 [[T]]
+;
+; BE-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
+; BE-NEXT:    [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
+; BE-NEXT:    [[T:%.*]] = extractelement <vscale x 6 x i32> [[TMP1]], i64 0
+; BE-NEXT:    ret i32 [[T]]
 ;
   %e = extractelement <vscale x 3 x i64> %x, i32 0
   %s = lshr i64 %e, 32
@@ -109,11 +117,15 @@ define i16 @shrinkExtractElt_i64_to_i16_0(<3 x i64> %x) {
 }
 
 define i16 @shrinkShiftExtractElt_i64_to_i16_0(<3 x i64> %x) {
-; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
-; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
-; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 48
-; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i16
-; ANY-NEXT:    ret i16 [[T]]
+; LE-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
+; LE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
+; LE-NEXT:    [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 3
+; LE-NEXT:    ret i16 [[T]]
+;
+; BE-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
+; BE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
+; BE-NEXT:    [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 0
+; BE-NEXT:    ret i16 [[T]]
 ;
   %e = extractelement <3 x i64> %x, i16 0
   %s = ashr i64 %e, 48
@@ -242,12 +254,19 @@ define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_shift_use(<3 x i64> %x) {
 
 ; OK to reuse the extract if we remove the shift+trunc.
 define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(<3 x i64> %x) {
-; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
-; ANY-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
-; ANY-NEXT:    call void @use(i64 [[E]])
-; ANY-NEXT:    [[S:%.*]] = lshr i64 [[E]], 48
-; ANY-NEXT:    [[T:%.*]] = trunc nuw i64 [[S]] to i16
-; ANY-NEXT:    ret i16 [[T]]
+; LE-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
+; LE-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
+; LE-NEXT:    call void @use(i64 [[E]])
+; LE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X]] to <12 x i16>
+; LE-NEXT:    [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 11
+; LE-NEXT:    ret i16 [[T]]
+;
+; BE-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
+; BE-NEXT:    [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
+; BE-NEXT:    call void @use(i64 [[E]])
+; BE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X]] to <12 x i16>
+; BE-NEXT:    [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 8
+; BE-NEXT:    ret i16 [[T]]
 ;
   %e = extractelement <3 x i64> %x, i64 2
   call void @use(i64 %e)

>From e86ca2b93e0347b1b27eb9e5c23d4c684b48a706 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Sat, 28 Sep 2024 11:29:39 +0100
Subject: [PATCH 4/6] Fix header comment to appease doxygen

---
 .../Transforms/InstCombine/InstCombineCasts.cpp  | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 874ea515594a56..c270e7fa905cc2 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -436,14 +436,14 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc,
   return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
 }
 
-// Whenever an element is extracted from a vector, and then truncated,
-// canonicalize by converting it to a bitcast followed by an
-// extractelement.
-//
-// Example (little endian):
-//   trunc (extractelement <4 x i64> %X, 0) to i32
-//   --->
-//   extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
+/// Whenever an element is extracted from a vector, and then truncated,
+/// canonicalize by converting it to a bitcast followed by an
+/// extractelement.
+///
+/// Example (little endian):
+///   trunc (extractelement <4 x i64> %X, 0) to i32
+///   --->
+///   extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
 static Instruction *foldVecExtTruncToExtElt(TruncInst &Trunc,
                                             InstCombinerImpl &IC) {
   Value *Src = Trunc.getOperand(0);

>From c7169ebc42844c439cac0dc8ed9742b59ed280fa Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Sat, 28 Sep 2024 11:34:07 +0100
Subject: [PATCH 5/6] Update comment to describe
 trunc(lshr(extractelement(x,c1),c2)) pattern as well

---
 llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index c270e7fa905cc2..e80c4ecbc04569 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -436,14 +436,18 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc,
   return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
 }
 
-/// Whenever an element is extracted from a vector, and then truncated,
-/// canonicalize by converting it to a bitcast followed by an
+/// Whenever an element is extracted from a vector, optionally shifted down, and
+/// then truncated, canonicalize by converting it to a bitcast followed by an
 /// extractelement.
 ///
-/// Example (little endian):
+/// Examples (little endian):
 ///   trunc (extractelement <4 x i64> %X, 0) to i32
 ///   --->
 ///   extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
+///
+///   trunc (lshr (extractelement <4 x i32> %X, 0), 8) to i8
+///   --->
+///   extractelement <16 x i8> (bitcast <4 x i32> %X to <16 x i8>), i32 1
 static Instruction *foldVecExtTruncToExtElt(TruncInst &Trunc,
                                             InstCombinerImpl &IC) {
   Value *Src = Trunc.getOperand(0);

>From e8811333434c941f531ed82fdfc27751e3a3d207 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Sat, 28 Sep 2024 11:36:11 +0100
Subject: [PATCH 6/6] Remove NFC refactor of BE element scaling

---
 llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index e80c4ecbc04569..9934c065ebf85f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -477,7 +477,7 @@ static Instruction *foldVecExtTruncToExtElt(TruncInst &Trunc,
   uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
   uint64_t VecOpIdx = Cst->getZExtValue();
   uint64_t NewIdx = IC.getDataLayout().isBigEndian()
-                        ? (VecOpIdx * TruncRatio) + (TruncRatio - 1)
+                        ? (VecOpIdx + 1) * TruncRatio - 1
                         : VecOpIdx * TruncRatio;
 
   // Adjust index by the whole number of truncated elements.



More information about the llvm-commits mailing list