[Mlir-commits] [mlir] [SCF] Fixed epilogue predicates in loop pipelining (PR #108964)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 17 04:47:32 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: SJW (sjw36)

<details>
<summary>Changes</summary>

The computed loop iteration is zero based, so only check it is less than zero. 
This fixes the case when lower bound is not zero.

---
Full diff: https://github.com/llvm/llvm-project/pull/108964.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (+5-2) 
- (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+3-3) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7cecd4942b640f..ad6f790a5ba02c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -655,6 +655,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
   Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
   Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
 
+  Value zero =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+
   SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
     // iterI = total_iters - 1 - i
@@ -671,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
 
     if (dynamicLoop) {
-      // pred = iterI >= lb
+      // pred = iterI < 0
       predicates[i + 1] = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::sge, iterI, lb);
+          loc, arith::CmpIPredicate::slt, iterI, zero);
     }
   }
 
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4a1406faabce1b..048786bad5d447 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -781,12 +781,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
 //        CHECK:   %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
 //        CHECK:   %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_14]], %{{.*}}
 //        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
 //        CHECK:   %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
 //        CHECK:   %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
 //        CHECK:   %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi slt, %[[ADDI_19]], %{{.*}}
 //        CHECK:   scf.if %[[CMPI_17]] {
 //        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
 //        CHECK:   } else {
@@ -845,7 +845,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
 //       CHECK:     %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
 //       CHECK:     %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
 //       CHECK:     %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
-//       CHECK:     %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
+//       CHECK:     %[[CMPI_9:.*]] = arith.cmpi slt, %[[ADDI_8]], %{{.*}}
 //       CHECK:     %[[IF_10:.*]] = scf.if %[[CMPI_9]]
 //       CHECK:       %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
 //       CHECK:       scf.yield %[[ADDF_13]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/108964


More information about the Mlir-commits mailing list