[Mlir-commits] [mlir] [MLIR][LoopPipelining] Improve schedule verifier, so it checks also operands of nested operations (PR #88450)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 11 15:37:01 PDT 2024


https://github.com/pawelszczerbuk created https://github.com/llvm/llvm-project/pull/88450

`verifySchedule` was not looking at the operands of nested operations, which caused incorrect schedule to be allowed in some cases, potentially leading to crash during expansion.
There is also minor fix in `cloneAndUpdateOperands` in the pipeline expander that prevents double visit of the cloned op. This one has no functional impact, so no test for it.

>From df3c3e26443aa548e6d1b9d3c0f2dae99024534c Mon Sep 17 00:00:00 2001
From: Pawel Szczerbuk <pawel.szczerbuk at openai.com>
Date: Thu, 11 Apr 2024 15:29:30 -0700
Subject: [PATCH] * Fix a bug in verifySchedule that caused nested ops to be
 not be verified * Fix double visit when cloning op in the loop pipeliner

---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 18 +++++++++++----
 mlir/test/Dialect/SCF/loop-pipelining.mlir    | 23 +++++++++++++++++++
 2 files changed, 37 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 9eda1a4597ba43..82ec95d31f525f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -203,6 +203,17 @@ bool LoopPipelinerInternal::initializeLoopInfo(
   return true;
 }
 
+/// Find operands of all the nested operations within `op`.
+static SetVector<Value> getNestedOperands(Operation *op) {
+  SetVector<Value> operands;
+  op->walk([&](Operation *nestedOp) {
+    for (Value operand : nestedOp->getOperands()) {
+      operands.insert(operand);
+    }
+  });
+  return operands;
+}
+
 /// Compute unrolled cycles of each op (consumer) and verify that each op is
 /// scheduled after its operands (producers) while adjusting for the distance
 /// between producer and consumer.
@@ -219,7 +230,7 @@ bool LoopPipelinerInternal::verifySchedule() {
   }
   for (Operation *consumer : opOrder) {
     int64_t consumerCycle = unrolledCyles[consumer];
-    for (Value operand : consumer->getOperands()) {
+    for (Value operand : getNestedOperands(consumer)) {
       auto [producer, distance] = getDefiningOpAndDistance(operand);
       if (!producer)
         continue;
@@ -245,9 +256,8 @@ static Operation *
 cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
                        function_ref<void(OpOperand *newOperand)> callback) {
   Operation *clone = rewriter.clone(*op);
-  for (OpOperand &operand : clone->getOpOperands())
-    callback(&operand);
-  clone->walk([&](Operation *nested) {
+  clone->walk<WalkOrder::PreOrder>([&](Operation *nested) {
+    // 'clone' itself will be visited first.
     for (OpOperand &operand : nested->getOpOperands()) {
       Operation *def = operand.get().getDefiningOp();
       if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 8d6f454d187518..46e7feca4329ee 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -846,3 +846,26 @@ func.func @invalid_schedule2(%A: memref<?xf32>, %result: memref<?xf32>) {
   }  { __test_pipelining_loop__ }
   return
 }
+
+// -----
+
+func.func @invalid_schedule3(%A: memref<?xf32>, %result: memref<?xf32>, %ext: index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
+    %cnd = arith.cmpi slt, %ext, %c4 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : index
+    // expected-error at +1 {{operation scheduled before its operands}}
+    %idx1 = scf.if %cnd -> (index) {
+      %idxinc = arith.addi %idx, %c1 : index
+      scf.yield %idxinc : index
+    } else {
+      scf.yield %idx : index
+    } { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 }
+    %A_elem = memref.load %A[%idx1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    %idx2 = arith.addi %idx1, %c1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 3 } : index
+    memref.store %A_elem, %result[%idx1] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 4 } : memref<?xf32>
+    scf.yield %idx2 : index
+  }  { __test_pipelining_loop__ }
+  return
+}



More information about the Mlir-commits mailing list