[Mlir-commits] [mlir] MLIR: add flag to conditionally disable automatic dialect loading (PR #120100)

William Moses llvmlistbot at llvm.org
Mon Dec 16 08:14:39 PST 2024


wsmoses wrote:

I was just about to flag this to you per @ftynse's suggestion :)

The broader context is I'm making a system for automatically compiling and inserting custom kernels into xla as native MLIR, see https://github.com/EnzymeAD/Enzyme-JAX/pull/191

So instead of an optimization barrier of a regular stablehlo.customcall to a magic unknown address you can have something like the following (which can be nvvm/gpu dialect/etc):

```
module {
  llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} {
    llvm.unreachable
  }
  llvm.func internal ptx_kernelcc @kern(%arg0: !llvm.ptr<1>) {
    %0 = llvm.mlir.constant(63 : i32) : i32
    %1 = nvvm.read.ptx.sreg.tid.x : i32
    %2 = llvm.icmp "ugt" %1, %0 : i32
    llvm.cond_br %2, ^bb2, ^bb1
  ^bb1:  // pred: ^bb0
    %4 = llvm.zext %1 : i32 to i64
    %5 = llvm.getelementptr inbounds %arg0[%4] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64
    %6 = llvm.load %5 {alignment = 1 : i64} : !llvm.ptr<1> -> i64
    %7 = llvm.mul %6, %6 : i64
    llvm.store %7, %5 {alignment = 1 : i64} : i64, !llvm.ptr<1>
    llvm.return
  ^bb2:  // pred: ^bb0
    llvm.call fastcc @throw_boundserror_2676() : () -> ()
    llvm.unreachable
  }
  func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> {
    %c1 = stablehlo.constant dense<1> : tensor<i64>
    %c40 = stablehlo.constant dense<40> : tensor<i64>
    %0 = enzymexla.kernel_call @kern blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c40) (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
    return %0 : tensor<64xi64>
  }
}
```

Of course when you actually want to run this eventually someone needs to do the jit and make a pointer with the custom kernel. So that is what our LowerPass is trying to do. In particular, we want to run MLIR's GPU codegen, which is accessible via  `mlir::gpu::buildLowerToNVVMPassPipeline(pm, options);` . And thus I end up calling a pass pipeline (on a new module I create as part of the pass),  during another pass.

Unfortunately that quickly hits some dialect loading hell. The error I get on my machine (but not @ftynse's) is:
```
enzymexlamlir-opt: external/llvm-project/mlir/lib/IR/MLIRContext.cpp:423: void mlir::MLIRContext::appendDialectRegistry(const mlir::DialectRegistry&): Assertion `impl->multiThreadedExecutionContext == 0 && "appending to the MLIRContext dialect registry while in a " "multi-threaded execution context"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: ./bazel-bin/enzymexlamlir-opt ./test/lit_tests/lowering/gpu.mlir --pass-pipeline=builtin.module(lower-kernel{jit=true})
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  enzymexlamlir-opt 0x00005fdb53c6c5ca
1  enzymexlamlir-opt 0x00005fdb53c6c983
2  enzymexlamlir-opt 0x00005fdb53c69fa0
```

Where this is happening in this registration code of the pass pipeline, specifically for this I believe https://github.com/llvm/llvm-project/blob/d1a7225076218ce224cd29c74259b715b393dc9d/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp#L486 (which is weird as it should be loaded already).

I also tried doing the module for codegen in an entirely new MLIRContext [that wouldn't hit this same issue], but that hit other errors in dialect loading (specifically not finding a llvmir translation for llvm.mlir.addressof).

Since the "use the same context" approach works for Alex's machine and doesn't hit this issue (which I hit), I'm hoping that adding a flag to the passmanager to avoid making the double registration will help, or at least give a more informative error message than crashing.

https://github.com/llvm/llvm-project/pull/120100


More information about the Mlir-commits mailing list