[Mlir-commits] [mlir] c3c1c5c - [mlir][scf] Fix bug in pipelining prologue emission
Thomas Raoux
llvmlistbot at llvm.org
Thu Feb 3 13:13:32 PST 2022
Author: Thomas Raoux
Date: 2022-02-03T13:12:50-08:00
New Revision: c3c1c5c6953fcf9320a0cae5121ce55839169790
URL: https://github.com/llvm/llvm-project/commit/c3c1c5c6953fcf9320a0cae5121ce55839169790
DIFF: https://github.com/llvm/llvm-project/commit/c3c1c5c6953fcf9320a0cae5121ce55839169790.diff
LOG: [mlir][scf] Fix bug in pipelining prologue emission
Induction variable calculation was ignoring scf.for step value. Fix it to get
the correct induction variable value in the prologue.
Differential Revision: https://reviews.llvm.org/D118932
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/test/Dialect/SCF/loop-pipelining.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 795c342e9026..8e20906e7251 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -138,7 +138,8 @@ void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (int64_t i = 0; i < maxStage; i++) {
// special handling for induction variable as the increment is implicit.
- Value iv = rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(), lb + i);
+ Value iv =
+ rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(), lb + i * step);
setValueMapping(forOp.getInductionVar(), iv, i);
for (Operation *op : opOrder) {
if (stages[op] > i)
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 9424af25bd12..545a0ce981f4 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -34,6 +34,45 @@ func @simple_pipeline(%A: memref<?xf32>, %result: memref<?xf32>) {
// -----
+// CHECK-LABEL: simple_pipeline_step(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
+// Prologue:
+// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// CHECK: %[[L1:.*]] = memref.load %[[A]][%[[C3]]] : memref<?xf32>
+// Kernel:
+// CHECK-NEXT: %[[L2:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C5]]
+// CHECK-SAME: step %[[C3]] iter_args(%[[LARG0:.*]] = %[[L0]], %[[LARG1:.*]] = %[[L1]]) -> (f32, f32) {
+// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[LARG0]], %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD0]], %[[R]][%[[IV]]] : memref<?xf32>
+// CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C6]] : index
+// CHECK-NEXT: %[[LR:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
+// CHECK-NEXT: scf.yield %[[LARG1]], %[[LR]] : f32, f32
+// CHECK-NEXT: }
+// Epilogue:
+// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L2]]#0, %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD1]], %[[R]][%[[C6]]] : memref<?xf32>
+// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[L2]]#1, %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD2]], %[[R]][%[[C9]]] : memref<?xf32>
+func @simple_pipeline_step(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c3 = arith.constant 3 : index
+ %c11 = arith.constant 11 : index
+ %cf = arith.constant 1.0 : f32
+ scf.for %i0 = %c0 to %c11 step %c3 {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ %A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32
+ memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+ } { __test_pipelining_loop__ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: three_stage(
// CHECK-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
More information about the Mlir-commits
mailing list