[Mlir-commits] [mlir] [mlir][scf] Fix scf.forall to scf.parallel pass walker (PR #95385)

Adam Siemieniuk llvmlistbot at llvm.org
Thu Jun 13 03:19:30 PDT 2024


https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/95385

Adds proper walk results to the pass body to prevent runtime crashes on transformation failure.

>From 9b88b28f92f16b10b670579e668fa02402899a5a Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 13 Jun 2024 12:11:48 +0200
Subject: [PATCH] [mlir][scf] Fix scf.forall to scf.parallel pass walker

Adds proper walk results to the pass body to prevent runtime crashes
on transformation failure.
---
 .../SCF/Transforms/ForallToParallel.cpp        |  3 ++-
 mlir/test/Dialect/SCF/forall-to-parallel.mlir  | 18 ++++++++++++++++++
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
index 44e6840b03a3d..925d4a3c0a085 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
@@ -71,8 +71,9 @@ struct ForallToParallelLoop final
 
     parentOp->walk([&](scf::ForallOp forallOp) {
       if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
-        return signalPassFailure();
+        return WalkResult::skip();
       }
+      return WalkResult::advance();
     });
   }
 };
diff --git a/mlir/test/Dialect/SCF/forall-to-parallel.mlir b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
index acde601d47259..21e816956a094 100644
--- a/mlir/test/Dialect/SCF/forall-to-parallel.mlir
+++ b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
@@ -78,3 +78,21 @@ func.func @mapping_attr() -> () {
   return
 
 }
+
+// -----
+
+// CHECK-LABEL: @forall_with_outputs
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x32xf32>
+func.func @forall_with_outputs(%arg0: tensor<32x32xf32>) -> tensor<8x112x32x32xf32> {
+  // CHECK-NOT: scf.parallel
+  // CHECK: %[[RES:.+]] = scf.forall{{.*}}shared_outs
+  // CHECK: return %[[RES]] : tensor<8x112x32x32xf32>
+
+  %0 = tensor.empty() : tensor<8x112x32x32xf32>
+  %1 = scf.forall (%arg1, %arg2) in (8, 112) shared_outs(%arg3 = %0) -> (tensor<8x112x32x32xf32>) {
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %arg0 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x112x32x32xf32>
+    }
+  }
+  return %1 : tensor<8x112x32x32xf32>
+}



More information about the Mlir-commits mailing list