[llvm] [LV] Split checking if tail-folding is possible, collecting masked ops. (PR #77612)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 8 07:27:03 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/77612

>From 9fa94573045383ccf06a123f8dba82242fb34d49 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 10 Jan 2024 14:17:28 +0000
Subject: [PATCH 1/2] [LV] Split checking if tail-folding is possible,
 collecting masked ops.

Introduce new canFoldTail helper which only checks if tail-folding is
possible, but without modifying MaskedOps.

Just because tail-folding is possible doesn't mean the tail will be
folded; that's up to the cost-model to decide. Separating the check if
tail-folding is possible and preparing for tail-folding makes sure that
MaskedOps is only populated when tail-folding is actually selected.

This allows only creating the header mask if needed after
https://github.com/llvm/llvm-project/pull/76635.
---
 .../Vectorize/LoopVectorizationLegality.h     |  9 ++-
 .../Vectorize/LoopVectorizationLegality.cpp   | 20 ++++++-
 .../Transforms/Vectorize/LoopVectorize.cpp    |  5 +-
 ...ectorize-force-tail-with-evl-interleave.ll | 60 ++++++++++++-------
 4 files changed, 65 insertions(+), 29 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
index a509ebf6a7e1b..af23813da569b 100644
--- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
@@ -276,9 +276,12 @@ class LoopVectorizationLegality {
   bool canVectorizeFPMath(bool EnableStrictReductions);
 
   /// Return true if we can vectorize this loop while folding its tail by
-  /// masking, and mark all respective loads/stores for masking.
-  /// This object's state is only modified iff this function returns true.
-  bool prepareToFoldTailByMasking();
+  /// masking.
+  bool canFoldTailByMasking() const;
+
+  /// Mark all respective loads/stores for masking. Must only be called when
+  /// ail-folding is possible.
+  void prepareToFoldTailByMasking();
 
   /// Returns the primary induction variable.
   PHINode *getPrimaryInduction() { return PrimaryInduction; }
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index d306524ae5113..c99b3f63c6921 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -1543,7 +1543,7 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) {
   return Result;
 }
 
-bool LoopVectorizationLegality::prepareToFoldTailByMasking() {
+bool LoopVectorizationLegality::canFoldTailByMasking() const {
 
   LLVM_DEBUG(dbgs() << "LV: checking if tail can be folded by masking.\n");
 
@@ -1601,8 +1601,24 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() {
 
   LLVM_DEBUG(dbgs() << "LV: can fold tail by masking.\n");
 
-  MaskedOp.insert(TmpMaskedOp.begin(), TmpMaskedOp.end());
   return true;
 }
 
+void LoopVectorizationLegality::prepareToFoldTailByMasking() {
+  // The list of pointers that we can safely read and write to remains empty.
+  SmallPtrSet<Value *, 8> SafePointers;
+
+  // Collect masked ops in temporary set first to avoid partially populating
+  // MaskedOp if a block cannot be predicated.
+  SmallPtrSet<const Instruction *, 8> TmpMaskedOp;
+
+  // Check and mark all blocks for predication, including those that ordinarily
+  // do not need predication such as the header block.
+  for (BasicBlock *BB : TheLoop->blocks()) {
+    bool R = blockCanBePredicated(BB, SafePointers, MaskedOp);
+    (void)R;
+    assert(R && "Must be able to predicate block when tail-folding.");
+  }
+}
+
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 417c577e3c5dc..208246f1a887e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1502,7 +1502,7 @@ class LoopVectorizationCostModel {
   /// \param UserIC User specific interleave count.
   void setTailFoldingStyles(bool IsScalableVF, unsigned UserIC) {
     assert(!ChosenTailFoldingStyle && "Tail folding must not be selected yet.");
-    if (!Legal->prepareToFoldTailByMasking()) {
+    if (!Legal->canFoldTailByMasking()) {
       ChosenTailFoldingStyle =
           std::make_pair(TailFoldingStyle::None, TailFoldingStyle::None);
       return;
@@ -7226,6 +7226,9 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
       CM.invalidateCostModelingDecisions();
   }
 
+  if (CM.foldTailByMasking())
+    Legal->prepareToFoldTailByMasking();
+
   ElementCount MaxUserVF =
       UserVF.isScalable() ? MaxFactors.ScalableVF : MaxFactors.FixedVF;
   bool UserVFIsLegal = ElementCount::isKnownLE(UserVF, MaxUserVF);
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-interleave.ll b/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-interleave.ll
index b4427864d4730..895c89b768acb 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-interleave.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-interleave.ll
@@ -96,35 +96,49 @@ define void @interleave(ptr noalias %a, ptr noalias %b, i64 %N) {
 ;
 ; NO-VP-LABEL: @interleave(
 ; NO-VP-NEXT:  entry:
-; NO-VP-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[N:%.*]], 16
+; NO-VP-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; NO-VP-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 8
+; NO-VP-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[N:%.*]], [[TMP1]]
 ; NO-VP-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; NO-VP:       vector.ph:
-; NO-VP-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[N]], 16
+; NO-VP-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; NO-VP-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 8
+; NO-VP-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[N]], [[TMP3]]
 ; NO-VP-NEXT:    [[N_VEC:%.*]] = sub i64 [[N]], [[N_MOD_VF]]
+; NO-VP-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; NO-VP-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 8
 ; NO-VP-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; NO-VP:       vector.body:
 ; NO-VP-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; NO-VP-NEXT:    [[TMP10:%.*]] = add i64 [[INDEX]], 0
-; NO-VP-NEXT:    [[TMP1:%.*]] = add i64 [[INDEX]], 8
-; NO-VP-NEXT:    [[TMP2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B:%.*]], i64 [[TMP10]], i32 0
-; NO-VP-NEXT:    [[TMP3:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i64 [[TMP1]], i32 0
-; NO-VP-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 0
-; NO-VP-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i32, ptr [[TMP3]], i32 0
-; NO-VP-NEXT:    [[WIDE_VEC:%.*]] = load <16 x i32>, ptr [[TMP4]], align 4
-; NO-VP-NEXT:    [[WIDE_VEC1:%.*]] = load <16 x i32>, ptr [[TMP5]], align 4
-; NO-VP-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <16 x i32> [[WIDE_VEC]], <16 x i32> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
-; NO-VP-NEXT:    [[STRIDED_VEC2:%.*]] = shufflevector <16 x i32> [[WIDE_VEC1]], <16 x i32> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
-; NO-VP-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <16 x i32> [[WIDE_VEC]], <16 x i32> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
-; NO-VP-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <16 x i32> [[WIDE_VEC1]], <16 x i32> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
-; NO-VP-NEXT:    [[TMP6:%.*]] = add nsw <8 x i32> [[STRIDED_VEC3]], [[STRIDED_VEC]]
-; NO-VP-NEXT:    [[TMP7:%.*]] = add nsw <8 x i32> [[STRIDED_VEC4]], [[STRIDED_VEC2]]
-; NO-VP-NEXT:    [[TMP24:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP10]]
-; NO-VP-NEXT:    [[TMP13:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP1]]
-; NO-VP-NEXT:    [[TMP12:%.*]] = getelementptr inbounds i32, ptr [[TMP24]], i32 0
-; NO-VP-NEXT:    [[TMP11:%.*]] = getelementptr inbounds i32, ptr [[TMP24]], i32 8
-; NO-VP-NEXT:    store <8 x i32> [[TMP6]], ptr [[TMP12]], align 4
-; NO-VP-NEXT:    store <8 x i32> [[TMP7]], ptr [[TMP11]], align 4
-; NO-VP-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; NO-VP-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; NO-VP-NEXT:    [[TMP7:%.*]] = call i64 @llvm.vscale.i64()
+; NO-VP-NEXT:    [[TMP8:%.*]] = mul i64 [[TMP7]], 4
+; NO-VP-NEXT:    [[TMP9:%.*]] = add i64 [[TMP8]], 0
+; NO-VP-NEXT:    [[TMP10:%.*]] = mul i64 [[TMP9]], 1
+; NO-VP-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], [[TMP10]]
+; NO-VP-NEXT:    [[TMP12:%.*]] = getelementptr inbounds [2 x i32], ptr [[B:%.*]], i64 [[TMP6]], i32 0
+; NO-VP-NEXT:    [[TMP13:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i64 [[TMP11]], i32 0
+; NO-VP-NEXT:    [[TMP14:%.*]] = getelementptr inbounds i32, ptr [[TMP12]], i32 0
+; NO-VP-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i32, ptr [[TMP13]], i32 0
+; NO-VP-NEXT:    [[WIDE_VEC:%.*]] = load <vscale x 8 x i32>, ptr [[TMP14]], align 4
+; NO-VP-NEXT:    [[WIDE_VEC1:%.*]] = load <vscale x 8 x i32>, ptr [[TMP15]], align 4
+; NO-VP-NEXT:    [[STRIDED_VEC:%.*]] = call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32> [[WIDE_VEC]])
+; NO-VP-NEXT:    [[TMP16:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[STRIDED_VEC]], 0
+; NO-VP-NEXT:    [[TMP17:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[STRIDED_VEC]], 1
+; NO-VP-NEXT:    [[STRIDED_VEC2:%.*]] = call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32> [[WIDE_VEC1]])
+; NO-VP-NEXT:    [[TMP18:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[STRIDED_VEC2]], 0
+; NO-VP-NEXT:    [[TMP19:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[STRIDED_VEC2]], 1
+; NO-VP-NEXT:    [[TMP20:%.*]] = add nsw <vscale x 4 x i32> [[TMP17]], [[TMP16]]
+; NO-VP-NEXT:    [[TMP21:%.*]] = add nsw <vscale x 4 x i32> [[TMP19]], [[TMP18]]
+; NO-VP-NEXT:    [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP6]]
+; NO-VP-NEXT:    [[TMP23:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[TMP11]]
+; NO-VP-NEXT:    [[TMP24:%.*]] = getelementptr inbounds i32, ptr [[TMP22]], i32 0
+; NO-VP-NEXT:    [[TMP25:%.*]] = call i64 @llvm.vscale.i64()
+; NO-VP-NEXT:    [[TMP26:%.*]] = mul i64 [[TMP25]], 4
+; NO-VP-NEXT:    [[TMP27:%.*]] = getelementptr inbounds i32, ptr [[TMP22]], i64 [[TMP26]]
+; NO-VP-NEXT:    store <vscale x 4 x i32> [[TMP20]], ptr [[TMP24]], align 4
+; NO-VP-NEXT:    store <vscale x 4 x i32> [[TMP21]], ptr [[TMP27]], align 4
+; NO-VP-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
 ; NO-VP-NEXT:    [[TMP28:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; NO-VP-NEXT:    br i1 [[TMP28]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; NO-VP:       middle.block:

>From 48311bde8f06253092bccf7c0790f33ad8b84f36 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 8 Jul 2024 15:22:11 +0100
Subject: [PATCH 2/2] !fixup address latest comments, thanks!

---
 .../Vectorize/LoopVectorizationLegality.h     |  2 +-
 .../Vectorize/LoopVectorizationLegality.cpp   | 20 ++++++-------------
 2 files changed, 7 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
index af23813da569b..2ff17bd2f7a71 100644
--- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
@@ -280,7 +280,7 @@ class LoopVectorizationLegality {
   bool canFoldTailByMasking() const;
 
   /// Mark all respective loads/stores for masking. Must only be called when
-  /// ail-folding is possible.
+  /// tail-folding is possible.
   void prepareToFoldTailByMasking();
 
   /// Returns the primary induction variable.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index c99b3f63c6921..f54eebb2874ab 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -1586,15 +1586,12 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
   // The list of pointers that we can safely read and write to remains empty.
   SmallPtrSet<Value *, 8> SafePointers;
 
-  // Collect masked ops in temporary set first to avoid partially populating
-  // MaskedOp if a block cannot be predicated.
+  // Check all blocks for predication, including those that ordinarily do not
+  // need predication such as the header block.
   SmallPtrSet<const Instruction *, 8> TmpMaskedOp;
-
-  // Check and mark all blocks for predication, including those that ordinarily
-  // do not need predication such as the header block.
   for (BasicBlock *BB : TheLoop->blocks()) {
     if (!blockCanBePredicated(BB, SafePointers, TmpMaskedOp)) {
-      LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking as requested.\n");
+      LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking.\n");
       return false;
     }
   }
@@ -1608,15 +1605,10 @@ void LoopVectorizationLegality::prepareToFoldTailByMasking() {
   // The list of pointers that we can safely read and write to remains empty.
   SmallPtrSet<Value *, 8> SafePointers;
 
-  // Collect masked ops in temporary set first to avoid partially populating
-  // MaskedOp if a block cannot be predicated.
-  SmallPtrSet<const Instruction *, 8> TmpMaskedOp;
-
-  // Check and mark all blocks for predication, including those that ordinarily
-  // do not need predication such as the header block.
+  // Mark all blocks for predication, including those that ordinarily do not
+  // need predication such as the header block.
   for (BasicBlock *BB : TheLoop->blocks()) {
-    bool R = blockCanBePredicated(BB, SafePointers, MaskedOp);
-    (void)R;
+    [[maybe_unused]] bool R = blockCanBePredicated(BB, SafePointers, MaskedOp);
     assert(R && "Must be able to predicate block when tail-folding.");
   }
 }



More information about the llvm-commits mailing list