[Mlir-commits] [mlir] MLIR: add flag to conditionally disable automatic dialect loading (PR #120100)
William Moses
llvmlistbot at llvm.org
Mon Dec 16 08:14:39 PST 2024
wsmoses wrote:
I was just about to flag this to you per @ftynse's suggestion :)
The broader context is I'm making a system for automatically compiling and inserting custom kernels into xla as native MLIR, see https://github.com/EnzymeAD/Enzyme-JAX/pull/191
So instead of an optimization barrier of a regular stablehlo.customcall to a magic unknown address you can have something like the following (which can be nvvm/gpu dialect/etc):
```
module {
llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} {
llvm.unreachable
}
llvm.func internal ptx_kernelcc @kern(%arg0: !llvm.ptr<1>) {
%0 = llvm.mlir.constant(63 : i32) : i32
%1 = nvvm.read.ptx.sreg.tid.x : i32
%2 = llvm.icmp "ugt" %1, %0 : i32
llvm.cond_br %2, ^bb2, ^bb1
^bb1: // pred: ^bb0
%4 = llvm.zext %1 : i32 to i64
%5 = llvm.getelementptr inbounds %arg0[%4] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64
%6 = llvm.load %5 {alignment = 1 : i64} : !llvm.ptr<1> -> i64
%7 = llvm.mul %6, %6 : i64
llvm.store %7, %5 {alignment = 1 : i64} : i64, !llvm.ptr<1>
llvm.return
^bb2: // pred: ^bb0
llvm.call fastcc @throw_boundserror_2676() : () -> ()
llvm.unreachable
}
func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> {
%c1 = stablehlo.constant dense<1> : tensor<i64>
%c40 = stablehlo.constant dense<40> : tensor<i64>
%0 = enzymexla.kernel_call @kern blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c40) (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
return %0 : tensor<64xi64>
}
}
```
Of course when you actually want to run this eventually someone needs to do the jit and make a pointer with the custom kernel. So that is what our LowerPass is trying to do. In particular, we want to run MLIR's GPU codegen, which is accessible via `mlir::gpu::buildLowerToNVVMPassPipeline(pm, options);` . And thus I end up calling a pass pipeline (on a new module I create as part of the pass), during another pass.
Unfortunately that quickly hits some dialect loading hell. The error I get on my machine (but not @ftynse's) is:
```
enzymexlamlir-opt: external/llvm-project/mlir/lib/IR/MLIRContext.cpp:423: void mlir::MLIRContext::appendDialectRegistry(const mlir::DialectRegistry&): Assertion `impl->multiThreadedExecutionContext == 0 && "appending to the MLIRContext dialect registry while in a " "multi-threaded execution context"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0. Program arguments: ./bazel-bin/enzymexlamlir-opt ./test/lit_tests/lowering/gpu.mlir --pass-pipeline=builtin.module(lower-kernel{jit=true})
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0 enzymexlamlir-opt 0x00005fdb53c6c5ca
1 enzymexlamlir-opt 0x00005fdb53c6c983
2 enzymexlamlir-opt 0x00005fdb53c69fa0
```
Where this is happening in this registration code of the pass pipeline, specifically for this I believe https://github.com/llvm/llvm-project/blob/d1a7225076218ce224cd29c74259b715b393dc9d/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp#L486 (which is weird as it should be loaded already).
I also tried doing the module for codegen in an entirely new MLIRContext [that wouldn't hit this same issue], but that hit other errors in dialect loading (specifically not finding a llvmir translation for llvm.mlir.addressof).
Since the "use the same context" approach works for Alex's machine and doesn't hit this issue (which I hit), I'm hoping that adding a flag to the passmanager to avoid making the double registration will help, or at least give a more informative error message than crashing.
https://github.com/llvm/llvm-project/pull/120100
More information about the Mlir-commits
mailing list