[Mlir-commits] [mlir] [mlir] Extend SCF loopUnrollByFactor to return the result loops (PR #114573)

Hongtao Yu llvmlistbot at llvm.org
Sun Nov 3 12:08:08 PST 2024


https://github.com/htyu updated https://github.com/llvm/llvm-project/pull/114573

>From 6e779e649aee2ebcdf7594e469dc94da6d544380 Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Fri, 1 Nov 2024 09:52:11 -0700
Subject: [PATCH 1/3] [mlir] Extend SCF loopUnrollByFactor to return the result
 loops

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h    | 12 +++++++-----
 .../SCF/TransformOps/SCFTransformOps.cpp       |  6 ++++--
 mlir/lib/Dialect/SCF/Utils/Utils.cpp           | 18 ++++++++++++------
 3 files changed, 23 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 4001ba3fc84c9d..eda64ea69f81d1 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -111,11 +111,13 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
 void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
                            ArrayRef<std::vector<unsigned>> combinedDimensions);
 
-/// Unrolls this for operation by the specified unroll factor. Returns failure
-/// if the loop cannot be unrolled either due to restrictions or due to invalid
-/// unroll factors. Requires positive loop bounds and step. If specified,
-/// annotates the Ops in each unrolled iteration by applying `annotateFn`.
-LogicalResult loopUnrollByFactor(
+/// Unrolls this for operation by the specified unroll factor. Returns the
+/// unrolled main loop and the eplilog loop in sequence, if the loop is
+/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
+/// either due to restrictions or due to invalid unroll factors. Requires
+/// positive loop bounds and step. If specified, annotates the Ops in each
+/// unrolled iteration by applying `annotateFn`.
+SmallVector<scf::ForOp> loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 551411bb147653..c84cb13f8b6bb2 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -353,8 +353,10 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
                                     transform::ApplyToEachResultList &results,
                                     transform::TransformState &state) {
   LogicalResult result(failure());
-  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
-    result = loopUnrollByFactor(scfFor, getFactor());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
+    auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
+    result = resultLoops.empty() ? failure() : success();
+  }
   else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
     result = loopUnrollByFactor(affineFor, getFactor());
   else
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 43fcc595af0f7e..8394ac47888100 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -372,15 +372,17 @@ static void generateUnrolledLoop(
   loopBodyBlock->getTerminator()->setOperands(lastYielded);
 }
 
-/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
-LogicalResult mlir::loopUnrollByFactor(
+/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
+/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
+/// vector.
+SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
   assert(unrollFactor > 0 && "expected positive unroll factor");
 
   // Return if the loop body is empty.
   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
-    return success();
+    return {forOp};
 
   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -401,8 +403,8 @@ LogicalResult mlir::loopUnrollByFactor(
     if (unrollFactor == 1) {
       if (*constTripCount == 1 &&
           failed(forOp.promoteIfSingleIteration(rewriter)))
-        return failure();
-      return success();
+        return {};
+      return {forOp};
     }
 
     int64_t tripCountEvenMultiple =
@@ -450,6 +452,9 @@ LogicalResult mlir::loopUnrollByFactor(
         boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
   }
 
+  SmallVector<scf::ForOp, 2> resultLoops;
+  resultLoops.push_back(forOp);
+
   // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
   if (generateEpilogueLoop) {
     OpBuilder epilogueBuilder(forOp->getContext());
@@ -468,6 +473,7 @@ LogicalResult mlir::loopUnrollByFactor(
     epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
                                epilogueForOp.getInitArgs().size(), results);
     (void)epilogueForOp.promoteIfSingleIteration(rewriter);
+    resultLoops.push_back(epilogueForOp);
   }
 
   // Create unrolled loop.
@@ -490,7 +496,7 @@ LogicalResult mlir::loopUnrollByFactor(
       annotateFn, iterArgs, yieldedValues);
   // Promote the loop body up if this has turned into a single iteration loop.
   (void)forOp.promoteIfSingleIteration(rewriter);
-  return success();
+  return resultLoops;
 }
 
 /// Check if bounds of all inner loops are defined outside of `forOp`

>From ec4600241bf8f1acac2c403b9c4b2a43a68802f7 Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Fri, 1 Nov 2024 11:06:51 -0700
Subject: [PATCH 2/3] make return value more structured.

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h     | 17 +++++++++++------
 .../SCF/TransformOps/SCFTransformOps.cpp        |  5 ++---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp            | 17 ++++++++---------
 3 files changed, 21 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index eda64ea69f81d1..c3bd6d86864186 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -111,13 +111,18 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
 void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
                            ArrayRef<std::vector<unsigned>> combinedDimensions);
 
+struct UnrolledLoopInfo {
+  scf::ForOp mainLoopOp;
+  scf::ForOp epilogueLoopOp;
+};
+
 /// Unrolls this for operation by the specified unroll factor. Returns the
-/// unrolled main loop and the eplilog loop in sequence, if the loop is
-/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
-/// either due to restrictions or due to invalid unroll factors. Requires
-/// positive loop bounds and step. If specified, annotates the Ops in each
-/// unrolled iteration by applying `annotateFn`.
-SmallVector<scf::ForOp> loopUnrollByFactor(
+/// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise
+/// returns a strucutre of null fields if the loop cannot be unrolled either due
+/// to restrictions or due to invalid unroll factors. Requires positive loop
+/// bounds and step. If specified, annotates the Ops in each unrolled iteration
+/// by applying `annotateFn`.
+UnrolledLoopInfo loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index c84cb13f8b6bb2..cefd023c40d96c 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -355,9 +355,8 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
   LogicalResult result(failure());
   if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
     auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
-    result = resultLoops.empty() ? failure() : success();
-  }
-  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+    result = resultLoops.mainLoopOp ? success() : failure();
+  } else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
     result = loopUnrollByFactor(affineFor, getFactor());
   else
     return emitSilenceableError()
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 8394ac47888100..a50e90af3af658 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -373,16 +373,15 @@ static void generateUnrolledLoop(
 }
 
 /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
