[Mlir-commits] [mlir] [mlir][gpu] Add pass for emulating unsupported types. (PR #138087)
Md Abdullah Shahneous Bari
llvmlistbot at llvm.org
Mon May 5 16:15:01 PDT 2025
mshahneo wrote:
> I claim that this pass is not _necessary_ at the GPU level .
>
> This issue - where MLIR has a type that the backend doesn't have and it has to be lowered to bytes + special operations - is already handled a bunch in the LLVM case - see everything to do with the 8-bit floats. Yes, it makes the GPU to {LLVM, SPIR-V, what have you} lowering a bit more complicated, but for something like "oh, SPIR-V doesn't have a bfloat type and just calls it i16", the existing pattern is to handle that in the lowering.
>
> Secondly, this "make it look like bytes" change isn't generally applicable: only SPIR-V needs it. LLVM (the other target) can handle all this just fine.
>
> ... Also, `gpu.launch_func` should _already_ be comfortable with post-lowering mismatch, given that, in the LLVM pipeline, you get `memref<>` on the host side but `!llvm.ptr` on the device side. So I don't think that's a problem either ... and so you might not need any sort of `memref.bitcast` because the device-side code will be post-memref and so doesn't care. If I'm wrong about that, do let me know - I don't know the SPIR-V lowering all that well. That is, `memref<...xbf16>` will be some flavor of `!spirv.ptr` and you can just swap bf16 for i16 during that process.
>
> Now, passes that make this transformation more straightforward, like EmulateUnsupportedFloats, I'd be fine with, since it reduces the problem at lowering-time to arith.extf, arith.truncf, and any "native" operations like matmul primitives.
>
> If people strongly want a "replace bf16 memrefs with i16 memrefs" pass to add to that suite of preprocessing steps, we could do that, but ... it certainly shouldn't have all this linearization stuff that's cloned from the narrow type emulation (where it's absolutely needed as a pre-processing step because there's no good way to handle sub-byte addressing). This SPIR-V situation ... isn't that, unless I'm missing a whole pile of context.
>
> (That being said, `memref.reinterpret_elements` that'll let you take a `memref<...S x bf16>` into a `memref<...S x i16>` is something I'd be OK with, though I don't think we needed here.
Thanks a lot for your explanation :).
It turns out arith.bitcast actually supports memrefs. So, we don't really need a new memref element cast op for this case. Updated the PR. Please let me what you think.
https://github.com/llvm/llvm-project/pull/138087
More information about the Mlir-commits
mailing list