[llvm] [IA][RISCV] Support VP intrinsics in InterleavedAccessPass (PR #120490)

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 18 15:11:14 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Min-Yih Hsu (mshockwave)

<details>
<summary>Changes</summary>

Teach InterleavedAccessPass to recognize the following patterns:
  - vp.store an interleaved scalable vector
  - Deinterleaving a scalable vector loaded from vp.load
  - Deinterleaving a scalable vector loaded from a vp.strided.load

Upon recognizing these patterns, IA will collect the interleaved / deinterleaved operands and delegate them over to their respective newly-added TLI hooks.

For RISC-V, these patterns are lowered into segmented loads/stores (except when we're interleaving constant splats, in which case a unit-strde store will be generated)

Right now we only recognized power-of-two (de)interleave cases, in which (de)interleave4/8 are synthesized from a tree of (de)interleave2.

-----

Note that #<!-- -->89276 also did a similar things on recognizing power-of-two (de)interleave expressions in AArch64's TLI. Though it was using patterns of a single (de)interleave factor while I was using BFS on the tree and supports factor of 2, 4, and 8. Maybe we can consolidate this part of the logics together in the future.

---

Patch is 89.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120490.diff


7 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+42) 
- (modified) llvm/lib/CodeGen/InterleavedAccessPass.cpp (+283) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+410) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+12) 
- (added) llvm/test/CodeGen/RISCV/rvv/scalable-vectors-interleaved-access.ll (+512) 
- (added) llvm/test/CodeGen/RISCV/rvv/scalable-vectors-strided-interleave-load-32.ll (+161) 
- (added) llvm/test/CodeGen/RISCV/rvv/scalable-vectors-strided-interleave-load-64.ll (+171) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 3751aac4df8ead..823f8aa8c9a7ef 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -94,6 +94,7 @@ class TargetRegisterClass;
 class TargetRegisterInfo;
 class TargetTransformInfo;
 class Value;
