[Mlir-commits] [mlir] [MLIR][LoopPipelining] Improve schedule verifier, so it checks also operands of nested operations (PR #88450)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 11 15:37:43 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (pawelszczerbuk)
<details>
<summary>Changes</summary>
`verifySchedule` was not looking at the operands of nested operations, which caused incorrect schedule to be allowed in some cases, potentially leading to crash during expansion.
There is also minor fix in `cloneAndUpdateOperands` in the pipeline expander that prevents double visit of the cloned op. This one has no functional impact, so no test for it.
---
Full diff: https://github.com/llvm/llvm-project/pull/88450.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (+14-4)
- (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+23)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 9eda1a4597ba43..82ec95d31f525f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -203,6 +203,17 @@ bool LoopPipelinerInternal::initializeLoopInfo(
return true;
}
+/// Find operands of all the nested operations within `op`.
+static SetVector<Value> getNestedOperands(Operation *op) {
+ SetVector<Value> operands;
+ op->walk([&](Operation *nestedOp) {
+ for (Value operand : nestedOp->getOperands()) {
+ operands.insert(operand);
+ }
+ });
+ return operands;
+}
+
/// Compute unrolled cycles of each op (consumer) and verify that each op is
/// scheduled after its operands (producers) while adjusting for the distance
/// between producer and consumer.
@@ -219,7 +230,7 @@ bool LoopPipelinerInternal::verifySchedule() {
}
for (Operation *consumer : opOrder) {
int64_t consumerCycle = unrolledCyles[consumer];
- for (Value operand : consumer->getOperands()) {
+ for (Value operand : getNestedOperands(consumer)) {
auto [producer, distance] = getDefiningOpAndDistance(operand);
if (!producer)
continue;
@@ -245,9 +256,8 @@ static Operation *
cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
function_ref<void(OpOperand *newOperand)> callback) {
Operation *clone = rewriter.clone(*op);
- for (OpOperand &operand : clone->getOpOperands())
- callback(&operand);
- clone->walk([&](Operation *nested) {
+ clone->walk<WalkOrder::PreOrder>([&](Operation *nested) {
+ // 'clone' itself will be visited first.
for (OpOperand &operand : nested->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 8d6f454d187518..46e7feca4329ee 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -846,3 +846,26 @@ func.func @invalid_schedule2(%A: memref<?xf32>, %result: memref<?xf32>) {
} { __test_pipelining_loop__ }
return
}
+
+// -----
+
+func.func @invalid_schedule3(%A: memref<?xf32>, %result: memref<?xf32>, %ext: index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
+ %cnd = arith.cmpi slt, %ext, %c4 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : index
+ // expected-error at +1 {{operation scheduled before its operands}}
+ %idx1 = scf.if %cnd -> (index) {
+ %idxinc = arith.addi %idx, %c1 : index
+ scf.yield %idxinc : index
+ } else {
+ scf.yield %idx : index
+ } { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 }
+ %A_elem = memref.load %A[%idx1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ %idx2 = arith.addi %idx1, %c1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 3 } : index
+ memref.store %A_elem, %result[%idx1] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 4 } : memref<?xf32>
+ scf.yield %idx2 : index
+ } { __test_pipelining_loop__ }
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/88450
More information about the Mlir-commits
mailing list