[Mlir-commits] [mlir] 22c6e7b - [mlir][nvvm] Fix support for tf32 data type in mma.sync
Christopher Bate
llvmlistbot at llvm.org
Thu May 5 10:03:45 PDT 2022
Author: Christopher Bate
Date: 2022-05-05T11:02:03-06:00
New Revision: 22c6e7b277fbe6c65216d8c7a690d53c8122cb42
URL: https://github.com/llvm/llvm-project/commit/22c6e7b277fbe6c65216d8c7a690d53c8122cb42
DIFF: https://github.com/llvm/llvm-project/commit/22c6e7b277fbe6c65216d8c7a690d53c8122cb42.diff
LOG: [mlir][nvvm] Fix support for tf32 data type in mma.sync
The NVVM dialect test coverage for all possible type/shape combinations
in the `nvvm.mma.sync` op is mostly complete. However, there were tests
missing for TF32 datatype support. This change adds tests for the one
relevant shape/type combination. This uncovered a small bug in the op
verifier, which this change also fixes.
Differential Revision: https://reviews.llvm.org/D124975
Added:
Modified:
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 31683b47e59227..345d90044b4c22 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -81,8 +81,10 @@ Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
return NVVM::MMATypes::f64;
if (operandElType.isF16() || operandElType == half2Type)
return NVVM::MMATypes::f16;
- if (operandElType.isF32())
+ if (operandElType.isF32() && isAccumulator)
return NVVM::MMATypes::f32;
+ if (operandElType.isF32() && !isAccumulator)
+ return NVVM::MMATypes::tf32;
if (operandElType.isa<IntegerType>()) {
if (isAccumulator)
return NVVM::MMATypes::s32;
@@ -291,7 +293,7 @@ ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getNameLoc(),
"expected one type for each operand segment but got " +
Twine(operandTypes.size()) + " types");
- for (const auto& iter : llvm::enumerate(operandTypes)) {
+ for (const auto &iter : llvm::enumerate(operandTypes)) {
auto &frag = frags[iter.index()];
frag.regTypes.resize(frag.regs.size(), iter.value());
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
@@ -376,8 +378,9 @@ LogicalResult MmaOp::verify() {
switch (multiplicandAPtxType().getValue()) {
case MMATypes::tf32:
kFactor = 4;
+ multiplicandFragType = i32Ty;
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
- context, {i32Ty, i32Ty, i32Ty, i32Ty}));
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
break;
case MMATypes::f16:
case MMATypes::bf16:
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index cf3679904ecc4e..dfe0443f7c4a78 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -152,6 +152,17 @@ func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
+func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
+ %b0 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index ad73a295359df5..fddfdda764832b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -203,6 +203,17 @@ llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
llvm.return %0 : !llvm.struct<(f64, f64)>
}
+llvm.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
+ %b0 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = {m = 16 : i32, n = 8 : i32, k = 4 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
// in the LLVM NVPTX backend.
// CHECK-LABEL: @gpu_wmma_load_op
More information about the Mlir-commits
mailing list