[flang-commits] [llvm] [lld] [lldb] [compiler-rt] [mlir] [clang] [flang] [libcxx] [mlir][async] Avoid crash when not using `func.func` (PR #72801)

Rik Huijzer via flang-commits flang-commits at lists.llvm.org
Mon Nov 20 12:19:53 PST 2023


https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/72801

>From 8abbf36f741c8363155e0f3cbf2450ff7f1f0801 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sun, 19 Nov 2023 18:31:38 +0100
Subject: [PATCH 1/3] [mlir][async] Avoid crash when not using `func.func`

---
 .../Async/Transforms/AsyncParallelFor.cpp     |  4 ++++
 .../Async/async-parallel-for-compute-fn.mlir  | 19 +++++++++++++++++++
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   |  2 ++
 3 files changed, 25 insertions(+)

diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 12a28c2e23b221a..639bc7f9ec7f112 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -102,6 +102,10 @@ struct AsyncParallelForPass
     : public impl::AsyncParallelForBase<AsyncParallelForPass> {
   AsyncParallelForPass() = default;
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<async::AsyncDialect, func::FuncDialect>();
+  }
+
   AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
                        int32_t minTaskSize) {
     this->asyncDispatch = asyncDispatch;
diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index 2115b1881fa6d66..fa3b53dd839c6c6 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -69,6 +69,25 @@ func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
 
 // -----
 
+// Smoke test that parallel for doesn't crash when func dialect is not loaded.
+
+// CHECK-LABEL: llvm.func @without_func_dialect()
+llvm.func @without_func_dialect() {
+  %cst = arith.constant 0.0 : f32
+
+  %c0 = arith.constant 0 : index
+  %c22 = arith.constant 22 : index
+  %c1 = arith.constant 1 : index
+  %54 = memref.alloc() : memref<22xf32>
+  %alloc_4 = memref.alloc() : memref<22xf32>
+  scf.parallel (%arg0) = (%c0) to (%c22) step (%c1) {
+    memref.store %cst, %alloc_4[%arg0] : memref<22xf32>
+  }
+  llvm.return
+}
+
+// -----
+
 // Check that for statically known inner loop bound block size is aligned and
 // inner loop uses statically known loop trip counts.
 
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 842964b853d084d..963c52fd4191657 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1143,6 +1143,8 @@ void OpEmitter::genAttrNameGetters() {
       const char *const getAttrName = R"(
   assert(index < {0} && "invalid attribute index");
   assert(name.getStringRef() == getOperationName() && "invalid operation name");
+  assert(!name.getAttributeNames().empty() && "empty attribute names. Is a new "
+         "op created without having initialized its dialect?");
   return name.getAttributeNames()[index];
 )";
       method->body() << formatv(getAttrName, attributes.size());

>From eb09cc895d7d1c08f745df22345cd0fae5432c7a Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 20 Nov 2023 19:23:49 +0100
Subject: [PATCH 2/3] Declare dependentDialects in `Passes.td`

---
 mlir/include/mlir/Dialect/Async/Passes.td              | 1 +
 mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp | 4 ----
 2 files changed, 1 insertion(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index c7ee4ba39aecdf0..f0ef83ca3fd4f1a 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -36,6 +36,7 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
   let dependentDialects = [
     "arith::ArithDialect",
     "async::AsyncDialect",
+    "func::FuncDialect",
     "scf::SCFDialect"
   ];
 }
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 639bc7f9ec7f112..12a28c2e23b221a 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -102,10 +102,6 @@ struct AsyncParallelForPass
     : public impl::AsyncParallelForBase<AsyncParallelForPass> {
   AsyncParallelForPass() = default;
 
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<async::AsyncDialect, func::FuncDialect>();
-  }
-
   AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
                        int32_t minTaskSize) {
     this->asyncDispatch = asyncDispatch;

>From 77ba982eba8f7511543e9e06864a15c839feece8 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 20 Nov 2023 21:19:37 +0100
Subject: [PATCH 3/3] Update assertion

---
 mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir | 2 +-
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp                | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index fa3b53dd839c6c6..6f068c0e8d74cc7 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -69,7 +69,7 @@ func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
 
 // -----
 
-// Smoke test that parallel for doesn't crash when func dialect is not loaded.
+// Smoke test that parallel for doesn't crash when func dialect is not used.
 
 // CHECK-LABEL: llvm.func @without_func_dialect()
 llvm.func @without_func_dialect() {
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 963c52fd4191657..57392434285ff89 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1143,8 +1143,8 @@ void OpEmitter::genAttrNameGetters() {
       const char *const getAttrName = R"(
   assert(index < {0} && "invalid attribute index");
   assert(name.getStringRef() == getOperationName() && "invalid operation name");
-  assert(!name.getAttributeNames().empty() && "empty attribute names. Is a new "
-         "op created without having initialized its dialect?");
+  assert(name.isRegistered() && "Operation isn't registered, missing a "
+        "dependent dialect loading?");
   return name.getAttributeNames()[index];
 )";
       method->body() << formatv(getAttrName, attributes.size());



More information about the flang-commits mailing list