+class VPIntrinsic;
 
 namespace Sched {
 
@@ -3152,6 +3153,47 @@ class TargetLoweringBase {
     return false;
   }
 
+  /// Lower an interleaved load to target specific intrinsics. Return
+  /// true on success.
+  ///
+  /// \p Load is a vp.load instruction.
+  /// \p Mask is a mask value
+  /// \p DeinterleaveIntrin is vector.deinterleave intrinsic
+  /// \p Factor is the interleave factor.
+  /// \p DeinterleaveRes is a list of deinterleaved results.
+  virtual bool lowerInterleavedScalableLoad(
+      VPIntrinsic *Load, Value *Mask, IntrinsicInst *DeinterleaveIntrin,
+      unsigned Factor, ArrayRef<Value *> DeinterleaveRes) const {
+    return false;
+  }
+
+  /// Lower an interleaved store to target specific intrinsics. Return
+  /// true on success.
+  ///
+  /// \p Store is the vp.store instruction.
+  /// \p Mask is a mask value
+  /// \p InterleaveIntrin is vector.interleave intrinsic
+  /// \p Factor is the interleave factor.
+  /// \p InterleaveOps is a list of values being interleaved.
+  virtual bool lowerInterleavedScalableStore(
+      VPIntrinsic *Store, Value *Mask, IntrinsicInst *InterleaveIntrin,
+      unsigned Factor, ArrayRef<Value *> InterleaveOps) const {
+    return false;
+  }
+
+  /// Lower a deinterleave intrinsic to a target specific strided load
+  /// intrinsic. Return true on success.
+  ///
+  /// \p StridedLoad is the vp.strided.load instruction.
+  /// \p DI is the deinterleave intrinsic.
+  /// \p Factor is the interleave factor.
+  /// \p DeinterleaveRes is a list of deinterleaved results.
+  virtual bool lowerDeinterleaveIntrinsicToStridedLoad(
+      VPIntrinsic *StridedLoad, IntrinsicInst *DI, unsigned Factor,
+      ArrayRef<Value *> DeinterleaveRes) const {
+    return false;
+  }
+
   /// Lower a deinterleave intrinsic to a target specific load intrinsic.
   /// Return true on success. Currently only supports
   /// llvm.vector.deinterleave2
diff --git a/llvm/lib/CodeGen/InterleavedAccessPass.cpp b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
index 8b6e3180986c30..0f3b65b8d9af2f 100644
--- a/llvm/lib/CodeGen/InterleavedAccessPass.cpp
+++ b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
@@ -60,6 +60,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/Casting.h"
@@ -248,6 +249,186 @@ static bool isReInterleaveMask(ShuffleVectorInst *SVI, unsigned &Factor,
   return false;
 }
 
+// For an (de)interleave tree like this:
+//
+//   A   C B   D
+//   |___| |___|
+//     |_____|
+//        |
+//     A B C D
+//
+//  We will get ABCD at the end while the leave operands/results
+//  are ACBD, which are also what we initially collected in
+//  getVectorInterleaveFactor / getVectorDeinterleaveFactor. But TLI
+//  hooks (e.g. lowerInterleavedScalableLoad) expect ABCD, so we need
+//  to reorder them by interleaving these values.
+static void interleaveLeafValues(SmallVectorImpl<Value *> &Leaves) {
+  unsigned Factor = Leaves.size();
+  assert(isPowerOf2_32(Factor) && Factor <= 8 && Factor > 1);
+
+  if (Factor == 2)
+    return;
+
+  SmallVector<Value *, 8> Buffer;
+  if (Factor == 4) {
+    for (unsigned SrcIdx : {0, 2, 1, 3})
+      Buffer.push_back(Leaves[SrcIdx]);
+  } else {
+    // Factor of 8.
+    //
+    //  A E C G B F D H
+    //  |_| |_| |_| |_|
+    //   |___|   |___|
+    //     |_______|
+    //         |
+    //  A B C D E F G H
+    for (unsigned SrcIdx : {0, 4, 2, 6, 1, 5, 3, 7})
+      Buffer.push_back(Leaves[SrcIdx]);
+  }
+
+  llvm::copy(Buffer, Leaves.begin());
+}
+
+static unsigned getVectorInterleaveFactor(IntrinsicInst *II,
+                                          SmallVectorImpl<Value *> &Operands) {
+  if (II->getIntrinsicID() != Intrinsic::vector_interleave2)
+    return 0;
+
+  unsigned Factor = 0;
+
+  // Visit with BFS
+  SmallVector<IntrinsicInst *, 8> Queue;
+  Queue.push_back(II);
+  while (!Queue.empty()) {
+    IntrinsicInst *Current = Queue.front();
+    Queue.erase(Queue.begin());
+
+    for (unsigned I = 0; I < 2; ++I) {
+      Value *Op = Current->getOperand(I);
+      if (auto *OpII = dyn_cast<IntrinsicInst>(Op))
+        if (OpII->getIntrinsicID() == Intrinsic::vector_interleave2) {
+          Queue.push_back(OpII);
+          continue;
+        }
+
+      ++Factor;
+      Operands.push_back(Op);
+    }
+  }
+
+  // Currently we only recognize power-of-two factors.
+  // FIXME: should we assert here instead?
+  if (Factor > 1 && isPowerOf2_32(Factor)) {
+    interleaveLeafValues(Operands);
+    return Factor;
+  }
+  return 0;
+}
+
+/// Check the interleaved mask
+///
+/// - if a value within the optional is non-nullptr, the value corresponds to
+///   deinterleaved mask
+/// - if a value within the option is nullptr, the value corresponds to all-true
+///   mask
+/// - return nullopt if mask cannot be deinterleaved
+static std::optional<Value *> getMask(Value *WideMask, unsigned Factor) {
+  using namespace llvm::PatternMatch;
+  if (auto *IMI = dyn_cast<IntrinsicInst>(WideMask)) {
+    SmallVector<Value *, 8> Operands;
+    if (unsigned MaskFactor = getVectorInterleaveFactor(IMI, Operands)) {
+      assert(!Operands.empty());
+      if (MaskFactor == Factor &&
+          std::equal(Operands.begin(), Operands.end(), Operands.begin()))
+        return Operands.front();
+    }
+  }
+  if (match(WideMask, m_AllOnes()))
+    return nullptr;
+  return std::nullopt;
+}
+
+static unsigned getVectorDeInterleaveFactor(IntrinsicInst *II,
+                                            SmallVectorImpl<Value *> &Results) {
+  using namespace PatternMatch;
+  if (II->getIntrinsicID() != Intrinsic::vector_deinterleave2 ||
+      !II->hasNUses(2))
+    return 0;
+
+  unsigned Factor = 0;
+
+  // Visit with BFS
+  SmallVector<IntrinsicInst *, 8> Queue;
+  Queue.push_back(II);
+  while (!Queue.empty()) {
+    IntrinsicInst *Current = Queue.front();
+    Queue.erase(Queue.begin());
+    assert(Current->hasNUses(2));
+
+    unsigned VisitedIdx = 0;
+    for (User *Usr : Current->users()) {
+      // We're playing safe here and matches only the expression
+      // consisting of a perfectly balanced binary tree in which all
+      // intermediate values are only used once.
+      if (!Usr->hasOneUse() || !isa<ExtractValueInst>(Usr))
+        return 0;
+
+      auto *EV = cast<ExtractValueInst>(Usr);
+      ArrayRef<unsigned> Indices = EV->getIndices();
+      if (Indices.size() != 1 || Indices[0] >= 2)
+        return 0;
+
+      // The idea is that we don't want to have two extractvalue
+      // on the same index. So we XOR (index + 1) onto VisitedIdx
+      // such that if there is any duplication, VisitedIdx will be
+      // zero.
+      VisitedIdx ^= Indices[0] + 1;
+      if (!VisitedIdx)
+        return 0;
+      // We have a legal index. At this point we're either going
+      // to continue the traversal or push the leaf values into Results.
+      // But in either cases we need to follow the order imposed by
+      // ExtractValue's indices and swap with the last element pushed
+      // into Queue/Results if necessary (This is also one of the main
+      // reasons using BFS instead of DFS here, btw).
+
+      // When VisitedIdx equals to 0b11, we're the last visted ExtractValue.
+      // So if the current index is 0, we need to swap. Conversely, when
+      // we're either the first visited ExtractValue or the last operand
+      // in Queue/Results is of index 0, there is no need to swap.
+      bool SwapWithLast = VisitedIdx == 0b11 && Indices[0] == 0;
+
+      // Continue the traversal.
+      if (match(EV->user_back(),
+                m_Intrinsic<Intrinsic::vector_deinterleave2>()) &&
+          EV->user_back()->hasNUses(2)) {
+        auto *EVUsr = cast<IntrinsicInst>(EV->user_back());
+        if (SwapWithLast)
+          Queue.insert(Queue.end() - 1, EVUsr);
+        else
+          Queue.push_back(EVUsr);
+        continue;
+      }
+
+      // Save the leaf value.
+      if (SwapWithLast)
+        Results.insert(Results.end() - 1, EV);
+      else
+        Results.push_back(EV);
+
+      ++Factor;
+    }
+  }
+
+  // Currently we only recognize power-of-two factors.
+  // FIXME: should we assert here instead?
+  if (Factor > 1 && isPowerOf2_32(Factor)) {
+    interleaveLeafValues(Results);
+    return Factor;
+  }
+  return 0;
+}
+
 bool InterleavedAccessImpl::lowerInterleavedLoad(
     LoadInst *LI, SmallVectorImpl<Instruction *> &DeadInsts) {
   if (!LI->isSimple() || isa<ScalableVectorType>(LI->getType()))
@@ -480,6 +661,81 @@ bool InterleavedAccessImpl::lowerInterleavedStore(
 
 bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
     IntrinsicInst *DI, SmallVectorImpl<Instruction *> &DeadInsts) {
+  using namespace PatternMatch;
+  SmallVector<Value *, 8> DeInterleaveResults;
+  unsigned Factor = getVectorDeInterleaveFactor(DI, DeInterleaveResults);
+
+  if (auto *VPLoad = dyn_cast<VPIntrinsic>(DI->getOperand(0));
+      Factor && VPLoad) {
+    if (!match(VPLoad, m_OneUse(m_Intrinsic<Intrinsic::vp_load>())))
+      return false;
+
+    // Check mask operand. Handle both all-true and interleaved mask.
+    Value *WideMask = VPLoad->getOperand(1);
+    std::optional<Value *> Mask = getMask(WideMask, Factor);
+    if (!Mask)
+      return false;
+
+    LLVM_DEBUG(dbgs() << "IA: Found a deinterleave intrinsic: " << *DI << "\n");
+
+    // Since lowerInterleaveLoad expects Shuffles and LoadInst, use special
+    // TLI function to emit target-specific interleaved instruction.
+    if (!TLI->lowerInterleavedScalableLoad(VPLoad, *Mask, DI, Factor,
+                                           DeInterleaveResults))
+      return false;
+
+    DeadInsts.push_back(DI);
+    DeadInsts.push_back(VPLoad);
+    return true;
+  }
+
+  // Match
+  //   %x = vp.strided.load  ;; VPStridedLoad
+  //   %y = bitcast %x       ;; BitCast
+  //   %y' = inttoptr %y
+  //   %z = deinterleave %y  ;; DI
+  if (Factor && isa<BitCastInst, IntToPtrInst>(DI->getOperand(0))) {
+    auto *BitCast = cast<Instruction>(DI->getOperand(0));
+    if (!BitCast->hasOneUse())
+      return false;
+
+    Instruction *IntToPtrCast = nullptr;
+    if (auto *BC = dyn_cast<BitCastInst>(BitCast->getOperand(0))) {
+      IntToPtrCast = BitCast;
+      BitCast = BC;
+    }
+
+    // Match the type is
+    //   <VF x (factor * elementTy)> bitcast to <(VF * factor) x elementTy>
+    Value *BitCastSrc = BitCast->getOperand(0);
+    auto *BitCastSrcTy = dyn_cast<VectorType>(BitCastSrc->getType());
+    auto *BitCastDstTy = cast<VectorType>(BitCast->getType());
+    if (!BitCastSrcTy || (BitCastSrcTy->getElementCount() * Factor !=
+                          BitCastDstTy->getElementCount()))
+      return false;
+
+    if (auto *VPStridedLoad = dyn_cast<VPIntrinsic>(BitCast->getOperand(0))) {
+      if (VPStridedLoad->getIntrinsicID() !=
+              Intrinsic::experimental_vp_strided_load ||
+          !VPStridedLoad->hasOneUse())
+        return false;
+
+      LLVM_DEBUG(dbgs() << "IA: Found a deinterleave intrinsic: " << *DI
+                        << "\n");
+
+      if (!TLI->lowerDeinterleaveIntrinsicToStridedLoad(
+              VPStridedLoad, DI, Factor, DeInterleaveResults))
+        return false;
+
+      DeadInsts.push_back(DI);
+      if (IntToPtrCast)
+        DeadInsts.push_back(IntToPtrCast);
+      DeadInsts.push_back(BitCast);
+      DeadInsts.push_back(VPStridedLoad);
+      return true;
+    }
+  }
+
   LoadInst *LI = dyn_cast<LoadInst>(DI->getOperand(0));
 
   if (!LI || !LI->hasOneUse() || !LI->isSimple())
@@ -502,6 +758,33 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic(
   if (!II->hasOneUse())
     return false;
 
+  if (auto *VPStore = dyn_cast<VPIntrinsic>(*(II->users().begin()))) {
+    if (VPStore->getIntrinsicID() != Intrinsic::vp_store)
+      return false;
+
+    SmallVector<Value *, 8> InterleaveOperands;
+    unsigned Factor = getVectorInterleaveFactor(II, InterleaveOperands);
+    if (!Factor)
+      return false;
+
+    Value *WideMask = VPStore->getOperand(2);
+    std::optional<Value *> Mask = getMask(WideMask, Factor);
+    if (!Mask)
+      return false;
+
+    LLVM_DEBUG(dbgs() << "IA: Found an interleave intrinsic: " << *II << "\n");
+
+    // Since lowerInterleavedStore expects Shuffle and StoreInst, use special
+    // TLI function to emit target-specific interleaved instruction.
+    if (!TLI->lowerInterleavedScalableStore(VPStore, *Mask, II, Factor,
+                                            InterleaveOperands))
+      return false;
+
+    DeadInsts.push_back(VPStore);
+    DeadInsts.push_back(II);
+    return true;
+  }
+
   StoreInst *SI = dyn_cast<StoreInst>(*(II->users().begin()));
 
   if (!SI || !SI->isSimple())
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b703eb90e8ef30..2dafbf737512a9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -22190,6 +22190,416 @@ bool RISCVTargetLowering::lowerInterleaveIntrinsicToStore(
   return true;
 }
 
+/// Lower an interleaved vp.load into a vlsegN intrinsic.
+///
+/// E.g. Lower an interleaved vp.load (Factor = 2):
+///   %l = call <vscale x 64 x i8> @llvm.vp.load.nxv64i8.p0(ptr %ptr,
+///                                                         %mask,
+///                                                         i32 %wide.rvl)
+///   %dl = tail call { <vscale x 32 x i8>, <vscale x 32 x i8> }
+///             @llvm.vector.deinterleave2.nxv64i8(
+///               <vscale x 64 x i8> %l)
+///   %r0 = extractvalue { <vscale x 32 x i8>, <vscale x 32 x i8> } %dl, 0
+///   %r1 = extractvalue { <vscale x 32 x i8>, <vscale x 32 x i8> } %dl, 1
+///
+/// Into:
+///   %rvl = udiv %wide.rvl, 2
+///   %sl = call { <vscale x 32 x i8>, <vscale x 32 x i8> }
+///             @llvm.riscv.vlseg2.mask.nxv32i8.i64(<vscale x 32 x i8> undef,
+///                                                 <vscale x 32 x i8> undef,
+///                                                 ptr %ptr,
+///                                                 %mask,
+///                                                 i64 %rvl,
+///                                                 i64 1)
+///   %r0 = extractvalue { <vscale x 32 x i8>, <vscale x 32 x i8> } %sl, 0
+///   %r1 = extractvalue { <vscale x 32 x i8>, <vscale x 32 x i8> } %sl, 1
+///
+/// NOTE: the deinterleave2 intrinsic won't be touched and is expected to be
+/// removed by the caller
+bool RISCVTargetLowering::lowerInterleavedScalableLoad(
+    VPIntrinsic *Load, Value *Mask, IntrinsicInst *DeinterleaveIntrin,
+    unsigned Factor, ArrayRef<Value *> DeInterleaveResults) const {
+  assert(Load->getIntrinsicID() == Intrinsic::vp_load &&
+         "Unexpected intrinsic");
+
+  auto *WideVTy = cast<VectorType>(Load->getType());
+  unsigned WideNumElements = WideVTy->getElementCount().getKnownMinValue();
+  assert(WideNumElements % Factor == 0 &&
+         "ElementCount of a wide load must be divisible by interleave factor");
+  auto *VTy =
+      VectorType::get(WideVTy->getScalarType(), WideNumElements / Factor,
+                      WideVTy->isScalableTy());
+  // FIXME: Should pass alignment attribute from pointer, but vectorizer needs
+  // to emit it first.
+  auto &DL = Load->getModule()->getDataLayout();
+  Align Alignment = Align(DL.getTypeStoreSize(WideVTy->getScalarType()));
+  if (!isLegalInterleavedAccessType(
+          VTy, Factor, Alignment,
+          Load->getArgOperand(0)->getType()->getPointerAddressSpace(), DL))
+    return false;
+
+  IRBuilder<> Builder(Load);
+  Value *WideEVL = Load->getArgOperand(2);
+  auto *XLenTy = Type::getIntNTy(Load->getContext(), Subtarget.getXLen());
+  Value *EVL = Builder.CreateZExtOrTrunc(
+      Builder.CreateUDiv(WideEVL, ConstantInt::get(WideEVL->getType(), Factor)),
+      XLenTy);
+
+  static const Intrinsic::ID IntrMaskIds[] = {
+      Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
+      Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
+      Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask,
+      Intrinsic::riscv_vlseg8_mask,
+  };
+  static const Intrinsic::ID IntrIds[] = {
+      Intrinsic::riscv_vlseg2, Intrinsic::riscv_vlseg3, Intrinsic::riscv_vlseg4,
+      Intrinsic::riscv_vlseg5, Intrinsic::riscv_vlseg6, Intrinsic::riscv_vlseg7,
+      Intrinsic::riscv_vlseg8,
+  };
+
+  unsigned SEW = DL.getTypeSizeInBits(VTy->getElementType());
+  unsigned NumElts = VTy->getElementCount().getKnownMinValue();
+  Type *VecTupTy = TargetExtType::get(
+      Load->getContext(), "riscv.vector.tuple",
+      ScalableVectorType::get(Type::getInt8Ty(Load->getContext()),
+                              NumElts * SEW / 8),
+      Factor);
+
+  Value *PoisonVal = PoisonValue::get(VecTupTy);
+  SmallVector<Value *> Operands;
+  Operands.append({PoisonVal, Load->getArgOperand(0)});
+
+  Function *VlsegNFunc;
+  if (Mask) {
+    VlsegNFunc = Intrinsic::getOrInsertDeclaration(
+        Load->getModule(), IntrMaskIds[Factor - 2],
+        {VecTupTy, Mask->getType(), EVL->getType()});
+    Operands.push_back(Mask);
+  } else {
+    VlsegNFunc = Intrinsic::getOrInsertDeclaration(
+        Load->getModule(), IntrIds[Factor - 2], {VecTupTy, EVL->getType()});
+  }
+
+  Operands.push_back(EVL);
+
+  // Tail-policy
+  if (Mask)
+    Operands.push_back(ConstantInt::get(XLenTy, 1));
+
+  Operands.push_back(ConstantInt::get(XLenTy, Log2_64(SEW)));
+
+  CallInst *VlsegN = Builder.CreateCall(VlsegNFunc, Operands);
+
+  SmallVector<Type *, 8> AggrTypes{Factor, VTy};
+  Value *Return =
+      PoisonValue::get(StructType::get(Load->getContext(), AggrTypes));
+  Function *VecExtractFunc = Intrinsic::getOrInsertDeclaration(
+      Load->getModule(), Intrinsic::riscv_tuple_extract, {VTy, VecTupTy});
+  for (unsigned i = 0; i < Factor; ++i) {
+    Value *VecExtract =
+        Builder.CreateCall(VecExtractFunc, {VlsegN, Builder.getInt32(i)});
+    Return = Builder.CreateInsertValue(Return, VecExtract, i);
+  }
+
+  for (auto [Idx, DIO] : enumerate(DeInterleaveResults)) {
+    // We have to create a brand new ExtractValue to replace each
+    // of these old ExtractValue instructions.
+    Value *NewEV =
+        Builder.CreateExtractValue(Return, {static_cast<unsigned>(Idx)});
+    DIO->replaceAllUsesWith(NewEV);
+  }
+  DeinterleaveIntrin->replaceAllUsesWith(
+      UndefValue::get(DeinterleaveIntrin->getType()));
+
+  return true;
+}
+
+/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
+/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
+/// larger splat
+/// `<vscale x 4 x i64> <splat of ((777 << 32) | 666)>` first before casting it
+/// into
+/// `<vscale x 8 x i32>`. This will resuling a simple unit stride store rather
+/// than a segment store, which is more expensive in this case.
+static Value *foldInterleaved2OfConstSplats(IntrinsicInst *InterleaveIntrin,
+                                            VectorType *VTy,
+                                            const TargetLowering *TLI,
+                                            Instruction *VPStore) {
+  // We only handle Factor = 2 for now.
+  assert(InterleaveIntrin->arg_size() == 2);
+  auto *SplatVal0 = dyn_cast_or_null<ConstantInt>(
+      getSplatValue(InterleaveIntrin->getArgOperand(0)));
+  auto *SplatVal1 = dyn_cast_or_null<ConstantInt>(
+      getSplatValue(InterleaveIntrin->getArgOperand(1)));
+  if (!SplatVal0 || !SplatVal1)
+    return nullptr;
+
+  auto &Ctx = VPStore->getContext();
+  auto &DL = VPStore->getModule()->getDataLayout();
+
+  auto *NewVTy = VectorType::getExtendedElementVectorType(VTy);
+  if (!TLI->isTypeLegal(TLI->getValueType(DL, NewVTy)))
+    return nullptr;
+
+  // InterleavedAccessPass will remove VPStore after this but we still want to
+  // preserve it, hence clone another one here.
+  auto *ClonedVPStore = VPStore->clone();
+  ClonedVPStore->insertBefore(VPStore);
+  IRBuilder<> Builder(ClonedVPStore);
+
+  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/120490


More information about the llvm-commits mailing list