[Mlir-commits] [mlir] 87a9be2 - Don't fail if unable to promote loops during unrolling
Mehdi Amini
llvmlistbot at llvm.org
Mon Jan 10 14:26:32 PST 2022
Author: Tyler Augustine
Date: 2022-01-10T22:26:21Z
New Revision: 87a9be2a74a98c8baf69009122ce7e75111d530e
URL: https://github.com/llvm/llvm-project/commit/87a9be2a74a98c8baf69009122ce7e75111d530e
DIFF: https://github.com/llvm/llvm-project/commit/87a9be2a74a98c8baf69009122ce7e75111d530e.diff
LOG: Don't fail if unable to promote loops during unrolling
When the unroll factor is 1, we should only fail "unrolling" when the trip count also is determined to be 1 and it is unable to be promoted.
Reviewed By: bondhugula
Differential Revision: https://reviews.llvm.org/D115365
Added:
Modified:
mlir/lib/Transforms/Utils/LoopUtils.cpp
mlir/test/Transforms/scf-loop-unroll.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 5746c4a588eb5..b9d1645692ff6 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1182,8 +1182,14 @@ LogicalResult mlir::loopUnrollByFactor(
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "unroll factor should be positive");
- if (unrollFactor == 1)
- return promoteIfSingleIteration(forOp);
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ if (unrollFactor == 1) {
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() == 1 &&
+ failed(promoteIfSingleIteration(forOp)))
+ return failure();
+ return success();
+ }
// Nothing in the loop body other than the terminator.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
@@ -1191,7 +1197,6 @@ LogicalResult mlir::loopUnrollByFactor(
// If the trip count is lower than the unroll factor, no unrolled body.
// TODO: option to specify cleanup loop unrolling.
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollFactor)
return failure();
@@ -1237,8 +1242,6 @@ LogicalResult mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");
- if (unrollFactor == 1)
- return promoteIfSingleIteration(forOp);
// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
@@ -1264,6 +1267,13 @@ LogicalResult mlir::loopUnrollByFactor(
assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
"expected positive loop bounds and step");
int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
+
+ if (unrollFactor == 1) {
+ if (tripCount == 1 && failed(promoteIfSingleIteration(forOp)))
+ return failure();
+ return success();
+ }
+
int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
assert(upperBoundUnrolledCst <= ubCst);
@@ -1403,14 +1413,19 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
uint64_t unrollJamFactor) {
assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
- if (unrollJamFactor == 1)
- return promoteIfSingleIteration(forOp);
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ if (unrollJamFactor == 1) {
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() == 1 &&
+ failed(promoteIfSingleIteration(forOp)))
+ return failure();
+ return success();
+ }
// Nothing in the loop body other than the terminator.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
return success();
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
// If the trip count is lower than the unroll jam factor, no unroll jam.
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollJamFactor) {
diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir
index a211c9e970f5d..fd19d9f29ee33 100644
--- a/mlir/test/Transforms/scf-loop-unroll.mlir
+++ b/mlir/test/Transforms/scf-loop-unroll.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
// CHECK-LABEL: scf_loop_unroll_single
func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
@@ -42,3 +43,16 @@ func @scf_loop_unroll_double_symbolic_ub(%arg0 : f32, %arg1 : f32, %n : index) -
// CHECK: }
// CHECK-NEXT: return %[[SUM1]]#0, %[[SUM1]]#1
}
+
+// UNROLL-BY-1-LABEL: scf_loop_unroll_factor_1_promote
+func @scf_loop_unroll_factor_1_promote() -> () {
+ %step = arith.constant 1 : index
+ %lo = arith.constant 0 : index
+ %hi = arith.constant 1 : index
+ scf.for %i = %lo to %hi step %step {
+ %x = "test.foo"(%i) : (index) -> i32
+ }
+ return
+ // UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
+ // UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
+}
More information about the Mlir-commits
mailing list