[Mlir-commits] [mlir] 5c3150e - [MLIR][NVVM] Introduce WGMMA Types

Guray Ozen llvmlistbot at llvm.org
Sat Aug 12 03:47:50 PDT 2023


Author: Guray Ozen
Date: 2023-08-12T12:47:45+02:00
New Revision: 5c3150e584b6449ac80434ed45a9bfb0188c444c

URL: https://github.com/llvm/llvm-project/commit/5c3150e584b6449ac80434ed45a9bfb0188c444c
DIFF: https://github.com/llvm/llvm-project/commit/5c3150e584b6449ac80434ed45a9bfb0188c444c.diff

LOG: [MLIR][NVVM] Introduce WGMMA Types

This work introduces `WGMMATypes` attributes for the `WgmmaMmaSyncOp`. This op, having been recently added to MLIR, previously used `MMATypes`. However, there arises a disparity in supported types between `MmaOp` and `WgmmaMmaSyncOp`. To address this discrepancy more effectively, a new set of attributes is introduced.

Furthermore, this patch refines and optimizing the verification mechanisms of `WgmmaMmaSyncOp` Op.

It also adds support for f8 types, including `e4m3` and `e5m2`, within the `WgmmaMmaSyncOp`.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D157695

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Conversion/NVVMToLLVM/invalid.mlir
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7ddf68e0a5b1a4..cfee215c2e3517 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1497,6 +1497,28 @@ def WGMMAScaleOutAttr : EnumAttr<NVVM_Dialect, WGMMAScaleOut, "wgmma_scale_out">
   let assemblyFormat = "`<` $value `>`";
 }
 
