[Mlir-commits] [mlir] [mlir][scf] add unroll-full option to test-loop-unrolling pass. (PR #127158)

lonely eagle llvmlistbot at llvm.org
Thu Feb 13 18:32:26 PST 2025


https://github.com/linuxlonelyeagle created https://github.com/llvm/llvm-project/pull/127158

Some loops cannot be unrolled by affine-loop-unroll pass. After running lower-affine pass, they can be unrolled in scf.To enable conversion of vector Ops in scf to llvm dialect, unroll-full option was added.

>From 2d42771ebbd7b4c268aa12f32b18accb26ce80c5 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 14 Feb 2025 10:26:40 +0800
Subject: [PATCH] add unroll-full option to test-loop-unrolling pass.

---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  3 +
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 15 +++++
 mlir/test/Transforms/scf-loop-unroll.mlir     | 57 +++++++++++++++++++
 .../lib/Dialect/SCF/TestLoopUnrolling.cpp     | 14 ++++-
 4 files changed, 86 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 02ffa0da7a8b8..c0c11c9e38994 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -126,6 +126,9 @@ FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// Unrolls this loop completely.
+LogicalResult loopUnrollFull(scf::ForOp forOp);
+
 /// Unrolls and jams this `scf.for` operation by the specified unroll factor.
 /// Returns failure if the loop cannot be unrolled either due to restrictions or
 /// due to invalid unroll factors. In case of unroll factor of 1, the function
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index fa82bcb816a2a..0ee325f6c0439 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -498,6 +498,21 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
   return resultLoops;
 }
 
+/// Unrolls this loop completely.
+LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
+  IRRewriter rewriter(forOp.getContext());
+  std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  if (mayBeConstantTripCount.has_value()) {
+    uint64_t tripCount = *mayBeConstantTripCount;
+    if (tripCount == 0)
+      return success();
+    if (tripCount == 1)
+      return forOp.promoteIfSingleIteration(rewriter);
+    return loopUnrollByFactor(forOp, tripCount);
+  }
+  return failure();
+}
+
 /// Check if bounds of all inner loops are defined outside of `forOp`
 /// and return false if not.
 static bool areInnerBoundsInvariant(scf::ForOp forOp) {
diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir
index baf6b2970ac0e..75481863795ae 100644
--- a/mlir/test/Transforms/scf-loop-unroll.mlir
+++ b/mlir/test/Transforms/scf-loop-unroll.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
 // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
+// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
 
 // CHECK-LABEL: scf_loop_unroll_single
 func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
@@ -56,3 +57,59 @@ func.func @scf_loop_unroll_factor_1_promote() -> () {
   // UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
   // UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
 }
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single(
+// UNROLL-FULL-SAME:    %[[VAL_0:.*]]: index) -> index {
+func.func @scf_loop_unroll_full_single(%arg : index) -> index {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 4 : index
+  %4 = scf.for %iv = %0 to %2 step %1 iter_args(%arg1 = %1) -> index {
+    %3 = arith.addi %arg1, %arg : index
+    scf.yield %3 : index
+  }
+  return %4 : index
+  // UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 1 : index
+  // UNROLL-FULL: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : index
+  // UNROLL-FULL: %[[VAL_3:.*]] = arith.addi %[[VAL_2]], %[[VAL_0]] : index
+  // UNROLL-FULL: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_0]] : index
+  // UNROLL-FULL: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+  // UNROLL-FULL: return %[[VAL_5]] : index
+}
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops(
+// UNROLL-FULL-SAME:    %[[VAL_0:.*]]: vector<4x4xindex>) -> index {
+func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 4 : index
+  %6 = scf.for %arg1 = %0 to %2 step %1 iter_args(%it0 = %0) -> index {
+    %5 = scf.for %arg2 = %0 to %2 step %1 iter_args(%it1 = %it0) -> index {
+      %3 = vector.extract %arg0[%arg1, %arg2] : index from vector<4x4xindex>
+      %4 = arith.addi %3, %it1 : index
+      scf.yield %3 : index
+    }
+    scf.yield %5 : index
+  }
+  return %6 : index
+  // UNROLL-FULL:   %[[VAL_1:.*]] = arith.constant 0 : index
+  // UNROLL-FULL:   %[[VAL_2:.*]] = arith.constant 1 : index
+  // UNROLL-FULL:   %[[VAL_3:.*]] = arith.constant 4 : index
+  // UNROLL-FULL:   %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_6:.*]] = %[[VAL_1]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_7:.*]] = vector.extract %[[VAL_0]][0, %[[VAL_5]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_7]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL:   %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_10:.*]] = %[[VAL_4]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_11:.*]] = vector.extract %[[VAL_0]][1, %[[VAL_9]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_11]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL:   %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_8]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_15:.*]] = vector.extract %[[VAL_0]][2, %[[VAL_13]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_15]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL:   %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_12]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_19:.*]] = vector.extract %[[VAL_0]][3, %[[VAL_17]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_19]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL: return %[[VAL_16]] : index
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
index 8694a7f9bbd62..cb554c3dfb66c 100644
--- a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
@@ -42,10 +42,11 @@ struct TestLoopUnrollingPass
   TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
   explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
                                  unsigned loopDepthParam,
-                                 bool annotateLoopParam) {
+                                 bool annotateLoopParam, bool unrollFullParam) {
     unrollFactor = unrollFactorParam;
     loopDepth = loopDepthParam;
     annotateLoop = annotateLoopParam;
+    unrollFull = unrollFactorParam;
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -63,8 +64,12 @@ struct TestLoopUnrollingPass
         op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
       }
     };
-    for (auto loop : loops)
-      (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+    for (auto loop : loops) {
+      if (unrollFull)
+        loopUnrollFull(loop);
+      else
+        (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+    }
   }
   Option<uint64_t> unrollFactor{*this, "unroll-factor",
                                 llvm::cl::desc("Loop unroll factor."),
@@ -77,6 +82,9 @@ struct TestLoopUnrollingPass
                                 llvm::cl::init(false)};
   Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
                              llvm::cl::init(0)};
+  Option<bool> unrollFull{*this, "unroll-full",
+                          llvm::cl::desc("Full unroll loops."),
+                          llvm::cl::init(false)};
 };
 } // namespace
 



More information about the Mlir-commits mailing list