[llvm] [LoopVectorize] Vectorize the compact pattern (PR #68980)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 13 04:51:49 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: None (huhu233)

<details>
<summary>Changes</summary>

This patch tries to vectorize the compact pattern, as shown,
    
for(i=0; i<N; i++){
   x = comp[i];
   if(x<a) Out_ref[n++]=B[i];
}
    
It introduces some changes:
1. Add a pattern matching in LoopVectorizationLegality to cache specific cases.
2. Introduce two new recipes to hande the compact chain:
    VPCompactPHIRecipe: Handle the entry PHI of compact chain.
    VPWidenCompactInstructionRecipe: Handle other instructions in compact chain.
3. Slightly adapt the cost model for compact pattern.

---

Patch is 56.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68980.diff


20 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+25) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+7) 
- (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+4) 
- (modified) llvm/include/llvm/Transforms/Utils/LoopUtils.h (+8) 
- (modified) llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h (+35) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+16) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+12) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+2) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+10) 
- (modified) llvm/lib/Transforms/Utils/LoopUtils.cpp (+29) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp (+157) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+111-5) 
- (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+5) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.cpp (+17) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+62) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+135) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+2) 
- (added) llvm/test/Transforms/LoopVectorize/AArch64/compact-vplan.ll (+78) 
- (added) llvm/test/Transforms/LoopVectorize/AArch64/compact.ll (+153) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5234ef8788d9e96..c2851c10e6ff3ef 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1672,6 +1672,11 @@ class TargetTransformInfo {
   /// \return The maximum number of function arguments the target supports.
   unsigned getMaxNumArgs() const;
 
+  InstructionCost getCompactCost() const;
+  bool isTargetSupportedCompactStore() const;
+  unsigned getTargetSupportedCompact() const;
+  unsigned getTargetSupportedCNTP() const;
+
   /// @}
 
 private:
@@ -2041,6 +2046,10 @@ class TargetTransformInfo::Concept {
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
   virtual unsigned getMaxNumArgs() const = 0;
+  virtual bool isTargetSupportedCompactStore() const = 0;
+  virtual unsigned getTargetSupportedCompact() const = 0;
+  virtual unsigned getTargetSupportedCNTP() const = 0;
+  virtual InstructionCost getCompactCost() const = 0;
 };
 
 template <typename T>
@@ -2757,6 +2766,22 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxNumArgs() const override {
     return Impl.getMaxNumArgs();
   }
+
+  bool isTargetSupportedCompactStore() const override {
+    return Impl.isTargetSupportedCompactStore();
+  }
+
+  unsigned getTargetSupportedCompact() const override {
+    return Impl.getTargetSupportedCompact();
+  }
+
+  unsigned getTargetSupportedCNTP() const override {
+    return Impl.getTargetSupportedCNTP();
+  }
+
+  InstructionCost getCompactCost() const override {
+    return Impl.getCompactCost();
+  }
 };
 
 template <typename T>
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index c1ff314ae51c98b..e063f383980a724 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -895,6 +895,13 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxNumArgs() const { return UINT_MAX; }
 
+  bool isTargetSupportedCompactStore() const { return false; }
+  unsigned getTargetSupportedCompact() const { return 0; }
+  unsigned getTargetSupportedCNTP() const { return 0; }
+  InstructionCost getCompactCost() const {
+    return InstructionCost::getInvalid();
+  }
+
 protected:
   // Obtain the minimum required size to hold the value (without the sign)
   // In case of a vector it returns the min required size for one element.
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 3dd16dafe3c42a7..737757761eca4ab 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -700,6 +700,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return getST()->getMaxPrefetchIterationsAhead();
   }
 
+  virtual InstructionCost getCompactCost() const {
+    return InstructionCost::getInvalid();
+  }
+
   virtual bool enableWritePrefetching() const {
     return getST()->enableWritePrefetching();
   }
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 0d99249be413762..348b8ad03de4179 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -409,6 +409,14 @@ Value *createAnyOfTargetReduction(IRBuilderBase &B, Value *Src,
 Value *createTargetReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc,
                              Value *Src, PHINode *OrigPhi = nullptr);
 
