[Mlir-commits] [mlir] [mlir][async] Avoid crash when not using `func.func` (PR #72801)
Rik Huijzer
llvmlistbot at llvm.org
Sun Nov 19 09:38:14 PST 2023
https://github.com/rikhuijzer created https://github.com/llvm/llvm-project/pull/72801
The `createParallelComputeFunction` crashed when calling `getFunctionTypeAttrName` during the creation of a new `FuncOp` inside the pass. The problem is that `getFunctionTypeAttrName` looks up the attribute name for the function type which in this case is `func.func`. However, `name.getAttributeNames()` was empty when clients used `llvm.func` instead of `func.func`.
To fix this, the `func` dialect is now registered as a dependent dialect. Also, I've added an assertion which could save other people some time.
>From 8abbf36f741c8363155e0f3cbf2450ff7f1f0801 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sun, 19 Nov 2023 18:31:38 +0100
Subject: [PATCH] [mlir][async] Avoid crash when not using `func.func`
---
.../Async/Transforms/AsyncParallelFor.cpp | 4 ++++
.../Async/async-parallel-for-compute-fn.mlir | 19 +++++++++++++++++++
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 2 ++
3 files changed, 25 insertions(+)
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 12a28c2e23b221a..639bc7f9ec7f112 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -102,6 +102,10 @@ struct AsyncParallelForPass
: public impl::AsyncParallelForBase<AsyncParallelForPass> {
AsyncParallelForPass() = default;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<async::AsyncDialect, func::FuncDialect>();
+ }
+
AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
int32_t minTaskSize) {
this->asyncDispatch = asyncDispatch;
diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index 2115b1881fa6d66..fa3b53dd839c6c6 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -69,6 +69,25 @@ func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
// -----
+// Smoke test that parallel for doesn't crash when func dialect is not loaded.
+
+// CHECK-LABEL: llvm.func @without_func_dialect()
+llvm.func @without_func_dialect() {
+ %cst = arith.constant 0.0 : f32
+
+ %c0 = arith.constant 0 : index
+ %c22 = arith.constant 22 : index
+ %c1 = arith.constant 1 : index
+ %54 = memref.alloc() : memref<22xf32>
+ %alloc_4 = memref.alloc() : memref<22xf32>
+ scf.parallel (%arg0) = (%c0) to (%c22) step (%c1) {
+ memref.store %cst, %alloc_4[%arg0] : memref<22xf32>
+ }
+ llvm.return
+}
+
+// -----
+
// Check that for statically known inner loop bound block size is aligned and
// inner loop uses statically known loop trip counts.
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 842964b853d084d..963c52fd4191657 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1143,6 +1143,8 @@ void OpEmitter::genAttrNameGetters() {
const char *const getAttrName = R"(
assert(index < {0} && "invalid attribute index");
assert(name.getStringRef() == getOperationName() && "invalid operation name");
+ assert(!name.getAttributeNames().empty() && "empty attribute names. Is a new "
+ "op created without having initialized its dialect?");
return name.getAttributeNames()[index];
)";
method->body() << formatv(getAttrName, attributes.size());
More information about the Mlir-commits
mailing list