[Mlir-commits] [mlir] [SCF] Fixed epilogue predicates in loop pipelining (PR #108964)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 18 11:03:46 PDT 2024
https://github.com/sjw36 updated https://github.com/llvm/llvm-project/pull/108964
>From 2f6eac44e893c31da13163e1e9b815485be6dd90 Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Tue, 17 Sep 2024 11:43:02 +0000
Subject: [PATCH 1/3] [SCF] Fixed epilogue predicates in loop pipelining
The computed loop iteration is zero based, so only check it is less than zero.
This fixes the case when lower bound is not zero.
---
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 7 +++++--
mlir/test/Dialect/SCF/loop-pipelining.mlir | 6 +++---
2 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7cecd4942b640f..ad6f790a5ba02c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -655,6 +655,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
+ Value zero =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+
SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
// iterI = total_iters - 1 - i
@@ -671,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
if (dynamicLoop) {
- // pred = iterI >= lb
+ // pred = iterI < 0
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, iterI, lb);
+ loc, arith::CmpIPredicate::slt, iterI, zero);
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4a1406faabce1b..048786bad5d447 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -781,12 +781,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+// CHECK: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_14]], %{{.*}}
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ADDI_19]], %{{.*}}
// CHECK: scf.if %[[CMPI_17]] {
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
// CHECK: } else {
@@ -845,7 +845,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
-// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
+// CHECK: %[[CMPI_9:.*]] = arith.cmpi slt, %[[ADDI_8]], %{{.*}}
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]]
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
// CHECK: scf.yield %[[ADDF_13]]
>From 995e3dd535dee01aaf8f007187b0647cf28479c8 Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Tue, 17 Sep 2024 15:31:53 +0000
Subject: [PATCH 2/3] * fixed cmp predicate
---
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 4 ++--
mlir/test/Dialect/SCF/loop-pipelining.mlir | 6 +++---
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index ad6f790a5ba02c..3d6da066875f99 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -674,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
if (dynamicLoop) {
- // pred = iterI < 0
+ // pred = iterI >= 0
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, iterI, zero);
+ loc, arith::CmpIPredicate::sge, iterI, zero);
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 048786bad5d447..4a1406faabce1b 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -781,12 +781,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-// CHECK: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_14]], %{{.*}}
+// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ADDI_19]], %{{.*}}
+// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
// CHECK: scf.if %[[CMPI_17]] {
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
// CHECK: } else {
@@ -845,7 +845,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
-// CHECK: %[[CMPI_9:.*]] = arith.cmpi slt, %[[ADDI_8]], %{{.*}}
+// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]]
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
// CHECK: scf.yield %[[ADDF_13]]
>From e7c55588f6b79a044228b1b72fed8ece16d26d53 Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Wed, 18 Sep 2024 14:40:05 +0000
Subject: [PATCH 3/3] * updated test for non-zero lower_bound
---
mlir/test/Dialect/SCF/loop-pipelining.mlir | 41 +++++++++++-----------
1 file changed, 21 insertions(+), 20 deletions(-)
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4a1406faabce1b..4eb6a9d3bf3473 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -735,42 +735,41 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// -----
// NOEPILOGUE-LABEL: dynamic_loop(
-// NOEPILOGUE-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>, %[[LB:.+]]: index, %[[UB:.+]]: index, %[[STEP:.+]]: index) {
-// NOEPILOGUE-DAG: %[[C2:.+]] = arith.constant 2 : index
+// NOEPILOGUE-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>, %[[UB:.+]]: index) {
+// NOEPILOGUE-DAG: %[[C4:.+]] = arith.constant 4 : index
+// NOEPILOGUE-DAG: %[[C5:.+]] = arith.constant 5 : index
// NOEPILOGUE-DAG: %[[CSTF:.+]] = arith.constant 1.000000e+00 : f32
+// NOEPILOGUE-DAG: %[[LB:.+]] = arith.constant 3 : index
+// NOEPILOGUE-DAG: %[[STEP:.+]] = arith.constant 2 : index
// Prologue:
-// NOEPILOGUE: %[[P_I0:.+]] = arith.cmpi slt, %[[LB]], %[[UB]] : index
+// NOEPILOGUE: %[[P_I0:.+]] = arith.cmpi sgt, %[[UB]], %[[LB]] : index
// NOEPILOGUE: %[[L0:.+]] = scf.if %[[P_I0]] -> (f32) {
// NOEPILOGUE-NEXT: memref.load %[[A]][%[[LB]]] : memref<?xf32>
-// NOEPILOGUE: %[[IV1:.+]] = arith.addi %[[LB]], %[[STEP]] : index
-// NOEPILOGUE: %[[P_I1:.+]] = arith.cmpi slt, %[[IV1]], %[[UB]] : index
-// NOEPILOGUE: %[[IV1_2:.+]] = arith.addi %[[LB]], %[[STEP]] : index
+// NOEPILOGUE: %[[P_I1:.+]] = arith.cmpi sgt, %[[UB]], %[[C5]] : index
// NOEPILOGUE: %[[V0:.+]] = scf.if %[[P_I0]] -> (f32) {
// NOEPILOGUE-NEXT: arith.addf %[[L0]], %[[CSTF]] : f32
// NOEPILOGUE: %[[L1:.+]] = scf.if %[[P_I1]] -> (f32) {
-// NOEPILOGUE-NEXT: memref.load %[[A]][%[[IV1_2]]] : memref<?xf32>
+// NOEPILOGUE-NEXT: memref.load %[[A]][%[[C5]]] : memref<?xf32>
// NOEPILOGUE: scf.for %[[IV2:.+]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[V1:.+]] = %[[V0]], %[[L2:.+]] = %[[L1]]) -> (f32, f32) {
-// NOEPILOGUE-DAG: %[[S2:.+]] = arith.muli %[[STEP]], %[[C2]] : index
-// NOEPILOGUE-DAG: %[[IT2:.+]] = arith.subi %[[UB]], %[[S2]] : index
+// NOEPILOGUE-DAG: %[[IT2:.+]] = arith.subi %[[UB]], %[[C4]] : index
// NOEPILOGUE-DAG: %[[P_I2:.+]] = arith.cmpi slt, %[[IV2]], %[[IT2]] : index
// NOEPILOGUE-DAG: %[[IT3:.+]] = arith.subi %[[UB]], %[[STEP]] : index
// NOEPILOGUE-DAG: %[[P_I3:.+]] = arith.cmpi slt, %[[IV2]], %[[IT3]] : index
// NOEPILOGUE: memref.store %[[V1]], %[[R]][%[[IV2]]] : memref<?xf32>
// NOEPILOGUE: %[[V2:.+]] = scf.if %[[P_I3]] -> (f32) {
// NOEPILOGUE: arith.addf %[[L2]], %[[CSTF]] : f32
-// NOEPILOGUE: %[[IT4:.+]] = arith.muli %[[STEP]], %[[C2]] : index
-// NOEPILOGUE: %[[IV3:.+]] = arith.addi %[[IV2]], %[[IT4]] : index
+// NOEPILOGUE: %[[IV3:.+]] = arith.addi %[[IV2]], %[[C4]] : index
// NOEPILOGUE: %[[L3:.+]] = scf.if %[[P_I2]] -> (f32) {
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32>
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32
// Check for predicated epilogue for dynamic loop.
// CHECK-LABEL: dynamic_loop(
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
// CHECK: %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
-// CHECK: %[[MULI_25:.*]] = arith.muli %{{.*}}, %{{.*}}
-// CHECK: %[[ADDI_26:.*]] = arith.addi %[[ARG5]], %[[MULI_25]]
+// CHECK: %[[ADDI_26:.*]] = arith.addi %[[ARG5]], %{{.*}}
// CHECK: %[[LOAD_27:.*]] = memref.load %{{.*}}[%[[ADDI_26]]]
// CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]]
// CHECK: }
@@ -779,14 +778,14 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %{{.*}}-1
// CHECK: %[[DIVUI_13:.*]] = arith.divui %[[ADDI_12]], %{{.*}}
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
-// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
-// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+// CHECK: %[[MULI_15:.*]] = arith.muli %[[ADDI_14]], %{{.*}}
+// CHECK: %[[ADDI_16:.*]] = arith.addi %[[MULI_15]], %{{.*}}
+// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
-// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
-// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+// CHECK: %[[MULI_20:.*]] = arith.muli %[[ADDI_19]], %{{.*}}
+// CHECK: %[[ADDI_21:.*]] = arith.addi %[[MULI_20]], %{{.*}}
+// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]]
// CHECK: scf.if %[[CMPI_17]] {
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
// CHECK: } else {
@@ -802,8 +801,10 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: } else {
// CHECK: }
// CHECK: return
-func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
+func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %ub: index) {
%cf = arith.constant 1.0 : f32
+ %lb = arith.constant 3 : index
+ %step = arith.constant 2 : index
scf.for %i0 = %lb to %ub step %step {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
More information about the Mlir-commits
mailing list