+/// Enum attribute of the 
diff erent PTX element types used for WGMMA operands.
+def WGMMATypeF16  : I32EnumAttrCase<"f16", 0>;
+def WGMMATypeTF32 : I32EnumAttrCase<"tf32", 1>;
+def WGMMATypeU8 : I32EnumAttrCase<"u8", 2>;
+def WGMMATypeS8 : I32EnumAttrCase<"s8", 3>;
+def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>;
+def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>;
+def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>;
+def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>;
+def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types",
+  [WGMMATypeF16, WGMMATypeTF32,
+    WGMMATypeU8, WGMMATypeS8,
+    WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3, 
+    WGMMATypeF8E5M2]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def WGMMATypesAttr : EnumAttr<NVVM_Dialect, WGMMATypes, "wgmma_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+
 def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async", 
               [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
                 PredOpTrait<"input struct and result struct must be the same type",
@@ -1508,15 +1530,14 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
     I64:$descriptorA, 
     I64:$descriptorB, 
     NVVM_MMAShapeAttr:$shape,
-    MMATypesAttr:$typeA,
-    MMATypesAttr:$typeB,
+    WGMMATypesAttr:$typeA,
+    WGMMATypesAttr:$typeB,
     WGMMAScaleOutAttr:$scaleD,
     WGMMAScaleInAttr:$scaleA,
     WGMMAScaleInAttr:$scaleB, 
     MMALayoutAttr:$layoutA,
     MMALayoutAttr:$layoutB,
     OptionalAttr<MMAIntOverflowAttr>:$satfinite
-    // OptionalAttr<UnitAttr>:$satfinite
   );  
   
    let assemblyFormat = [{ 
@@ -1536,44 +1557,50 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
 
     Supported shapes:  
     ```
-    |-------------------|--------------------|----------------|---------------|
-    |                   | f16 += f16 * f16   | s32 += s8 * s8 |               | 
-    |f32 += tf32 * tf32 | f32 += f16 * f16   | s32 += s8 * u8 |s32 += b1 * b1 |
-    |                   | f32 += bf16 * bf16 | s32 += u8 * u8 |               |
-    |-------------------|--------------------|----------------|---------------|
-    |   .m64n8k8        |  .m64n8k16         |  .m64n8k32     |  .m64n8k256   |
-    |   .m64n16k8       |  .m64n16k16        |  .m64n16k32    |  .m64n16k256  |
-    |   .m64n24k8       |  .m64n24k16        |  .m64n24k32    |  .m64n24k256  |
-    |   .m64n32k8       |  .m64n32k16        |  .m64n32k32    |  .m64n32k256  |
-    |   .m64n40k8       |  .m64n40k16        |  .m64n48k32    |  .m64n48k256  |
-    |   .m64n48k8       |  .m64n48k16        |  .m64n64k32    |  .m64n64k256  |
-    |   .m64n56k8       |  .m64n56k16        |  .m64n80k32    |  .m64n80k256  |
-    |   .m64n64k8       |  .m64n64k16        |  .m64n96k32    |  .m64n96k256  |
-    |   .m64n72k8       |  .m64n72k16        |  .m64n112k32   |  .m64n112k256 |
-    |   .m64n80k8       |  .m64n80k16        |  .m64n128k32   |  .m64n128k256 |
-    |   .m64n88k8       |  .m64n88k16        |  .m64n144k32   |  .m64n144k256 |
-    |   .m64n96k8       |  .m64n96k16        |  .m64n160k32   |  .m64n160k256 |
-    |   .m64n104k8      |  .m64n104k16       |  .m64n176k32   |  .m64n176k256 |
-    |   .m64n112k8      |  .m64n112k16       |  .m64n192k32   |  .m64n192k256 |
-    |   .m64n120k8      |  .m64n120k16       |  .m64n208k32   |  .m64n208k256 |
-    |   .m64n128k8      |  .m64n128k16       |  .m64n224k32   |  .m64n224k256 |
-    |   .m64n136k8      |  .m64n136k16       |  .m64n240k32   |  .m64n240k256 |
-    |   .m64n144k8      |  .m64n144k16       |  .m64n256k32   |  .m64n256k256 |
-    |   .m64n152k8      |  .m64n152k16       |                |               | 
-    |   .m64n160k8      |  .m64n160k16       |                |               | 
-    |   .m64n168k8      |  .m64n168k16       |                |               | 
-    |   .m64n176k8      |  .m64n176k16       |                |               | 
-    |   .m64n184k8      |  .m64n184k16       |                |               | 
-    |   .m64n192k8      |  .m64n192k16       |                |               |
-    |   .m64n200k8      |  .m64n200k16       |                |               |
-    |   .m64n208k8      |  .m64n208k16       |                |               |
-    |   .m64n216k8      |  .m64n216k16       |                |               |
-    |   .m64n224k8      |  .m64n224k16       |                |               |
-    |   .m64n232k8      |  .m64n232k16       |                |               |
-    |   .m64n240k8      |  .m64n240k16       |                |               |
-    |   .m64n248k8      |  .m64n248k16       |                |               |
-    |   .m64n256k8      |  .m64n256k16       |                |               |
-    |-------------------|--------------------|----------------|---------------|
+    |--------------|--------------|------------|--------------|---------------|
+    |              |              |            |              |f16+=e4m3*e4m3 |
+    |              |              |            |              |f16+=e5m2*e5m2 |
+    |f32+=tf32*tf32|f16+=f16 *f16 | s32+=s8*s8 |s32 += b1 * b1|f16+=e5m2*e4m3 |
+    |              |f32+=f16 *f16 | s32+=u8*u8 |              |f16+=e4m3*e5m2 |
+    |              |f32+=bf16*bf16| s32+=u8*u8 |              |f16+=e4m3*e5m2 |
+    |              |f32+=bf16*bf16| s32+=s8*u8 |              |f32+=e4m3*e4m3 |
+    |              |              | s32+=u8*s8 |              |f32+=e5m2*e5m2 |
+    |              |              |            |              |f32+=e4m3*e5m2 |
+    |              |              |            |              |f32+=e4m3*e5m2 |
+    |--------------|--------------|------------|--------------|---------------|
+    |   .m64n8k8   |  .m64n8k16   | .m64n8k32  | .m64n8k256   | .m64n8k32     |
+    |   .m64n16k8  |  .m64n16k16  | .m64n16k32 | .m64n16k256  | .m64n16k32    |
+    |   .m64n24k8  |  .m64n24k16  | .m64n24k32 | .m64n24k256  | .m64n24k32    |
+    |   .m64n32k8  |  .m64n32k16  | .m64n32k32 | .m64n32k256  | .m64n32k32    |
+    |   .m64n40k8  |  .m64n40k16  | .m64n48k32 | .m64n48k256  | .m64n40k32    |
+    |   .m64n48k8  |  .m64n48k16  | .m64n64k32 | .m64n64k256  | .m64n48k32    |
+    |   .m64n56k8  |  .m64n56k16  | .m64n80k32 | .m64n80k256  | .m64n56k32    |
+    |   .m64n64k8  |  .m64n64k16  | .m64n96k32 | .m64n96k256  | .m64n64k32    |
+    |   .m64n72k8  |  .m64n72k16  | .m64n112k32| .m64n112k256 | .m64n72k32    |
+    |   .m64n80k8  |  .m64n80k16  | .m64n128k32| .m64n128k256 | .m64n80k32    |
+    |   .m64n88k8  |  .m64n88k16  | .m64n144k32| .m64n144k256 | .m64n88k32    |
+    |   .m64n96k8  |  .m64n96k16  | .m64n160k32| .m64n160k256 | .m64n96k32    |
+    |   .m64n104k8 |  .m64n104k16 | .m64n176k32| .m64n176k256 | .m64n104k32   |
+    |   .m64n112k8 |  .m64n112k16 | .m64n192k32| .m64n192k256 | .m64n112k32   |
+    |   .m64n120k8 |  .m64n120k16 | .m64n208k32| .m64n208k256 | .m64n120k32   |
+    |   .m64n128k8 |  .m64n128k16 | .m64n224k32| .m64n224k256 | .m64n128k32   |
+    |   .m64n136k8 |  .m64n136k16 | .m64n240k32| .m64n240k256 | .m64n136k32   |
+    |   .m64n144k8 |  .m64n144k16 | .m64n256k32| .m64n256k256 | .m64n144k32   |
+    |   .m64n152k8 |  .m64n152k16 |            |              | .m64n152k32   |
+    |   .m64n160k8 |  .m64n160k16 |            |              | .m64n160k32   |
+    |   .m64n168k8 |  .m64n168k16 |            |              | .m64n168k32   |
+    |   .m64n176k8 |  .m64n176k16 |            |              | .m64n176k32   |
+    |   .m64n184k8 |  .m64n184k16 |            |              | .m64n184k32   |
+    |   .m64n192k8 |  .m64n192k16 |            |              | .m64n192k32   |
+    |   .m64n200k8 |  .m64n200k16 |            |              | .m64n200k32   |
+    |   .m64n208k8 |  .m64n208k16 |            |              | .m64n208k32   |
+    |   .m64n216k8 |  .m64n216k16 |            |              | .m64n216k32   |
+    |   .m64n224k8 |  .m64n224k16 |            |              | .m64n224k32   |
+    |   .m64n232k8 |  .m64n232k16 |            |              | .m64n232k32   |
+    |   .m64n240k8 |  .m64n240k16 |            |              | .m64n240k32   |
+    |   .m64n248k8 |  .m64n248k16 |            |              | .m64n248k32   |
+    |   .m64n256k8 |  .m64n256k16 |            |              | .m64n256k32   |
+    |--------------|--------------|------------|--------------|---------------|
     ```
 
     See for more information:

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 70c774233797bb..d6794881b442ee 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -37,6 +37,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cassert>
 #include <optional>
 #include <string>
 
@@ -708,6 +709,81 @@ LogicalResult NVVM::LdMatrixOp::verify() {
   return success();
 }
 
+FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
+  if (typeA == NVVM::WGMMATypes::tf32)
+    return 8;
+  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
+    return 16;
+  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
+    return 32;
+  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
+    return 32;
+  if (typeA == NVVM::WGMMATypes::b1)
+    return 256;
+  return failure();
+}
+
+LogicalResult isAllowedWGMMADataType(Type typeD, NVVM::WGMMATypes typeA,
+                                     NVVM::WGMMATypes typeB) {
+  switch (typeA) {
+  case NVVM::WGMMATypes::f16:
+    if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::f16)
+      return success();
+    break;
+  case NVVM::WGMMATypes::tf32:
+    if (typeD.isF32() && typeB == NVVM::WGMMATypes::tf32)
+      return success();
+    break;
+  case NVVM::WGMMATypes::u8:
+  case NVVM::WGMMATypes::s8:
+    if (typeD.isInteger(32) &&
+        (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
+      return success();
+    break;
+  case NVVM::WGMMATypes::b1:
+    if (typeD.isInteger(32) && typeB == NVVM::WGMMATypes::b1)
+      return success();
+    break;
+  case NVVM::WGMMATypes::bf16:
+    if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::bf16)
+      return success();
+    break;
+  case NVVM::WGMMATypes::e4m3:
+  case NVVM::WGMMATypes::e5m2:
+    if ((typeD.isF32() || typeD.isF16()) &&
+        (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
+      return success();
+    break;
+  }
+  return failure();
+}
+
+LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
+  SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
+                               72,  80,  88,  96,  104, 112, 120, 128,
+                               136, 144, 152, 160, 168, 176, 184, 192,
+                               200, 208, 216, 224, 232, 240, 248, 256};
+  SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
+                                    80,  96,  112, 128, 144, 160,
+                                    176, 192, 208, 224, 240, 256};
+  switch (typeA) {
+  case mlir::NVVM::WGMMATypes::f16:
+  case mlir::NVVM::WGMMATypes::tf32:
+  case mlir::NVVM::WGMMATypes::bf16:
+  case mlir::NVVM::WGMMATypes::e4m3:
+  case mlir::NVVM::WGMMATypes::e5m2:
+    if (llvm::any_of(allowedN, [&](int n) { return sizeN == n; }))
+      return success();
+    break;
+  case mlir::NVVM::WGMMATypes::u8:
+  case mlir::NVVM::WGMMATypes::s8:
+  case mlir::NVVM::WGMMATypes::b1:
+    if (llvm::any_of(allowedNshort, [&](int n) { return sizeN == n; }))
+      return success();
+  }
+  return failure();
+}
+
 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   Value outValue = getResults();
   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
