[Mlir-commits] [mlir] [SCF] Fixed epilogue predicates in loop pipelining (PR #108964)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 17 04:47:32 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir
Author: SJW (sjw36)
<details>
<summary>Changes</summary>
The computed loop iteration is zero based, so only check it is less than zero.
This fixes the case when lower bound is not zero.
---
Full diff: https://github.com/llvm/llvm-project/pull/108964.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (+5-2)
- (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+3-3)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7cecd4942b640f..ad6f790a5ba02c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -655,6 +655,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
+ Value zero =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+
SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
// iterI = total_iters - 1 - i
@@ -671,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
if (dynamicLoop) {
- // pred = iterI >= lb
+ // pred = iterI < 0
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, iterI, lb);
+ loc, arith::CmpIPredicate::slt, iterI, zero);
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4a1406faabce1b..048786bad5d447 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -781,12 +781,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+// CHECK: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_14]], %{{.*}}
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ADDI_19]], %{{.*}}
// CHECK: scf.if %[[CMPI_17]] {
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
// CHECK: } else {
@@ -845,7 +845,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
-// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
+// CHECK: %[[CMPI_9:.*]] = arith.cmpi slt, %[[ADDI_8]], %{{.*}}
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]]
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
// CHECK: scf.yield %[[ADDF_13]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/108964
More information about the Mlir-commits
mailing list