[Mlir-commits] [mlir] Allow bf16 operands on new MFMAs (PR #144925)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 19 09:36:56 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-backend-amdgpu
Author: Umang Yadav (umangyadav)
<details>
<summary>Changes</summary>
New gfx950 MFMA allows bf16 operands.
https://github.com/llvm/llvm-project/blob/c0cc81cdc03c97473ba771bbc3a2330bd22396bc/llvm/include/llvm/IR/IntrinsicsAMDGPU.td#L3434
Current logic converts bf16 to i16 always which fails to compile for newer bf16 MFMA e.g. `v_mfma_f32_16x16x32bf16`.
Backend expects bf16 type for the operands for those newer MFMAs. This patch fixes it.
---
Full diff: https://github.com/llvm/llvm-project/pull/144925.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+21-7)
- (modified) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+2-2)
``````````diff
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 074404add47f1..700563460f525 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -499,7 +499,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
/// and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
-/// 1. If the element type is bfloat16, bitcast it to i16.
+/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
+/// allows bf16. Newer MFMAs support bf16 types on operand, check
+/// IntrinsicsAMDGPU.td file for reference.
/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
/// instead, which is what the f8f6f4 intrinsics use.
/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
@@ -509,10 +511,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
- Location loc, Value input) {
+ Location loc, Value input,
+ bool allowBf16 = true) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
- if (vectorType.getElementType().isBF16())
+ if (vectorType.getElementType().isBF16() && !allowBf16)
return rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), input);
if (vectorType.getElementType().isInteger(8) &&
@@ -958,12 +961,23 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
StringRef intrinsicName =
isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
+ // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
+ // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
+ bool allowBf16 = [&]() {
+ if (chipset < kGfx950)
+ return false;
+ if (isScaled)
+ return true;
+ return intrinsicName.contains("16x16x32.bf16") ||
+ intrinsicName.contains("32x32x16.bf16");
+ }();
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
- loweredOp.addOperands(
- {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
- convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
- adaptor.getDestC()});
+ loweredOp.addOperands({convertMFMAVectorOperand(
+ rewriter, loc, adaptor.getSourceA(), allowBf16),
+ convertMFMAVectorOperand(
+ rewriter, loc, adaptor.getSourceB(), allowBf16),
+ adaptor.getDestC()});
if (isScaled) {
Value zero = createI32Constant(rewriter, loc, 0);
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index 52a5d39f668c6..39c31d5bf2fa3 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -11,9 +11,9 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
// CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/144925
More information about the Mlir-commits
mailing list