[Mlir-commits] [mlir] 87a9be2 - Don't fail if unable to promote loops during unrolling

Mehdi Amini llvmlistbot at llvm.org
Mon Jan 10 14:26:32 PST 2022


Author: Tyler Augustine
Date: 2022-01-10T22:26:21Z
New Revision: 87a9be2a74a98c8baf69009122ce7e75111d530e

URL: https://github.com/llvm/llvm-project/commit/87a9be2a74a98c8baf69009122ce7e75111d530e
DIFF: https://github.com/llvm/llvm-project/commit/87a9be2a74a98c8baf69009122ce7e75111d530e.diff

LOG: Don't fail if unable to promote loops during unrolling

When the unroll factor is 1, we should only fail "unrolling" when the trip count also is determined to be 1 and it is unable to be promoted.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D115365

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/test/Transforms/scf-loop-unroll.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 5746c4a588eb5..b9d1645692ff6 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1182,8 +1182,14 @@ LogicalResult mlir::loopUnrollByFactor(
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
   assert(unrollFactor > 0 && "unroll factor should be positive");
 
-  if (unrollFactor == 1)
-    return promoteIfSingleIteration(forOp);
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  if (unrollFactor == 1) {
+    if (mayBeConstantTripCount.hasValue() &&
+        mayBeConstantTripCount.getValue() == 1 &&
+        failed(promoteIfSingleIteration(forOp)))
+      return failure();
+    return success();
+  }
 
   // Nothing in the loop body other than the terminator.
   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
@@ -1191,7 +1197,6 @@ LogicalResult mlir::loopUnrollByFactor(
 
   // If the trip count is lower than the unroll factor, no unrolled body.
   // TODO: option to specify cleanup loop unrolling.
-  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   if (mayBeConstantTripCount.hasValue() &&
       mayBeConstantTripCount.getValue() < unrollFactor)
     return failure();
@@ -1237,8 +1242,6 @@ LogicalResult mlir::loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
   assert(unrollFactor > 0 && "expected positive unroll factor");
-  if (unrollFactor == 1)
-    return promoteIfSingleIteration(forOp);
 
   // Return if the loop body is empty.
   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
@@ -1264,6 +1267,13 @@ LogicalResult mlir::loopUnrollByFactor(
     assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
            "expected positive loop bounds and step");
     int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
+
+    if (unrollFactor == 1) {
+      if (tripCount == 1 && failed(promoteIfSingleIteration(forOp)))
+        return failure();
+      return success();
+    }
+
     int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     assert(upperBoundUnrolledCst <= ubCst);
@@ -1403,14 +1413,19 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
                                           uint64_t unrollJamFactor) {
   assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
 
-  if (unrollJamFactor == 1)
-    return promoteIfSingleIteration(forOp);
+  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  if (unrollJamFactor == 1) {
+    if (mayBeConstantTripCount.hasValue() &&
+        mayBeConstantTripCount.getValue() == 1 &&
+        failed(promoteIfSingleIteration(forOp)))
+      return failure();
+    return success();
+  }
 
   // Nothing in the loop body other than the terminator.
   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
     return success();
 
-  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   // If the trip count is lower than the unroll jam factor, no unroll jam.
   if (mayBeConstantTripCount.hasValue() &&
       mayBeConstantTripCount.getValue() < unrollJamFactor) {

diff  --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir
index a211c9e970f5d..fd19d9f29ee33 100644
--- a/mlir/test/Transforms/scf-loop-unroll.mlir
+++ b/mlir/test/Transforms/scf-loop-unroll.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
 
 // CHECK-LABEL: scf_loop_unroll_single
 func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
@@ -42,3 +43,16 @@ func @scf_loop_unroll_double_symbolic_ub(%arg0 : f32, %arg1 : f32, %n : index) -
   // CHECK:      }
   // CHECK-NEXT: return %[[SUM1]]#0, %[[SUM1]]#1
 }
+
+// UNROLL-BY-1-LABEL: scf_loop_unroll_factor_1_promote
+func @scf_loop_unroll_factor_1_promote() -> () {
+  %step = arith.constant 1 : index
+  %lo = arith.constant 0 : index
+  %hi = arith.constant 1 : index
+  scf.for %i = %lo to %hi step %step {
+    %x = "test.foo"(%i) : (index) -> i32
+  }
+  return
+  // UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
+  // UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
+}


        


More information about the Mlir-commits mailing list