[Mlir-commits] [mlir] [mlir][scf] Fix scf.forall to scf.parallel pass walker (PR #95385)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 13 03:20:02 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Adds proper walk results to the pass body to prevent runtime crashes on transformation failure.
---
Full diff: https://github.com/llvm/llvm-project/pull/95385.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp (+2-1)
- (modified) mlir/test/Dialect/SCF/forall-to-parallel.mlir (+18)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
index 44e6840b03a3d..925d4a3c0a085 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
@@ -71,8 +71,9 @@ struct ForallToParallelLoop final
parentOp->walk([&](scf::ForallOp forallOp) {
if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
- return signalPassFailure();
+ return WalkResult::skip();
}
+ return WalkResult::advance();
});
}
};
diff --git a/mlir/test/Dialect/SCF/forall-to-parallel.mlir b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
index acde601d47259..21e816956a094 100644
--- a/mlir/test/Dialect/SCF/forall-to-parallel.mlir
+++ b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
@@ -78,3 +78,21 @@ func.func @mapping_attr() -> () {
return
}
+
+// -----
+
+// CHECK-LABEL: @forall_with_outputs
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x32xf32>
+func.func @forall_with_outputs(%arg0: tensor<32x32xf32>) -> tensor<8x112x32x32xf32> {
+ // CHECK-NOT: scf.parallel
+ // CHECK: %[[RES:.+]] = scf.forall{{.*}}shared_outs
+ // CHECK: return %[[RES]] : tensor<8x112x32x32xf32>
+
+ %0 = tensor.empty() : tensor<8x112x32x32xf32>
+ %1 = scf.forall (%arg1, %arg2) in (8, 112) shared_outs(%arg3 = %0) -> (tensor<8x112x32x32xf32>) {
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %arg0 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x112x32x32xf32>
+ }
+ }
+ return %1 : tensor<8x112x32x32xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/95385
More information about the Mlir-commits
mailing list