[Mlir-commits] [mlir] [mlir][scf] Fix scf.forall to scf.parallel pass walker (PR #95385)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 13 03:20:02 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Adds proper walk results to the pass body to prevent runtime crashes on transformation failure.

---
Full diff: https://github.com/llvm/llvm-project/pull/95385.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp (+2-1) 
- (modified) mlir/test/Dialect/SCF/forall-to-parallel.mlir (+18) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
index 44e6840b03a3d..925d4a3c0a085 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
@@ -71,8 +71,9 @@ struct ForallToParallelLoop final
 
     parentOp->walk([&](scf::ForallOp forallOp) {
       if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
-        return signalPassFailure();
+        return WalkResult::skip();
       }
+      return WalkResult::advance();
     });
   }
 };
diff --git a/mlir/test/Dialect/SCF/forall-to-parallel.mlir b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
index acde601d47259..21e816956a094 100644
--- a/mlir/test/Dialect/SCF/forall-to-parallel.mlir
+++ b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
@@ -78,3 +78,21 @@ func.func @mapping_attr() -> () {
   return
 
 }
+
+// -----
+
+// CHECK-LABEL: @forall_with_outputs
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x32xf32>
+func.func @forall_with_outputs(%arg0: tensor<32x32xf32>) -> tensor<8x112x32x32xf32> {
+  // CHECK-NOT: scf.parallel
+  // CHECK: %[[RES:.+]] = scf.forall{{.*}}shared_outs
+  // CHECK: return %[[RES]] : tensor<8x112x32x32xf32>
+
+  %0 = tensor.empty() : tensor<8x112x32x32xf32>
+  %1 = scf.forall (%arg1, %arg2) in (8, 112) shared_outs(%arg3 = %0) -> (tensor<8x112x32x32xf32>) {
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %arg0 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x112x32x32xf32>
+    }
+  }
+  return %1 : tensor<8x112x32x32xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/95385


More information about the Mlir-commits mailing list