-/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
-/// vector.
-SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
+/// eplilog loop, if the loop is unrolled. Otherwise return null.
+UnrolledLoopInfo mlir::loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
   assert(unrollFactor > 0 && "expected positive unroll factor");
 
   // Return if the loop body is empty.
   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
-    return {forOp};
+    return {forOp, nullptr};
 
   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -403,8 +402,8 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
     if (unrollFactor == 1) {
       if (*constTripCount == 1 &&
           failed(forOp.promoteIfSingleIteration(rewriter)))
-        return {};
-      return {forOp};
+        return {nullptr, nullptr};
+      return {forOp, nullptr};
     }
 
     int64_t tripCountEvenMultiple =
@@ -452,8 +451,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
         boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
   }
 
-  SmallVector<scf::ForOp, 2> resultLoops;
-  resultLoops.push_back(forOp);
+  UnrolledLoopInfo resultLoops;
 
   // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
   if (generateEpilogueLoop) {
@@ -473,7 +471,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
     epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
                                epilogueForOp.getInitArgs().size(), results);
     (void)epilogueForOp.promoteIfSingleIteration(rewriter);
-    resultLoops.push_back(epilogueForOp);
+    resultLoops.epilogueLoopOp = epilogueForOp;
   }
 
   // Create unrolled loop.
@@ -496,6 +494,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
       annotateFn, iterArgs, yieldedValues);
   // Promote the loop body up if this has turned into a single iteration loop.
   (void)forOp.promoteIfSingleIteration(rewriter);
+  resultLoops.mainLoopOp = forOp;
   return resultLoops;
 }
 

>From 69e5cbe2623c6cbe6e0a4ec3be2d6cc753baf4fc Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Sun, 3 Nov 2024 12:06:24 -0800
Subject: [PATCH 3/3] Changed return type to FailureOr<UnrolledLoopInfo>

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h    |  6 +++---
 .../SCF/TransformOps/SCFTransformOps.cpp       |  3 +--
 mlir/lib/Dialect/SCF/Utils/Utils.cpp           | 18 +++++++++---------
 3 files changed, 13 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index c3bd6d86864186..9c41fc6a30b809 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -112,8 +112,8 @@ void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
                            ArrayRef<std::vector<unsigned>> combinedDimensions);
 
 struct UnrolledLoopInfo {
-  scf::ForOp mainLoopOp;
-  scf::ForOp epilogueLoopOp;
+  scf::ForOp mainLoopOp = nullptr;
+  scf::ForOp epilogueLoopOp = nullptr;
 };
 
 /// Unrolls this for operation by the specified unroll factor. Returns the
@@ -122,7 +122,7 @@ struct UnrolledLoopInfo {
 /// to restrictions or due to invalid unroll factors. Requires positive loop
 /// bounds and step. If specified, annotates the Ops in each unrolled iteration
 /// by applying `annotateFn`.
-UnrolledLoopInfo loopUnrollByFactor(
+FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index cefd023c40d96c..21455bd7251130 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -354,8 +354,7 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
                                     transform::TransformState &state) {
   LogicalResult result(failure());
   if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
-    auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
-    result = resultLoops.mainLoopOp ? success() : failure();
+    result = loopUnrollByFactor(scfFor, getFactor());
   } else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
     result = loopUnrollByFactor(affineFor, getFactor());
   else
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a50e90af3af658..e591ca49bccb8f 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -373,15 +373,15 @@ static void generateUnrolledLoop(
 }
 
 /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
-/// eplilog loop, if the loop is unrolled. Otherwise return null.
-UnrolledLoopInfo mlir::loopUnrollByFactor(
+/// eplilog loop, if the loop is unrolled.
+FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
   assert(unrollFactor > 0 && "expected positive unroll factor");
 
   // Return if the loop body is empty.
   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
-    return {forOp, nullptr};
+    return UnrolledLoopInfo{forOp, nullptr};
 
   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -402,8 +402,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
     if (unrollFactor == 1) {
       if (*constTripCount == 1 &&
           failed(forOp.promoteIfSingleIteration(rewriter)))
-        return {nullptr, nullptr};
-      return {forOp, nullptr};
+        return failure();
+      return UnrolledLoopInfo{forOp, nullptr};
     }
 
     int64_t tripCountEvenMultiple =
@@ -470,8 +470,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
     }
     epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
                                epilogueForOp.getInitArgs().size(), results);
-    (void)epilogueForOp.promoteIfSingleIteration(rewriter);
-    resultLoops.epilogueLoopOp = epilogueForOp;
+    if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
+      resultLoops.epilogueLoopOp = epilogueForOp;
   }
 
   // Create unrolled loop.
@@ -493,8 +493,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
       },
       annotateFn, iterArgs, yieldedValues);
   // Promote the loop body up if this has turned into a single iteration loop.
-  (void)forOp.promoteIfSingleIteration(rewriter);
-  resultLoops.mainLoopOp = forOp;
+  if (forOp.promoteIfSingleIteration(rewriter).failed())
+    resultLoops.mainLoopOp = forOp;
   return resultLoops;
 }
 



More information about the Mlir-commits mailing list