[Mlir-commits] [mlir] Allow bf16 operands on new MFMAs (PR #144925)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 19 09:36:56 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-backend-amdgpu

Author: Umang Yadav (umangyadav)

<details>
<summary>Changes</summary>

New gfx950 MFMA allows bf16 operands. 

https://github.com/llvm/llvm-project/blob/c0cc81cdc03c97473ba771bbc3a2330bd22396bc/llvm/include/llvm/IR/IntrinsicsAMDGPU.td#L3434 

Current logic converts bf16 to i16 always which fails to compile for newer bf16 MFMA e.g. `v_mfma_f32_16x16x32bf16`. 
Backend expects bf16 type for the operands for those newer MFMAs.  This patch fixes it.

---
Full diff: https://github.com/llvm/llvm-project/pull/144925.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+21-7) 
- (modified) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+2-2) 


``````````diff
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 074404add47f1..700563460f525 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -499,7 +499,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 /// and LLVM AMDGPU intrinsics convention.
 ///
 /// Specifically:
-/// 1. If the element type is bfloat16, bitcast it to i16.
+/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
+/// allows bf16. Newer MFMAs support bf16 types on operand, check
+/// IntrinsicsAMDGPU.td file for reference.
 /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
 /// instead, which is what the f8f6f4 intrinsics use.
 /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
@@ -509,10 +511,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 /// therefore 8-bit and smaller floats are represented as their corresponding
 /// `iN` integers.
 static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
-                                      Location loc, Value input) {
+                                      Location loc, Value input,
+                                      bool allowBf16 = true) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
-    if (vectorType.getElementType().isBF16())
+    if (vectorType.getElementType().isBF16() && !allowBf16)
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
     if (vectorType.getElementType().isInteger(8) &&
@@ -958,12 +961,23 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
 
     StringRef intrinsicName =
         isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
+    // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
+    // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
+    bool allowBf16 = [&]() {
+      if (chipset < kGfx950)
+        return false;
+      if (isScaled)
+        return true;
+      return intrinsicName.contains("16x16x32.bf16") ||
+             intrinsicName.contains("32x32x16.bf16");
+    }();
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
-    loweredOp.addOperands(
-        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
-         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
-         adaptor.getDestC()});
+    loweredOp.addOperands({convertMFMAVectorOperand(
+                               rewriter, loc, adaptor.getSourceA(), allowBf16),
+                           convertMFMAVectorOperand(
+                               rewriter, loc, adaptor.getSourceB(), allowBf16),
+                           adaptor.getDestC()});
     if (isScaled) {
       Value zero = createI32Constant(rewriter, loc, 0);
       auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index 52a5d39f668c6..39c31d5bf2fa3 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -11,9 +11,9 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
   amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
   // CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
   amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
-  // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
   amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
-  // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
   amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
   // CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
   amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/144925


More information about the Mlir-commits mailing list