@@ -730,142 +806,49 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
     return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
   }
 
+  mlir::NVVM::WGMMATypes typeA = getTypeA();
+  mlir::NVVM::WGMMATypes typeB = getTypeB();
+  if (failed(isAllowedWGMMADataType(outputType, typeA, typeB))) {
+    return emitOpError() << outputType
+                         << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
+                         << NVVM::stringifyWGMMATypes(typeB)
+                         << ", it is not supported.";
+  }
+
   // Check M
   if (getShape().getM() != 64)
     return emitOpError() << "shape 'm' must be 64";
 
   // Check K
-  mlir::NVVM::MMATypes typeA = getTypeA();
-  mlir::NVVM::MMATypes typeB = getTypeB();
-  switch (typeA) {
-  case mlir::NVVM::MMATypes::bf16:
-  case mlir::NVVM::MMATypes::f16:
-    if (typeA != typeB) {
-      return emitOpError() << "input types must be same but got "
-                           << NVVM::stringifyMMATypes(typeA) << " and "
-                           << NVVM::stringifyMMATypes(typeB);
-    }
-    if (getShape().getK() != 16) {
-      return emitOpError() << "shape 'k'  must be 16 "
-                           << "for input type "
-                           << NVVM::stringifyMMATypes(typeA);
-    }
-    break;
-  case mlir::NVVM::MMATypes::tf32:
-    if (typeA != typeB) {
-      return emitOpError() << "input types must be same but got "
-                           << NVVM::stringifyMMATypes(typeA) << " and "
-                           << NVVM::stringifyMMATypes(typeB);
-    }
-    if (getShape().getK() != 8) {
-      return emitOpError() << "shape 'k'  must be 8 "
-                           << "for input type "
-                           << NVVM::stringifyMMATypes(typeA);
-    }
-    break;
-  case mlir::NVVM::MMATypes::s8:
-  case mlir::NVVM::MMATypes::u8:
-    if (typeB != mlir::NVVM::MMATypes::s8 &&
-        typeB != mlir::NVVM::MMATypes::u8) {
-      return emitOpError() << "input type of rhs could be "
-                           << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::s8)
-                           << " or "
-                           << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::u8)
-                           << " same but got and "
-                           << NVVM::stringifyMMATypes(typeB);
-    }
-    if (getShape().getK() != 32) {
-      emitOpError() << "shape 'k'  must be 32 "
-                    << "for input type " << NVVM::stringifyMMATypes(typeA);
-    }
-    break;
-  case mlir::NVVM::MMATypes::b1:
-    if (typeA != typeB) {
-      return emitOpError() << "input types must be same but got "
-                           << NVVM::stringifyMMATypes(typeA) << " and "
-                           << NVVM::stringifyMMATypes(typeB);
-    }
-    if (getShape().getK() != 256) {
-      return emitOpError() << "shape 'k'  must be 256 "
-                           << "for input type "
-                           << NVVM::stringifyMMATypes(typeA);
-    }
-    break;
-  default:
-    return emitOpError() << "Unsupported input type "
-                         << NVVM::stringifyMMATypes(typeA) << " and "
-                         << NVVM::stringifyMMATypes(typeB);
-  }
+  FailureOr<int> allowedK = getAllowedSizeK(typeA);
+  if (failed(allowedK) || allowedK.value() != getShape().getK())
+    return emitOpError() << "shape 'k' must be " << allowedK.value()
+                         << " for input type "
+                         << NVVM::stringifyWGMMATypes(typeA);
 
   // Check N
