[Mlir-commits] [mlir] 88f0e4c - [mlir][async] Avoid crash when not using `func.func` (#72801)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 20 23:40:21 PST 2023
Author: Rik Huijzer
Date: 2023-11-21T08:40:16+01:00
New Revision: 88f0e4c75c1ac498f2223fc640c4ff6c572c5ed1
URL: https://github.com/llvm/llvm-project/commit/88f0e4c75c1ac498f2223fc640c4ff6c572c5ed1
DIFF: https://github.com/llvm/llvm-project/commit/88f0e4c75c1ac498f2223fc640c4ff6c572c5ed1.diff
LOG: [mlir][async] Avoid crash when not using `func.func` (#72801)
The `createParallelComputeFunction` crashed when calling
`getFunctionTypeAttrName` during the creation of a new `FuncOp` inside
the pass. The problem is that `getFunctionTypeAttrName` looks up the
attribute name for the function type which in this case is `func.func`.
However, `name.getAttributeNames()` was empty when clients used
`llvm.func` instead of `func.func`.
To fix this, the `func` dialect is now registered as a dependent
dialect. Also, I've added an assertion which could save other people
some time.
Fixes #71281, fixes #64326.
Added:
Modified:
mlir/include/mlir/Dialect/Async/Passes.td
mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index c7ee4ba39aecdf0..f0ef83ca3fd4f1a 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -36,6 +36,7 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
let dependentDialects = [
"arith::ArithDialect",
"async::AsyncDialect",
+ "func::FuncDialect",
"scf::SCFDialect"
];
}
diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index 2115b1881fa6d66..6f068c0e8d74cc7 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -69,6 +69,25 @@ func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
// -----
+// Smoke test that parallel for doesn't crash when func dialect is not used.
+
+// CHECK-LABEL: llvm.func @without_func_dialect()
+llvm.func @without_func_dialect() {
+ %cst = arith.constant 0.0 : f32
+
+ %c0 = arith.constant 0 : index
+ %c22 = arith.constant 22 : index
+ %c1 = arith.constant 1 : index
+ %54 = memref.alloc() : memref<22xf32>
+ %alloc_4 = memref.alloc() : memref<22xf32>
+ scf.parallel (%arg0) = (%c0) to (%c22) step (%c1) {
+ memref.store %cst, %alloc_4[%arg0] : memref<22xf32>
+ }
+ llvm.return
+}
+
+// -----
+
// Check that for statically known inner loop bound block size is aligned and
// inner loop uses statically known loop trip counts.
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 842964b853d084d..57392434285ff89 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1143,6 +1143,8 @@ void OpEmitter::genAttrNameGetters() {
const char *const getAttrName = R"(
assert(index < {0} && "invalid attribute index");
assert(name.getStringRef() == getOperationName() && "invalid operation name");
+ assert(name.isRegistered() && "Operation isn't registered, missing a "
+ "dependent dialect loading?");
return name.getAttributeNames()[index];
)";
method->body() << formatv(getAttrName, attributes.size());
More information about the Mlir-commits
mailing list