[flang-commits] [llvm] [lld] [lldb] [compiler-rt] [mlir] [clang] [flang] [libcxx] [mlir][async] Avoid crash when not using `func.func` (PR #72801)
Rik Huijzer via flang-commits
flang-commits at lists.llvm.org
Mon Nov 20 12:19:53 PST 2023
https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/72801
>From 8abbf36f741c8363155e0f3cbf2450ff7f1f0801 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sun, 19 Nov 2023 18:31:38 +0100
Subject: [PATCH 1/3] [mlir][async] Avoid crash when not using `func.func`
---
.../Async/Transforms/AsyncParallelFor.cpp | 4 ++++
.../Async/async-parallel-for-compute-fn.mlir | 19 +++++++++++++++++++
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 2 ++
3 files changed, 25 insertions(+)
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 12a28c2e23b221a..639bc7f9ec7f112 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -102,6 +102,10 @@ struct AsyncParallelForPass
: public impl::AsyncParallelForBase<AsyncParallelForPass> {
AsyncParallelForPass() = default;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<async::AsyncDialect, func::FuncDialect>();
+ }
+
AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
int32_t minTaskSize) {
this->asyncDispatch = asyncDispatch;
diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index 2115b1881fa6d66..fa3b53dd839c6c6 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -69,6 +69,25 @@ func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
// -----
+// Smoke test that parallel for doesn't crash when func dialect is not loaded.
+
+// CHECK-LABEL: llvm.func @without_func_dialect()
+llvm.func @without_func_dialect() {
+ %cst = arith.constant 0.0 : f32
+
+ %c0 = arith.constant 0 : index
+ %c22 = arith.constant 22 : index
+ %c1 = arith.constant 1 : index
+ %54 = memref.alloc() : memref<22xf32>
+ %alloc_4 = memref.alloc() : memref<22xf32>
+ scf.parallel (%arg0) = (%c0) to (%c22) step (%c1) {
+ memref.store %cst, %alloc_4[%arg0] : memref<22xf32>
+ }
+ llvm.return
+}
+
+// -----
+
// Check that for statically known inner loop bound block size is aligned and
// inner loop uses statically known loop trip counts.
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 842964b853d084d..963c52fd4191657 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1143,6 +1143,8 @@ void OpEmitter::genAttrNameGetters() {
const char *const getAttrName = R"(
assert(index < {0} && "invalid attribute index");
assert(name.getStringRef() == getOperationName() && "invalid operation name");
+ assert(!name.getAttributeNames().empty() && "empty attribute names. Is a new "
+ "op created without having initialized its dialect?");
return name.getAttributeNames()[index];
)";
method->body() << formatv(getAttrName, attributes.size());
>From eb09cc895d7d1c08f745df22345cd0fae5432c7a Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 20 Nov 2023 19:23:49 +0100
Subject: [PATCH 2/3] Declare dependentDialects in `Passes.td`
---
mlir/include/mlir/Dialect/Async/Passes.td | 1 +
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp | 4 ----
2 files changed, 1 insertion(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index c7ee4ba39aecdf0..f0ef83ca3fd4f1a 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -36,6 +36,7 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
let dependentDialects = [
"arith::ArithDialect",
"async::AsyncDialect",
+ "func::FuncDialect",
"scf::SCFDialect"
];
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 639bc7f9ec7f112..12a28c2e23b221a 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -102,10 +102,6 @@ struct AsyncParallelForPass
: public impl::AsyncParallelForBase<AsyncParallelForPass> {
AsyncParallelForPass() = default;
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<async::AsyncDialect, func::FuncDialect>();
- }
-
AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
int32_t minTaskSize) {
this->asyncDispatch = asyncDispatch;
>From 77ba982eba8f7511543e9e06864a15c839feece8 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 20 Nov 2023 21:19:37 +0100
Subject: [PATCH 3/3] Update assertion
---
mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir | 2 +-
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index fa3b53dd839c6c6..6f068c0e8d74cc7 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -69,7 +69,7 @@ func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
// -----
-// Smoke test that parallel for doesn't crash when func dialect is not loaded.
+// Smoke test that parallel for doesn't crash when func dialect is not used.
// CHECK-LABEL: llvm.func @without_func_dialect()
llvm.func @without_func_dialect() {
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 963c52fd4191657..57392434285ff89 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1143,8 +1143,8 @@ void OpEmitter::genAttrNameGetters() {
const char *const getAttrName = R"(
assert(index < {0} && "invalid attribute index");
assert(name.getStringRef() == getOperationName() && "invalid operation name");
- assert(!name.getAttributeNames().empty() && "empty attribute names. Is a new "
- "op created without having initialized its dialect?");
+ assert(name.isRegistered() && "Operation isn't registered, missing a "
+ "dependent dialect loading?");
return name.getAttributeNames()[index];
)";
method->body() << formatv(getAttrName, attributes.size());
More information about the flang-commits
mailing list