[Mlir-commits] [mlir] [SCF] Fixed epilogue predicates in loop pipelining (PR #108964)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 17 08:33:16 PDT 2024
https://github.com/sjw36 updated https://github.com/llvm/llvm-project/pull/108964
>From a9343ff2abb51c4febce6b317b7c7f7db16a01eb Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Tue, 17 Sep 2024 11:43:02 +0000
Subject: [PATCH 1/2] [SCF] Fixed epilogue predicates in loop pipelining
The computed loop iteration is zero based, so only check it is less than zero.
This fixes the case when lower bound is not zero.
---
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 7 +++++--
mlir/test/Dialect/SCF/loop-pipelining.mlir | 6 +++---
2 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7cecd4942b640f..ad6f790a5ba02c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -655,6 +655,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
+ Value zero =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+
SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
// iterI = total_iters - 1 - i
@@ -671,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
if (dynamicLoop) {
- // pred = iterI >= lb
+ // pred = iterI < 0
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, iterI, lb);
+ loc, arith::CmpIPredicate::slt, iterI, zero);
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4a1406faabce1b..048786bad5d447 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -781,12 +781,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+// CHECK: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_14]], %{{.*}}
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ADDI_19]], %{{.*}}
// CHECK: scf.if %[[CMPI_17]] {
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
// CHECK: } else {
@@ -845,7 +845,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
-// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
+// CHECK: %[[CMPI_9:.*]] = arith.cmpi slt, %[[ADDI_8]], %{{.*}}
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]]
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
// CHECK: scf.yield %[[ADDF_13]]
>From 86ccc0bb7e403c267765c2b4975f0df56ae50e45 Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Tue, 17 Sep 2024 15:31:53 +0000
Subject: [PATCH 2/2] * fixed cmp predicate
---
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 4 ++--
mlir/test/Dialect/SCF/loop-pipelining.mlir | 6 +++---
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index ad6f790a5ba02c..3d6da066875f99 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -674,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
if (dynamicLoop) {
- // pred = iterI < 0
+ // pred = iterI >= 0
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, iterI, zero);
+ loc, arith::CmpIPredicate::sge, iterI, zero);
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 048786bad5d447..4a1406faabce1b 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -781,12 +781,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-// CHECK: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_14]], %{{.*}}
+// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ADDI_19]], %{{.*}}
+// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
// CHECK: scf.if %[[CMPI_17]] {
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
// CHECK: } else {
@@ -845,7 +845,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
-// CHECK: %[[CMPI_9:.*]] = arith.cmpi slt, %[[ADDI_8]], %{{.*}}
+// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]]
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
// CHECK: scf.yield %[[ADDF_13]]
More information about the Mlir-commits
mailing list