[Mlir-commits] [mlir] [MLIR] Removes incorrect assertion in loop unroller (PR #69028)

Stephen Chou llvmlistbot at llvm.org
Fri Oct 13 13:46:26 PDT 2023


https://github.com/stephenchouca created https://github.com/llvm/llvm-project/pull/69028

In particular, `upperBoundUnrolledCst` may be larger than `ubCst` when:

1. the step size is greater than 1;
2. `ub - lb` is not evenly divisible by the step size; and
3. the loop's trip count is evenly divisible by the unroll factor.

This is okay since the non-unit step size ensures that the unrolled loop maintains the same trip count as the original loop. Added a test case for this.

This also fixes #61832.

>From f2c9a76cc95f8047fffa6ec44179cdaa84d7c8ab Mon Sep 17 00:00:00 2001
From: Stephen Chou <stephenchou at google.com>
Date: Fri, 13 Oct 2023 20:12:25 +0000
Subject: [PATCH] [MLIR] Removes incorrect assertion in loop unroller.

In particular, `upperBoundUnrolledCst` may be larger than `ubCst` when:

1. the step size is greater than 1;
2. `ub - lb` is not evenly divisible by the step size; and
3. the loop's trip count is evenly divisible by the unroll factor.

This is okay since the non-unit step size ensures that the unrolled loop maintains the same trip count as the original loop. Added a test case for this.
---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp   |  1 -
 mlir/test/Dialect/SCF/loop-unroll.mlir | 30 ++++++++++++++++++++++++++
 2 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 5360c493f8f8d71..e85825595e3c1ee 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -391,7 +391,6 @@ LogicalResult mlir::loopUnrollByFactor(
 
     int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
-    assert(upperBoundUnrolledCst <= ubCst);
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
     // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir
index c83e33d7fbc9c6e..e28efbb6ec2b911 100644
--- a/mlir/test/Dialect/SCF/loop-unroll.mlir
+++ b/mlir/test/Dialect/SCF/loop-unroll.mlir
@@ -186,6 +186,36 @@ func.func @static_loop_unroll_by_2(%arg0 : memref<?xf32>) {
 // UNROLL-BY-2-ANNOTATE:    memref.store %{{.*}}, %[[MEM:.*0]][%{{.*}}] {unrolled_iteration = 0 : ui32} : memref<?xf32>
 // UNROLL-BY-2-ANNOTATE:    memref.store %{{.*}}, %[[MEM]][%{{.*}}] {unrolled_iteration = 1 : ui32} : memref<?xf32>
 
+// Test that no epilogue clean-up loop is generated because the trip count
+// (taking into account the non-unit step size) is a multiple of the unroll
+// factor.
+func.func @static_loop_step_2_unroll_by_2(%arg0 : memref<?xf32>) {
+  %0 = arith.constant 7.0 : f32
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 19 : index
+  %step = arith.constant 2 : index
+  scf.for %i0 = %lb to %ub step %step {
+    memref.store %0, %arg0[%i0] : memref<?xf32>
+  }
+  return
+}
+
+// UNROLL-BY-2-LABEL: func @static_loop_step_2_unroll_by_2
+//  UNROLL-BY-2-SAME:  %[[MEM:.*0]]: memref<?xf32>
+//
+//   UNROLL-BY-2-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//   UNROLL-BY-2-DAG:  %[[C2:.*]] = arith.constant 2 : index
+//   UNROLL-BY-2-DAG:  %[[C19:.*]] = arith.constant 19 : index
+//   UNROLL-BY-2-DAG:  %[[C4:.*]] = arith.constant 4 : index
+//   UNROLL-BY-2:  scf.for %[[IV:.*]] = %[[C0]] to %[[C19]] step %[[C4]] {
+//  UNROLL-BY-2-NEXT:    memref.store %{{.*}}, %[[MEM]][%[[IV]]] : memref<?xf32>
+//  UNROLL-BY-2-NEXT:    %[[C1_IV:.*]] = arith.constant 1 : index
+//  UNROLL-BY-2-NEXT:    %[[V0:.*]] = arith.muli %[[C2]], %[[C1_IV]] : index
+//  UNROLL-BY-2-NEXT:    %[[V1:.*]] = arith.addi %[[IV]], %[[V0]] : index
+//  UNROLL-BY-2-NEXT:    memref.store %{{.*}}, %[[MEM]][%[[V1]]] : memref<?xf32>
+//  UNROLL-BY-2-NEXT:  }
+//  UNROLL-BY-2-NEXT:  return
+
 // Test that epilogue clean up loop is generated (trip count is not
 // a multiple of unroll factor).
 func.func @static_loop_unroll_by_3(%arg0 : memref<?xf32>) {



More information about the Mlir-commits mailing list