[Mlir-commits] [mlir] 22c6e7b - [mlir][nvvm] Fix support for tf32 data type in mma.sync

Christopher Bate llvmlistbot at llvm.org
Thu May 5 10:03:45 PDT 2022


Author: Christopher Bate
Date: 2022-05-05T11:02:03-06:00
New Revision: 22c6e7b277fbe6c65216d8c7a690d53c8122cb42

URL: https://github.com/llvm/llvm-project/commit/22c6e7b277fbe6c65216d8c7a690d53c8122cb42
DIFF: https://github.com/llvm/llvm-project/commit/22c6e7b277fbe6c65216d8c7a690d53c8122cb42.diff

LOG: [mlir][nvvm] Fix support for tf32 data type in mma.sync

The NVVM dialect test coverage for all possible type/shape combinations
in the `nvvm.mma.sync` op is mostly complete. However, there were tests
missing for TF32 datatype support. This change adds tests for the one
relevant shape/type combination. This uncovered a small bug in the op
verifier, which this change also fixes.

Differential Revision: https://reviews.llvm.org/D124975

Added: 
    

Modified: 
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 31683b47e59227..345d90044b4c22 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -81,8 +81,10 @@ Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
     return NVVM::MMATypes::f64;
   if (operandElType.isF16() || operandElType == half2Type)
     return NVVM::MMATypes::f16;
-  if (operandElType.isF32())
+  if (operandElType.isF32() && isAccumulator)
     return NVVM::MMATypes::f32;
+  if (operandElType.isF32() && !isAccumulator)
+    return NVVM::MMATypes::tf32;
   if (operandElType.isa<IntegerType>()) {
     if (isAccumulator)
       return NVVM::MMATypes::s32;
@@ -291,7 +293,7 @@ ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
           parser.getNameLoc(),
           "expected one type for each operand segment but got " +
               Twine(operandTypes.size()) + " types");
-  for (const auto& iter : llvm::enumerate(operandTypes)) {
+  for (const auto &iter : llvm::enumerate(operandTypes)) {
     auto &frag = frags[iter.index()];
     frag.regTypes.resize(frag.regs.size(), iter.value());
     if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
@@ -376,8 +378,9 @@ LogicalResult MmaOp::verify() {
     switch (multiplicandAPtxType().getValue()) {
     case MMATypes::tf32:
       kFactor = 4;
+      multiplicandFragType = i32Ty;
       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
-          context, {i32Ty, i32Ty, i32Ty, i32Ty}));
+          context, {f32Ty, f32Ty, f32Ty, f32Ty}));
       break;
     case MMATypes::f16:
     case MMATypes::bf16:

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index cf3679904ecc4e..dfe0443f7c4a78 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -152,6 +152,17 @@ func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
+func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,                                
+                                     %b0 : i32,
+                                     %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
+     shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
 func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, 
                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {  
   // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index ad73a295359df5..fddfdda764832b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -203,6 +203,17 @@ llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
 
+llvm.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
+                                     %b0 : i32,
+                                     %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+  // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
+     shape = {m = 16 : i32, n = 8 : i32, k = 4 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
 // The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
 // in the LLVM NVPTX backend.
 // CHECK-LABEL: @gpu_wmma_load_op


        


More information about the Mlir-commits mailing list