-  SmallVector<int> allowedNShapesF16 = {8,   16,  24,  32,  40,  48,  56,  64,
-                                        72,  80,  88,  96,  104, 112, 120, 128,
-                                        136, 144, 152, 160, 168, 176, 184, 192,
-                                        200, 208, 216, 224, 232, 240, 248, 256};
-  SmallVector<int> allowedNShapesU8S8B1 = {8,   16,  24,  32,  48,  64,
-                                           80,  96,  112, 128, 144, 160,
-                                           176, 192, 208, 224, 240, 256};
-
-  bool validGEMMType = false;
-  // f16 += f16 * f16
-  if (outputType.isF16() && typeA == mlir::NVVM::MMATypes::f16) {
-    if (!llvm::any_of(allowedNShapesF16,
-                      [&](int n) { return getShape().getN() == n; })) {
-      return emitOpError() << "has input type "
-                           << NVVM::stringifyMMATypes(typeA) << " n is set to "
-                           << getShape().getN() << ", it is not supported.";
-    }
-    validGEMMType = true;
-  }
-  // f32 += tf32 * tf32| f32 += f16 * f16| f16 += bf16 * bf16
-  if (outputType.isF32() && (typeA == mlir::NVVM::MMATypes::bf16 ||
-                             typeA == mlir::NVVM::MMATypes::tf32 ||
-                             typeA == mlir::NVVM::MMATypes::f16)) {
-    if (!llvm::any_of(allowedNShapesF16,
-                      [&](int n) { return getShape().getN() == n; })) {
-      return emitOpError() << "has input type "
-                           << NVVM::stringifyMMATypes(typeA) << " n is set to "
-                           << getShape().getN() << ", it is not supported.";
-    }
-    validGEMMType = true;
-  }
-  // s32 += s8 * s8 | s32 += s8 * u8 | s32 += u8 * u8 | s32 += b1 * b1
-  if (outputType.isInteger(32) &&
-      (typeA == mlir::NVVM::MMATypes::s8 || typeA == mlir::NVVM::MMATypes::u8 ||
-       typeA == mlir::NVVM::MMATypes::b1)) {
-    if (!llvm::any_of(allowedNShapesU8S8B1,
-                      [&](int n) { return getShape().getN() == n; })) {
-      return emitOpError() << "has input type "
-                           << NVVM::stringifyMMATypes(typeA) << " n is set to "
-                           << getShape().getN() << ", it is not supported.";
-    }
-    validGEMMType = true;
+  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
+    return emitOpError() << "has input type "
+                         << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
+                         << getShape().getN() << ", it is not supported.";
   }
 
