[Mlir-commits] [mlir] [MLIR][NVVM] Fix the datatype error for nvvm.mma.sync when the operand is bf16 (PR #122664)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 12 20:13:59 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: None (xiaoleis-nv)

<details>
<summary>Changes</summary>

The PR fixes the datatype error for `nvvm.mma.sync` when the operand is `bf16`. This operation originally requires the A/B type to be `f16x2` for the `bf16` MMA. However, it violates the NVVM intrinsic [[here](https://github.com/xiaoleis-nv/llvm-project/blob/372044ee09d39942925824f8f335aef40bfe92f0/llvm/include/llvm/IR/IntrinsicsNVVM.td#L119)], where the A/B operand type should be `i32`. This is a bug, and there are no tests in MLIR that cover this datatype.

```
    // mma bf16 -> s32 @ m16n8k16/m16n8k8
    !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
    !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
```

This PR addresses this bug and adds tests to guarantee correctness.

---
Full diff: https://github.com/llvm/llvm-project/pull/122664.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+2-2) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+6-1) 
- (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+23) 
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+12) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0b9097e9bbca2c..04042903e343ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1699,8 +1699,8 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
     | f16      | .m8n8k4   | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
     |          | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
     |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
-    | bf16     | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
-    |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+    | bf16     | .m16n8k8  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
+    |          | .m16n8k16 | row     | col     | 4x i32   | 2x i32   | 4x f32            |
     | tf32     | .m16n8k4  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
     |          | .m16n8k8  | row     | col     | 4x i32   | 2x i32   | 2x f16x2 or 4 f32 |
     | u8/s8    | .m8n8k16  | row     | col     | 1x i32   | 1x i32   | 2x i32            |
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 838159d676545d..d8fde3e765ac49 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -445,8 +445,13 @@ LogicalResult MmaOp::verify() {
       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
       break;
-    case MMATypes::f16:
     case MMATypes::bf16:
+      kFactor = 8;
+      multiplicandFragType = i32Ty;
+      expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+          context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+      break;
+    case MMATypes::f16:
       kFactor = 8;
       multiplicandFragType = f16x2Ty;
       expectedResult.push_back(f16x2x2StructTy);
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index a7bdceba01c1e8..4c3b6648a41c00 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -163,6 +163,29 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16
+func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+                              %b0 : i32, %b1 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
 // CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
 func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
                              %c0 : i32, %c1 : i32) {
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 2d7710e7cbf279..09e98765413f0c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -291,6 +291,18 @@ llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+                              %b0 : i32, %b1 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+  // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.bf16
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
 // f32 return type, f16 accumulate type
 // CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
 llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/122664


More information about the Mlir-commits mailing list