+Value *createTargetCompact(IRBuilderBase &B, Module *M,
+                           const TargetTransformInfo *TTI, Value *Mask,
+                           Value *Val);
+
+Value *createTargetCNTP(IRBuilderBase &B, Module *M,
+                        const TargetTransformInfo *TTI, Value *Mask,
+                        Value *Val);
+
 /// Create an ordered reduction intrinsic using the given recurrence
 /// descriptor \p Desc.
 Value *createOrderedReduction(IRBuilderBase &B,
diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
index 20cfc680e8f90b3..7f82154699e5174 100644
--- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
@@ -224,6 +224,26 @@ class LoopVectorizationRequirements {
   Instruction *ExactFPMathInst = nullptr;
 };
 
+class CompactDescriptor {
+  PHINode *LiveOutPhi;
+  bool IsCompactSign;
+  SmallPtrSet<Value *, 8> Chain;
+
+public:
+  CompactDescriptor() = default;
+  CompactDescriptor(SmallPtrSetImpl<Value *> &CompactChain, PHINode *LiveOut,
+                    bool IsSign)
+      : LiveOutPhi(LiveOut), IsCompactSign(IsSign) {
+    Chain.insert(CompactChain.begin(), CompactChain.end());
+  }
+
+  bool isInCompactChain(Value *V) const { return Chain.find(V) != Chain.end(); }
+
+  PHINode *getLiveOutPhi() const { return LiveOutPhi; }
+
+  bool isSign() const { return IsCompactSign; }
+};
+
 /// LoopVectorizationLegality checks if it is legal to vectorize a loop, and
 /// to what vectorization factor.
 /// This class does not look at the profitability of vectorization, only the
@@ -261,6 +281,8 @@ class LoopVectorizationLegality {
   /// inductions and reductions.
   using RecurrenceSet = SmallPtrSet<const PHINode *, 8>;
 
+  using CompactList = MapVector<PHINode *, CompactDescriptor>;
+
   /// Returns true if it is legal to vectorize this loop.
   /// This does not mean that it is profitable to vectorize this
   /// loop, only that it is legal to do so.
@@ -397,6 +419,14 @@ class LoopVectorizationLegality {
 
   DominatorTree *getDominatorTree() const { return DT; }
 
+  const CompactList &getCompactList() const { return CpList; }
+
+  bool hasCompactChain() const { return CpList.size() > 0; }
+
+  PHINode *getCompactChainStart(Instruction *I) const;
+
+  bool isSign(PHINode *Phi) { return CpList[Phi].isSign(); };
+
 private:
   /// Return true if the pre-header, exiting and latch blocks of \p Lp and all
   /// its nested loops are considered legal for vectorization. These legal
@@ -425,6 +455,8 @@ class LoopVectorizationLegality {
   /// and we only need to check individual instructions.
   bool canVectorizeInstrs();
 
+  bool isMatchCompact(PHINode *Phi, Loop *TheLoop, CompactDescriptor &CpDesc);
+
   /// When we vectorize loops we may change the order in which
   /// we read and write from memory. This method checks if it is
   /// legal to vectorize the code, considering only memory constrains.
@@ -538,6 +570,9 @@ class LoopVectorizationLegality {
   /// BFI and PSI are used to check for profile guided size optimizations.
   BlockFrequencyInfo *BFI;
   ProfileSummaryInfo *PSI;
+
+  // Record compact chain in the loop.
+  CompactList CpList;
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index aad14f21d114619..b7596bb2e0dfc92 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1248,6 +1248,22 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
   return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
 }
 
+bool TargetTransformInfo::isTargetSupportedCompactStore() const {
+  return TTIImpl->isTargetSupportedCompactStore();
+}
+
+unsigned TargetTransformInfo::getTargetSupportedCompact() const {
+  return TTIImpl->getTargetSupportedCompact();
+}
+
+unsigned TargetTransformInfo::getTargetSupportedCNTP() const {
+  return TTIImpl->getTargetSupportedCNTP();
+}
+
+InstructionCost TargetTransformInfo::getCompactCost() const {
+  return TTIImpl->getCompactCost();
+}
+
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index fc9e3ff3734989d..d30d0b57b5d47b0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -22,6 +22,7 @@
 #include "llvm/CodeGen/StackMaps.h"
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Support/raw_ostream.h"
@@ -301,6 +302,11 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FFREXP:
     Res = PromoteIntRes_FFREXP(N);
     break;
+  case ISD::INTRINSIC_WO_CHAIN:
+    if (N->getConstantOperandVal(0) == Intrinsic::aarch64_sve_compact) {
+      Res = PromoteIntRes_COMPACT(N);
+      break;
+    }
   }
 
   // If the result is null then the sub-method took care of registering it.
@@ -5942,6 +5948,12 @@ SDValue DAGTypeLegalizer::PromoteIntOp_CONCAT_VECTORS(SDNode *N) {
   return DAG.getBuildVector(N->getValueType(0), dl, NewOps);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_COMPACT(SDNode *N) {
+  SDValue OpExt = SExtOrZExtPromotedInteger(N->getOperand(2));
+  return DAG.getNode(N->getOpcode(), SDLoc(N), OpExt.getValueType(),
+                     N->getOperand(0), N->getOperand(1), OpExt);
+}
+
 SDValue DAGTypeLegalizer::ExpandIntOp_STACKMAP(SDNode *N, unsigned OpNo) {
   assert(OpNo > 1);
   SDValue Op = N->getOperand(OpNo);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index c802604a3470e13..d204169ed2327f7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -364,6 +364,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_FunnelShift(SDNode *N);
   SDValue PromoteIntRes_VPFunnelShift(SDNode *N);
   SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
+  SDValue PromoteIntRes_COMPACT(SDNode *N);
 
   // Integer Operand Promotion.
   bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index d8a0e68d7123759..5ca5f22525d3dd3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3889,3 +3889,5 @@ AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
     return AM.Scale != 0 && AM.Scale != 1;
   return -1;
 }
+
+InstructionCost AArch64TTIImpl::getCompactCost() const { return 6; }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a6baade412c77d2..28bd48e8ed76c4a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -24,6 +24,7 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
 #include <cstdint>
 #include <optional>
 
@@ -412,6 +413,15 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
     return BaseT::getStoreMinimumVF(VF, ScalarMemTy, ScalarValTy);
   }
+
+  bool isTargetSupportedCompactStore() const { return ST->hasSVE(); }
+  unsigned getTargetSupportedCompact() const {
+    return Intrinsic::aarch64_sve_compact;
+  }
+  unsigned getTargetSupportedCNTP() const {
+    return Intrinsic::aarch64_sve_cntp;
+  }
+  InstructionCost getCompactCost() const override;
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 21affe7bdce406e..1373fb7931f0a7e 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -34,6 +34,7 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
@@ -1119,6 +1120,34 @@ Value *llvm::createTargetReduction(IRBuilderBase &B,
   return createSimpleTargetReduction(B, Src, RK);
 }
 
+Value *llvm::createTargetCompact(IRBuilderBase &B, Module *M,
+                                 const TargetTransformInfo *TTI, Value *Mask,
+                                 Value *Val) {
+  Intrinsic::ID IID = TTI->getTargetSupportedCompact();
+  switch (IID) {
+  default:
+    return nullptr;
+  case Intrinsic::aarch64_sve_compact:
+    Function *CompactMaskDecl = Intrinsic::getDeclaration(
+        M, Intrinsic::aarch64_sve_compact, Val->getType());
+    return B.CreateCall(CompactMaskDecl, {Mask, Val});
+  }
+}
+
+Value *llvm::createTargetCNTP(IRBuilderBase &B, Module *M,
+                              const TargetTransformInfo *TTI, Value *Mask,
+                              Value *Val) {
+  Intrinsic::ID IID = TTI->getTargetSupportedCNTP();
+  switch (IID) {
+  default:
+    return nullptr;
+  case Intrinsic::aarch64_sve_cntp:
+    Function *CNTPDecl = Intrinsic::getDeclaration(
+        M, Intrinsic::aarch64_sve_cntp, Val->getType());
+    return B.CreateCall(CNTPDecl, {Mask, Val});
+  }
+}
+
 Value *llvm::createOrderedReduction(IRBuilderBase &B,
                                     const RecurrenceDescriptor &Desc,
                                     Value *Src, Value *Start) {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index 35d69df56dc7220..dbab8af159a1621 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -24,6 +24,7 @@
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/KnownBits.h"
 #include "llvm/Transforms/Utils/SizeOpts.h"
 #include "llvm/Transforms/Vectorize/LoopVectorize.h"
 
@@ -78,6 +79,11 @@ static cl::opt<LoopVectorizeHints::ScalableForceKind>
                 "Scalable vectorization is available and favored when the "
                 "cost is inconclusive.")));
 
+static cl::opt<bool>
+    EnableCompactVectorization("enable-compact-vectorization", cl::init(true),
+                               cl::Hidden,
+                               cl::desc("Enable vectorizing compact pattern."));
+
 /// Maximum vectorization interleave count.
 static const unsigned MaxInterleaveFactor = 16;
 
@@ -785,6 +791,143 @@ static bool isTLIScalarize(const TargetLibraryInfo &TLI, const CallInst &CI) {
   return Scalarize;
 }
 
+static bool isUserOfCompactPHI(BasicBlock *BB, PHINode *Phi, Instruction *I) {
+  if (I->getParent() != BB)
+    return false;
+
+  // Operations on PHI should be affine.
+  if (I->getOpcode() != Instruction::Add &&
+      I->getOpcode() != Instruction::Sub &&
+      I->getOpcode() != Instruction::SExt &&
+      I->getOpcode() != Instruction::ZExt)
+    return false;
+
+  if (I == Phi)
+    return true;
+
+  for (unsigned i = 0; i < I->getNumOperands(); i++) {
+    if (auto *Instr = dyn_cast<Instruction>(I->getOperand(i)))
+      if (isUserOfCompactPHI(BB, Phi, Instr))
+        return true;
+  }
+  return false;
+}
+
+// Match the basic compact pattern:
+// for.body:
+//    %src.phi = phi i64 [ 0, %preheader ], [ %target.phi, %for.inc ]
+//    ...
+// if.then:
+//    ...
+//    %data = load i32, ptr %In
+//    (there may be additional sext/zext if %src.phi types i32)
+//    %addr = getelementptr i32, ptr %Out, i64 %src.phi
+//    store i32 %data, ptr %addr
+//    %inc = add i64 %src.phi, 1
+// for.inc
+//    %target.phi = phi i64 [ %inc, if.then ], [ %src.phi, %for.body ]
+bool LoopVectorizationLegality::isMatchCompact(PHINode *Phi, Loop *TheLoop,
+                                               CompactDescriptor &CpDesc) {
+  if (Phi->getNumIncomingValues() > 2)
+    return false;
+
+  // Don't support phis who is used as mask.
+  for (User *U : Phi->users()) {
+    if (isa<CmpInst>(U))
+      return false;
+  }
+
+  SmallPtrSet<Value *, 8> CompactChain;
+  CompactChain.insert(Phi);
+
+  BasicBlock *LoopPreHeader = TheLoop->getLoopPreheader();
+  int ExitIndex = Phi->getIncomingBlock(0) == LoopPreHeader ? 1 : 0;
+  BasicBlock *ExitBlock = Phi->getIncomingBlock(ExitIndex);
+  PHINode *CompactLiveOut = nullptr;
+  Value *IncValue = nullptr;
+  BasicBlock *IncBlock = nullptr;
+  bool IsCycle = false;
+  for (auto &CandPhi : ExitBlock->phis()) {
+    if (llvm::is_contained(CandPhi.incoming_values(), Phi) &&
+        CandPhi.getNumIncomingValues() == 2) {
+      IsCycle = true;
+      CompactLiveOut = &CandPhi;
+      int IncIndex = CandPhi.getIncomingBlock(0) == Phi->getParent() ? 1 : 0;
+      IncBlock = CandPhi.getIncomingBlock(IncIndex);
+      IncValue = CandPhi.getIncomingValueForBlock(IncBlock);
+      break;
+    }
+  }
+  // Similar with reduction PHI.
+  if (!IsCycle)
+    return false;
+  CompactChain.insert(CompactLiveOut);
+
+  // Match the pattern %inc = add i32 %src.phi, 1.
+  Value *Index = nullptr, *Step = nullptr;
+  if (!match(IncValue, m_Add(m_Value(Index), m_Value(Step))))
+    return false;
+  if (Index != Phi) {
+    std::swap(Index, Step);
+  }
+  if (Step != ConstantInt::get(Step->getType(), 1))
+    return false;
+  CompactChain.insert(IncValue);
+
+  const DataLayout &DL = Phi->getModule()->getDataLayout();
+  int CntCandStores = 0;
+  GetElementPtrInst *GEP = nullptr;
+  for (auto &Inst : *IncBlock) {
+    if (auto *SI = dyn_cast<StoreInst>(&Inst)) {
+      // TODO: Support llvm.sve.compact.nxv8i16, llvm.sve.compact.nxv16i18 in
+      // the future.
+      unsigned TySize = DL.getTypeSizeInBits(SI->getValueOperand()->getType());
+      if (TySize < 32)
+        return false;
+
+      GEP = dyn_cast<GetElementPtrInst>(SI->getPointerOperand());
+      if (GEP == nullptr)
+        continue;
+
+      // Only handle single pointer.
+      if (GEP->getNumOperands() != 2)
+        continue;
+
+      // Get the index of GEP, index could be phi or sext/zext (if phi types
+      // i32).
+      Value *Op1 = GEP->getOperand(1);
+      Value *X = nullptr;
+      SmallSet<Value *, 16> CandiInstrs;
+      if (match(Op1, m_SExt(m_Value(X))) || match(Op1, m_ZExt(m_Value(X)))) {
+        Op1 = X;
+      }
+      Instruction *Op1Instr = dyn_cast<Instruction>(Op1);
+      if (!Op1Instr || isUserOfCompactPHI(IncBlock, Phi, Op1Instr))
+        continue;
+      CompactChain.insert(GEP);
+      CompactChain.insert(SI);
+      CntCandStores++;
+    }
+  }
+  if (!CntCandStores)
+    return false;
+
+  KnownBits Bits = computeKnownBits(Phi, DL);
+  bool IsSign = !Bits.isNonNegative();
+  CompactDescriptor CompactDesc(CompactChain, CompactLiveOut, IsSign);
+  CpDesc = CompactDesc;
+  LLVM_DEBUG(dbgs() << "LV: Found a compact chain.\n");
+  return true;
+}
+
+PHINode *LoopVectorizationLegality::getCompactChainStart(Instruction *I) const {
+  for (auto &CpDesc : CpList) {
+    if (CpDesc.second.isInCompactChain(I))
+      return CpDesc.first;
+  }
+  return nullptr;
+}
+
 bool LoopVectorizationLegality::canVectorizeInstrs() {
   BasicBlock *Header = TheLoop->getHeader();
 
@@ -881,6 +1024,14 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
           continue;
         }
 
+        CompactDescriptor CpDesc;
+        if (EnableCompactVectorization &&
+            TTI->isTargetSupportedCompactStore() &&
+            isMatchCompact(Phi, TheLoop, CpDesc)) {
+          CpList[Phi] = CpDesc;
+          continue;
+        }
+
         reportVectorizationFailure("Found an unidentified PHI",
             "value that could not be identified as "
             "reduction is used outside the loop",
@@ -1525,16 +1676,22 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() {
   LLVM_DEBUG(dbgs() << "LV: checking if tail can be folded by masking.\n");
 
   SmallPtrSet<const Value *, 8> ReductionLiveOuts;
+  SmallPtrSet<const Value *, 8> CompactLiveOuts;
 
   for (const auto &Reduction : getReductionVars())
     ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
 
+  for (const auto &Compact : getCompactList())
+    CompactLiveOuts.insert(Compact.second.getLiveOutPhi());
+
   // TODO: handle non-reduction outside users when tail is folded by masking.
   for (auto *AE : AllowedExit) {
    ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/68980


More information about the llvm-commits mailing list