[Mlir-commits] [mlir] [mlir][gpu] Add pass for emulating unsupported types. (PR #138087)

Krzysztof Drewniak llvmlistbot at llvm.org
Fri May 2 14:20:04 PDT 2025


krzysz00 wrote:

I claim that this pass is not _necessary_ at the GPU level .

This issue - where MLIR has a type that the backend doesn't have and it has to be lowered to bytes + special operations - is already handled a bunch in the LLVM case - see everything to do with the 8-bit floats. Yes, it makes the GPU to {LLVM, SPIR-V, what have you} lowering a bit more complicated, but for something like "oh, SPIR-V doesn't have a bfloat type and just calls it i16", the existing pattern is to handle that in the lowering.

Secondly, this "make it look like bytes" change isn't generally applicable: only SPIR-V needs it.  LLVM (the other target) can handle all this just fine. 

... Also, `gpu.launch_func` should *already* be comfortable with post-lowering mismatch, given that, in the LLVM pipeline, you get `memref<>` on the host side but `!llvm.ptr` on the device side. So I don't think that's a problem either ... and so you might not need any sort of `memref.bitcast` because the device-side code will be post-memref and so doesn't care. If I'm wrong about that, do let me know - I don't know the SPIR-V lowering all that well. That is, `memref<...xbf16>` will be some flavor of `!spirv.ptr` and you can just swap bf16 for i16 during that process.

Now, passes that make this transformation more straightforward, like EmulateUnsupportedFloats, I'd be fine with, since it reduces the problem at lowering-time to arith.extf, arith.truncf, and any "native" operations like matmul primitives.

If people strongly want a "replace bf16 memrefs with i16 memrefs" pass to add to that suite of preprocessing steps, we could do that, but ... it certainly shouldn't have all this linearization stuff that's cloned from the narrow type emulation (where it's absolutely needed as a pre-processing step because there's no good way to handle sub-byte addressing). This SPIR-V situation ... isn't that, unless I'm missing a whole pile of context.

(That being said, `memref.reinterpret_elements` that'll let you take a `memref<...S x bf16>` into a `memref<...S x i16>` is something I'd be OK with, though I don't think we needed here.

https://github.com/llvm/llvm-project/pull/138087


More information about the Mlir-commits mailing list