[Mlir-commits] [mlir] b227c25 - [mlir][scf] add unroll-full option to test-loop-unrolling pass (#127158)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 17 23:50:02 PST 2025
Author: lonely eagle
Date: 2025-02-18T15:49:58+08:00
New Revision: b227c2578cc77e2113846b270d0b0a08f53e8019
URL: https://github.com/llvm/llvm-project/commit/b227c2578cc77e2113846b270d0b0a08f53e8019
DIFF: https://github.com/llvm/llvm-project/commit/b227c2578cc77e2113846b270d0b0a08f53e8019.diff
LOG: [mlir][scf] add unroll-full option to test-loop-unrolling pass (#127158)
Some loops cannot be unrolled by affine-loop-unroll pass. After running
lower-affine pass, they can be unrolled in scf.To enable conversion of
vector Ops in scf to llvm dialect, unroll-full option was added.
---------
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Transforms/scf-loop-unroll.mlir
mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 02ffa0da7a8b8..c0c11c9e38994 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -126,6 +126,9 @@ FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
+/// Unrolls this loop completely.
+LogicalResult loopUnrollFull(scf::ForOp forOp);
+
/// Unrolls and jams this `scf.for` operation by the specified unroll factor.
/// Returns failure if the loop cannot be unrolled either due to restrictions or
/// due to invalid unroll factors. In case of unroll factor of 1, the function
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index fa82bcb816a2a..bc1cb24303ad2 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -498,6 +498,20 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
return resultLoops;
}
+/// Unrolls this loop completely.
+LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
+ IRRewriter rewriter(forOp.getContext());
+ std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ if (!mayBeConstantTripCount.has_value())
+ return failure();
+ uint64_t tripCount = *mayBeConstantTripCount;
+ if (tripCount == 0)
+ return success();
+ if (tripCount == 1)
+ return forOp.promoteIfSingleIteration(rewriter);
+ return loopUnrollByFactor(forOp, tripCount);
+}
+
/// Check if bounds of all inner loops are defined outside of `forOp`
/// and return false if not.
static bool areInnerBoundsInvariant(scf::ForOp forOp) {
diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir
index baf6b2970ac0e..0ef6ad15d4eb0 100644
--- a/mlir/test/Transforms/scf-loop-unroll.mlir
+++ b/mlir/test/Transforms/scf-loop-unroll.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
+// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
// CHECK-LABEL: scf_loop_unroll_single
func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
@@ -56,3 +57,59 @@ func.func @scf_loop_unroll_factor_1_promote() -> () {
// UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
// UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
}
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single
+// UNROLL-FULL-SAME: %[[ARG:.*]]: index)
+func.func @scf_loop_unroll_full_single(%arg : index) -> index {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ %2 = arith.constant 4 : index
+ %4 = scf.for %iv = %0 to %2 step %1 iter_args(%arg1 = %1) -> index {
+ %3 = arith.addi %arg1, %arg : index
+ scf.yield %3 : index
+ }
+ return %4 : index
+ // UNROLL-FULL: %[[C1:.*]] = arith.constant 1 : index
+ // UNROLL-FULL: %[[V0:.*]] = arith.addi %[[ARG]], %[[C1]] : index
+ // UNROLL-FULL: %[[V1:.*]] = arith.addi %[[V0]], %[[ARG]] : index
+ // UNROLL-FULL: %[[V2:.*]] = arith.addi %[[V1]], %[[ARG]] : index
+ // UNROLL-FULL: %[[V3:.*]] = arith.addi %[[V2]], %[[ARG]] : index
+ // UNROLL-FULL: return %[[V3]] : index
+}
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops
+// UNROLL-FULL-SAME: %[[ARG:.*]]: vector<4x4xindex>)
+func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ %2 = arith.constant 4 : index
+ %6 = scf.for %arg1 = %0 to %2 step %1 iter_args(%it0 = %0) -> index {
+ %5 = scf.for %arg2 = %0 to %2 step %1 iter_args(%it1 = %it0) -> index {
+ %3 = vector.extract %arg0[%arg1, %arg2] : index from vector<4x4xindex>
+ %4 = arith.addi %3, %it1 : index
+ scf.yield %3 : index
+ }
+ scf.yield %5 : index
+ }
+ return %6 : index
+ // UNROLL-FULL: %[[C0:.*]] = arith.constant 0 : index
+ // UNROLL-FULL: %[[C1:.*]] = arith.constant 1 : index
+ // UNROLL-FULL: %[[C4:.*]] = arith.constant 4 : index
+ // UNROLL-FULL: %[[SUM0:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[C0]])
+ // UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][0, %[[IV]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: %[[SUM1:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM0]])
+ // UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][1, %[[IV]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: %[[SUM2:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM1]])
+ // UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][2, %[[IV]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: %[[SUM3:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM2]])
+ // UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][3, %[[IV]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: return %[[SUM3]] : index
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
index 8694a7f9bbd62..ced003305a7b8 100644
--- a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
@@ -42,10 +42,11 @@ struct TestLoopUnrollingPass
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
unsigned loopDepthParam,
- bool annotateLoopParam) {
+ bool annotateLoopParam, bool unrollFullParam) {
unrollFactor = unrollFactorParam;
loopDepth = loopDepthParam;
annotateLoop = annotateLoopParam;
+ unrollFull = unrollFactorParam;
}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -63,8 +64,12 @@ struct TestLoopUnrollingPass
op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
}
};
- for (auto loop : loops)
- (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+ for (auto loop : loops) {
+ if (unrollFull)
+ (void)loopUnrollFull(loop);
+ else
+ (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+ }
}
Option<uint64_t> unrollFactor{*this, "unroll-factor",
llvm::cl::desc("Loop unroll factor."),
@@ -77,6 +82,9 @@ struct TestLoopUnrollingPass
llvm::cl::init(false)};
Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
llvm::cl::init(0)};
+ Option<bool> unrollFull{*this, "unroll-full",
+ llvm::cl::desc("Full unroll loops."),
+ llvm::cl::init(false)};
};
} // namespace
More information about the Mlir-commits
mailing list