[Mlir-commits] [mlir] [mlir][async] Avoid crash when not using `func.func` (PR #72801)

Rik Huijzer llvmlistbot at llvm.org
Sun Nov 19 09:38:14 PST 2023


https://github.com/rikhuijzer created https://github.com/llvm/llvm-project/pull/72801

The `createParallelComputeFunction` crashed when calling `getFunctionTypeAttrName` during the creation of a new `FuncOp` inside the pass. The problem is that `getFunctionTypeAttrName` looks up the attribute name for the function type which in this case is `func.func`. However, `name.getAttributeNames()` was empty when clients used `llvm.func` instead of `func.func`.

To fix this, the `func` dialect is now registered as a dependent dialect. Also, I've added an assertion which could save other people some time.

>From 8abbf36f741c8363155e0f3cbf2450ff7f1f0801 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sun, 19 Nov 2023 18:31:38 +0100
Subject: [PATCH] [mlir][async] Avoid crash when not using `func.func`

---
 .../Async/Transforms/AsyncParallelFor.cpp     |  4 ++++
 .../Async/async-parallel-for-compute-fn.mlir  | 19 +++++++++++++++++++
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   |  2 ++
 3 files changed, 25 insertions(+)

diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 12a28c2e23b221a..639bc7f9ec7f112 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -102,6 +102,10 @@ struct AsyncParallelForPass
     : public impl::AsyncParallelForBase<AsyncParallelForPass> {
   AsyncParallelForPass() = default;
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<async::AsyncDialect, func::FuncDialect>();
+  }
+
   AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
                        int32_t minTaskSize) {
     this->asyncDispatch = asyncDispatch;
diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index 2115b1881fa6d66..fa3b53dd839c6c6 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -69,6 +69,25 @@ func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
 
 // -----
 
+// Smoke test that parallel for doesn't crash when func dialect is not loaded.
+
+// CHECK-LABEL: llvm.func @without_func_dialect()
+llvm.func @without_func_dialect() {
+  %cst = arith.constant 0.0 : f32
+
+  %c0 = arith.constant 0 : index
+  %c22 = arith.constant 22 : index
+  %c1 = arith.constant 1 : index
+  %54 = memref.alloc() : memref<22xf32>
+  %alloc_4 = memref.alloc() : memref<22xf32>
+  scf.parallel (%arg0) = (%c0) to (%c22) step (%c1) {
+    memref.store %cst, %alloc_4[%arg0] : memref<22xf32>
+  }
+  llvm.return
+}
+
+// -----
+
 // Check that for statically known inner loop bound block size is aligned and
 // inner loop uses statically known loop trip counts.
 
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 842964b853d084d..963c52fd4191657 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1143,6 +1143,8 @@ void OpEmitter::genAttrNameGetters() {
       const char *const getAttrName = R"(
   assert(index < {0} && "invalid attribute index");
   assert(name.getStringRef() == getOperationName() && "invalid operation name");
+  assert(!name.getAttributeNames().empty() && "empty attribute names. Is a new "
+         "op created without having initialized its dialect?");
   return name.getAttributeNames()[index];
 )";
       method->body() << formatv(getAttrName, attributes.size());



More information about the Mlir-commits mailing list