[Mlir-commits] [mlir] d03f35f - [MLIR][NVVM] Fix the datatype error for nvvm.mma.sync when the operand is bf16 (#122664)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 13 01:33:09 PST 2025


Author: xiaoleis-nv
Date: 2025-01-13T15:03:05+05:30
New Revision: d03f35f9b6d031d6a9375d90ccf7cc285f8e4b79

URL: https://github.com/llvm/llvm-project/commit/d03f35f9b6d031d6a9375d90ccf7cc285f8e4b79
DIFF: https://github.com/llvm/llvm-project/commit/d03f35f9b6d031d6a9375d90ccf7cc285f8e4b79.diff

LOG: [MLIR][NVVM] Fix the datatype error for nvvm.mma.sync when the operand is bf16 (#122664)

The PR fixes the datatype error for `nvvm.mma.sync` when the operand is
`bf16`. This operation originally requires the A/B type to be `f16x2`
for the `bf16` MMA. However, it violates the NVVM intrinsic
[[here](https://github.com/xiaoleis-nv/llvm-project/blob/372044ee09d39942925824f8f335aef40bfe92f0/llvm/include/llvm/IR/IntrinsicsNVVM.td#L119)],
where the A/B operand type should be `i32`. This is a bug, and there are
no tests in MLIR that cover this datatype.

```
    // mma bf16 -> s32 @ m16n8k16/m16n8k8
    !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
    !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
```

This PR addresses this bug and adds tests to guarantee correctness.

Co-authored-by: Xiaolei Shi <xiaoleis at nvidia.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0b9097e9bbca2c..04042903e343ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1699,8 +1699,8 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
     | f16      | .m8n8k4   | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
     |          | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
     |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
-    | bf16     | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
-    |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+    | bf16     | .m16n8k8  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
+    |          | .m16n8k16 | row     | col     | 4x i32   | 2x i32   | 4x f32            |
     | tf32     | .m16n8k4  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
     |          | .m16n8k8  | row     | col     | 4x i32   | 2x i32   | 2x f16x2 or 4 f32 |
     | u8/s8    | .m8n8k16  | row     | col     | 1x i32   | 1x i32   | 2x i32            |

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 838159d676545d..d8fde3e765ac49 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -445,8 +445,13 @@ LogicalResult MmaOp::verify() {
       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
       break;
-    case MMATypes::f16:
     case MMATypes::bf16:
+      kFactor = 8;
+      multiplicandFragType = i32Ty;
+      expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+          context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+      break;
+    case MMATypes::f16:
       kFactor = 8;
       multiplicandFragType = f16x2Ty;
       expectedResult.push_back(f16x2x2StructTy);

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index a7bdceba01c1e8..4c3b6648a41c00 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -163,6 +163,29 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16
+func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+                              %b0 : i32, %b1 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
 // CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
 func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
                              %c0 : i32, %c1 : i32) {

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 2d7710e7cbf279..09e98765413f0c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -291,6 +291,18 @@ llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+                              %b0 : i32, %b1 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+  // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.bf16
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
 // f32 return type, f16 accumulate type
 // CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
 llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,


        


More information about the Mlir-commits mailing list