[Mlir-commits] [mlir] d03f35f - [MLIR][NVVM] Fix the datatype error for nvvm.mma.sync when the operand is bf16 (#122664)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 13 01:33:09 PST 2025
Author: xiaoleis-nv
Date: 2025-01-13T15:03:05+05:30
New Revision: d03f35f9b6d031d6a9375d90ccf7cc285f8e4b79
URL: https://github.com/llvm/llvm-project/commit/d03f35f9b6d031d6a9375d90ccf7cc285f8e4b79
DIFF: https://github.com/llvm/llvm-project/commit/d03f35f9b6d031d6a9375d90ccf7cc285f8e4b79.diff
LOG: [MLIR][NVVM] Fix the datatype error for nvvm.mma.sync when the operand is bf16 (#122664)
The PR fixes the datatype error for `nvvm.mma.sync` when the operand is
`bf16`. This operation originally requires the A/B type to be `f16x2`
for the `bf16` MMA. However, it violates the NVVM intrinsic
[[here](https://github.com/xiaoleis-nv/llvm-project/blob/372044ee09d39942925824f8f335aef40bfe92f0/llvm/include/llvm/IR/IntrinsicsNVVM.td#L119)],
where the A/B operand type should be `i32`. This is a bug, and there are
no tests in MLIR that cover this datatype.
```
// mma bf16 -> s32 @ m16n8k16/m16n8k8
!eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
```
This PR addresses this bug and adds tests to guarantee correctness.
Co-authored-by: Xiaolei Shi <xiaoleis at nvidia.com>
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0b9097e9bbca2c..04042903e343ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1699,8 +1699,8 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
| f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
| | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
| | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
- | bf16 | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
- | | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+ | bf16 | .m16n8k8 | row | col | 2x i32 | 1x i32 | 4x f32 |
+ | | .m16n8k16 | row | col | 4x i32 | 2x i32 | 4x f32 |
| tf32 | .m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 |
| | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 |
| u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 |
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 838159d676545d..d8fde3e765ac49 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -445,8 +445,13 @@ LogicalResult MmaOp::verify() {
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
break;
- case MMATypes::f16:
case MMATypes::bf16:
+ kFactor = 8;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ break;
+ case MMATypes::f16:
kFactor = 8;
multiplicandFragType = f16x2Ty;
expectedResult.push_back(f16x2x2StructTy);
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index a7bdceba01c1e8..4c3b6648a41c00 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -163,6 +163,29 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
}
+// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16
+func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
// CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
%c0 : i32, %c1 : i32) {
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 2d7710e7cbf279..09e98765413f0c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -291,6 +291,18 @@ llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.bf16
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
// f32 return type, f16 accumulate type
// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
More information about the Mlir-commits
mailing list