[Mlir-commits] [mlir] 56954a5 - [MLIR][LoopPipelining] Improve schedule verifier, so it checks also operands of nested operations (#88450)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 11 17:09:47 PDT 2024
Author: pawelszczerbuk
Date: 2024-04-11T17:09:44-07:00
New Revision: 56954a53e58282d7584e31ec14a2b1052cd861e8
URL: https://github.com/llvm/llvm-project/commit/56954a53e58282d7584e31ec14a2b1052cd861e8
DIFF: https://github.com/llvm/llvm-project/commit/56954a53e58282d7584e31ec14a2b1052cd861e8.diff
LOG: [MLIR][LoopPipelining] Improve schedule verifier, so it checks also operands of nested operations (#88450)
`verifySchedule` was not looking at the operands of nested operations,
which caused incorrect schedule to be allowed in some cases, potentially
leading to crash during expansion.
There is also minor fix in `cloneAndUpdateOperands` in the pipeline
expander that prevents double visit of the cloned op. This one has no
functional impact, so no test for it.
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/test/Dialect/SCF/loop-pipelining.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 9eda1a4597ba43..82ec95d31f525f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -203,6 +203,17 @@ bool LoopPipelinerInternal::initializeLoopInfo(
return true;
}
+/// Find operands of all the nested operations within `op`.
+static SetVector<Value> getNestedOperands(Operation *op) {
+ SetVector<Value> operands;
+ op->walk([&](Operation *nestedOp) {
+ for (Value operand : nestedOp->getOperands()) {
+ operands.insert(operand);
+ }
+ });
+ return operands;
+}
+
/// Compute unrolled cycles of each op (consumer) and verify that each op is
/// scheduled after its operands (producers) while adjusting for the distance
/// between producer and consumer.
@@ -219,7 +230,7 @@ bool LoopPipelinerInternal::verifySchedule() {
}
for (Operation *consumer : opOrder) {
int64_t consumerCycle = unrolledCyles[consumer];
- for (Value operand : consumer->getOperands()) {
+ for (Value operand : getNestedOperands(consumer)) {
auto [producer, distance] = getDefiningOpAndDistance(operand);
if (!producer)
continue;
@@ -245,9 +256,8 @@ static Operation *
cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
function_ref<void(OpOperand *newOperand)> callback) {
Operation *clone = rewriter.clone(*op);
- for (OpOperand &operand : clone->getOpOperands())
- callback(&operand);
- clone->walk([&](Operation *nested) {
+ clone->walk<WalkOrder::PreOrder>([&](Operation *nested) {
+ // 'clone' itself will be visited first.
for (OpOperand &operand : nested->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 8d6f454d187518..46e7feca4329ee 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -846,3 +846,26 @@ func.func @invalid_schedule2(%A: memref<?xf32>, %result: memref<?xf32>) {
} { __test_pipelining_loop__ }
return
}
+
+// -----
+
+func.func @invalid_schedule3(%A: memref<?xf32>, %result: memref<?xf32>, %ext: index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
+ %cnd = arith.cmpi slt, %ext, %c4 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : index
+ // expected-error at +1 {{operation scheduled before its operands}}
+ %idx1 = scf.if %cnd -> (index) {
+ %idxinc = arith.addi %idx, %c1 : index
+ scf.yield %idxinc : index
+ } else {
+ scf.yield %idx : index
+ } { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 }
+ %A_elem = memref.load %A[%idx1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ %idx2 = arith.addi %idx1, %c1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 3 } : index
+ memref.store %A_elem, %result[%idx1] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 4 } : memref<?xf32>
+ scf.yield %idx2 : index
+ } { __test_pipelining_loop__ }
+ return
+}
More information about the Mlir-commits
mailing list