-  if (!validGEMMType) {
-    return emitOpError() << outputType
-                         << " += " << NVVM::stringifyMMATypes(typeA) << " * "
-                         << NVVM::stringifyMMATypes(typeB)
-                         << ", it is not supported.";
-  }
-
-  // Check transpose is needed from the given layouts. It is only
-  // supported for bf16 or f16.
-  if ((typeA != mlir::NVVM::MMATypes::f16 &&
-       typeA != mlir::NVVM::MMATypes::bf16) &&
+  // Check transpose (only available for f16/bf16)
+  if ((typeA != mlir::NVVM::WGMMATypes::f16 &&
+       typeA != mlir::NVVM::WGMMATypes::bf16) &&
       (getLayoutA() == mlir::NVVM::MMALayout::col ||
        getLayoutB() == mlir::NVVM::MMALayout::col)) {
     return emitOpError()
            << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
            << " and layout_b = " << stringifyMMALayout(getLayoutB())
-           << " for input types " << stringifyMMATypes(typeA) << " and "
-           << stringifyMMATypes(typeB)
+           << " for input types " << stringifyWGMMATypes(typeA) << " and "
+           << stringifyWGMMATypes(typeB)
            << " requires transpose. However, this is only supported for: "
            << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and "
            << stringifyMMATypes(mlir::NVVM::MMATypes::bf16);
   }
 
-  // Check number of result registers
+  // Check result registers
   int expectedOutput;
   if (outputType.isF32() || outputType.isInteger(32))
     expectedOutput = getShape().getN() / 2;
@@ -876,7 +859,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
                          << ", however output struct has " << outputSize
                          << " elements";
   }
-  // Check satfinite is set. It is only for s32 accumulator
+  // Check satfinite (only availalbe for s32 accumulator)
   if (!outputType.isInteger(32) &&
       getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
           NVVM::MMAIntOverflow::satfinite) {
@@ -892,8 +875,8 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
 
   int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
-  bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 ||
-               getTypeA() == mlir::NVVM::MMATypes::bf16;
+  bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
+               getTypeA() == mlir::NVVM::WGMMATypes::bf16;
 
   Value outValue = getResults() ? getResults() : getInouts();
   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
@@ -901,10 +884,13 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   std::string outputTypeName;
   if (outputType.isF16())
     outputTypeName = "f16";
-  if (outputType.isF32())
+  else if (outputType.isF32())
     outputTypeName = "f32";
   else if (outputType.isInteger(32))
     outputTypeName = "s32";
+  else
+    assert(false && "unsupported output type");
+
   int expectedOutputRegisters;
   if (outputType.isF32() || outputType.isInteger(32))
     expectedOutputRegisters = getShape().getN() / 2;
@@ -921,7 +907,8 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
      << ", 0;\n"
         "wgmma.mma_async.sync.aligned.m"
      << m << "n" << n << "k" << k << "." << outputTypeName << "."
-     << stringifyMMATypes(getTypeA()) << "." << stringifyMMATypes(getTypeB());
+     << stringifyWGMMATypes(getTypeA()) << "."
+     << stringifyWGMMATypes(getTypeB());
   if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
       NVVM::MMAIntOverflow::satfinite)
     ss << ".satfinite";
@@ -959,8 +946,8 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
   Value outValue = getResults() ? getResults() : getInouts();
   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
   Type outputType = stype.getBody().front();
-  bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 ||
-               getTypeA() == mlir::NVVM::MMATypes::bf16;
+  bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
+               getTypeA() == mlir::NVVM::WGMMATypes::bf16;
   if (getResults())
     asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
   if (getInouts())

diff  --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
index 7f8fa5267c343f..8b9df5fa598014 100644
--- a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
@@ -62,7 +62,7 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
 
 func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {  
   %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
-  // expected-error @+1 {{op shape 'k'  must be 16 for input type f16}}
+  // expected-error @+1 {{op shape 'k' must be 16 for input type f16}}
   %res = nvvm.wgmma.mma_async %descA, %descB, 
       #nvvm.shape<m = 64, n = 16, k = 3>, 
       D [%result, <zero>], 
@@ -116,4 +116,19 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
       : !llvm.struct<(i32, i32, i32, i32)> 
       -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
   return 
-}
\ No newline at end of file
+}
+
+// -----
+
+func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {  
+  %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // expected-error @+1 {{op 'f32' += bf16 * f16, it is not supported}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 8, k = 16>, 
+      D [%result, <zero>], 
+      A [<bf16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+      -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  return 
+}

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index e1e2bcba40a256..3bcb41c18d4fb4 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -258,14 +258,71 @@ func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
   %result1 = nvvm.wgmma.mma_async %descA, %descB, 
       #nvvm.shape<m = 64, n = 64, k = 8>, 
       D [%result, #nvvm.wgmma_scale_out<one>],
-      A [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
        : !mat32f32 -> !mat32f32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, 
       #nvvm.shape<m = 64, n = 64, k = 8>, 
       D [%result1, #nvvm.wgmma_scale_out<one>],
-      A [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      : !mat32f32 -> !mat32f32
+  return %result2 : !mat32f32
+}
+
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_e4m3_e4m3
+func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {  
+  // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35,  $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35,  $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  %result = llvm.mlir.undef : !mat32f32
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 64, k = 32>, 
+      D [%result, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+       : !mat32f32 -> !mat32f32
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 64, k = 32>, 
+      D [%result1, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      : !mat32f32 -> !mat32f32
+  return %result2 : !mat32f32
+}
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_e5m2_e4m3
+func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {  
+  // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35,  $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35,  $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  %result = llvm.mlir.undef : !mat32f32
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 64, k = 32>, 
+      D [%result, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+       : !mat32f32 -> !mat32f32
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 64, k = 32>, 
+      D [%result1, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
       : !mat32f32 -> !mat32f32
   return %result2 : !mat32f32
 }


        


More information about the Mlir-commits mailing list