[Mlir-commits] [mlir] 18e161f - [MLIR][NVVM] Introduction of the `wgmma.mma_async` Op

Guray Ozen llvmlistbot at llvm.org
Wed Aug 9 14:08:06 PDT 2023


Author: Guray Ozen
Date: 2023-08-09T23:08:00+02:00
New Revision: 18e161f9e15b036faf48bfd8813d9330e06e2ee3

URL: https://github.com/llvm/llvm-project/commit/18e161f9e15b036faf48bfd8813d9330e06e2ee3
DIFF: https://github.com/llvm/llvm-project/commit/18e161f9e15b036faf48bfd8813d9330e06e2ee3.diff

LOG: [MLIR][NVVM] Introduction of the `wgmma.mma_async` Op

This work introduces the `wgmma.mma_async` Op along PTX generation using `BasicPtxBuilderOpInterface`. The Op is designed to execute the matrix multiply-and-accumulate operation across a warpgroup (128 threads). It's important to note that this operation works for devices with the sm_90a capability.

The matrix multiply-and-accumulate operation can take one of the following forms. In both cases, matrix D is referred to as the accumulator:
	D = A * B + D 	: Result is added to the accumulator matrix D.
	D = A * B 		: The input from the accumulator matrix D is not utilized.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D157370

Added: 
    mlir/test/Conversion/NVVMToLLVM/invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 490a0db9baa028..4845a226ec9c17 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -168,7 +168,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
         /*desc=*/[{Generate constant value.}],
         /*retType=*/"::mlir::Value",
         /*methodName=*/"makeConstantI32",
-        /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "unsigned" : $val),
+        /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
         /*methodBody=*/"",
         /*defaultImpl=*/ [{
             mlir::Operation* op = $_op;
@@ -1473,6 +1473,121 @@ def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned",
   }];
 }
 
+/// Enum attribute type for the negating of input operands
+def WGMMAScaleInNeg : I32EnumAttrCase<"neg", -1>;
+def WGMMAScaleInOne : I32EnumAttrCase<"one", 1>;
+def WGMMAScaleIn : I32EnumAttr<"WGMMAScaleIn", "WGMMA overflow options",
+  [WGMMAScaleInOne, WGMMAScaleInNeg]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def WGMMAScaleInAttr : EnumAttr<NVVM_Dialect, WGMMAScaleIn, "wgmma_scale_in"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+/// Enum attribute type for the output operand
+def WGMMAScaleOutZero : I32EnumAttrCase<"zero", 0>;
+def WGMMAScaleOutOne : I32EnumAttrCase<"one", 1>;
+def WGMMAScaleOut : I32EnumAttr<"WGMMAScaleOut", "WGMMA input predicate",
+  [WGMMAScaleOutZero, WGMMAScaleOutOne]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def WGMMAScaleOutAttr : EnumAttr<NVVM_Dialect, WGMMAScaleOut, "wgmma_scale_out"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_WgmmaMmaSyncOp : NVVM_Op<"wgmma.mma_async", 
+              [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+                PredOpTrait<"input struct and result struct must be the same type",
+                  TCresIsSameAsOpBase<0, 0>>,]> 
+{
+  let results = (outs LLVM_AnyAggregate:$results);
+  let arguments = (ins 
+    LLVM_AnyAggregate:$inouts,
+    I64:$descriptorA, 
+    I64:$descriptorB, 
+    NVVM_MMAShapeAttr:$shape,
+    MMATypesAttr:$typeA,
+    MMATypesAttr:$typeB,
+    WGMMAScaleOutAttr:$scaleD,
+    WGMMAScaleInAttr:$scaleA,
+    WGMMAScaleInAttr:$scaleB, 
+    MMALayoutAttr:$layoutA,
+    MMALayoutAttr:$layoutB,
+    OptionalAttr<MMAIntOverflowAttr>:$satfinite
+    // OptionalAttr<UnitAttr>:$satfinite
+  );  
+  
+   let assemblyFormat = [{ 
+      $descriptorA `,` $descriptorB `,` $shape `,` 
+      `D` `[` $inouts `,` $scaleD (`,` $satfinite^)? `]` `,`
+      `A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,` 
+      `B` `[` $typeB `,` $scaleB `,` $layoutB `]`
+      attr-dict `:` 
+      type($inouts) `->` type($results)
+    }];
+  
+  let description = [{
+    The warpgroup (128 threads) level matrix multiply and accumulate operation 
+    has either of the following forms, where matrix D is called accumulator:
+      D = A * B + D
+      D = A * B, where the input from accumulator D is disabled.
+
+    Supported shapes:  
+    ```
+    |-------------------|--------------------|----------------|---------------|
+    |                   | f16 += f16 * f16   | s32 += s8 * s8 |               | 
+    |f32 += tf32 * tf32 | f32 += f16 * f16   | s32 += s8 * u8 |s32 += b1 * b1 |
+    |                   | f32 += bf16 * bf16 | s32 += u8 * u8 |               |
+    |-------------------|--------------------|----------------|---------------|
+    |   .m64n8k8        |  .m64n8k16         |  .m64n8k32     |  .m64n8k256   |
+    |   .m64n16k8       |  .m64n16k16        |  .m64n16k32    |  .m64n16k256  |
+    |   .m64n24k8       |  .m64n24k16        |  .m64n24k32    |  .m64n24k256  |
+    |   .m64n32k8       |  .m64n32k16        |  .m64n32k32    |  .m64n32k256  |
+    |   .m64n40k8       |  .m64n40k16        |  .m64n48k32    |  .m64n48k256  |
+    |   .m64n48k8       |  .m64n48k16        |  .m64n64k32    |  .m64n64k256  |
+    |   .m64n56k8       |  .m64n56k16        |  .m64n80k32    |  .m64n80k256  |
+    |   .m64n64k8       |  .m64n64k16        |  .m64n96k32    |  .m64n96k256  |
+    |   .m64n72k8       |  .m64n72k16        |  .m64n112k32   |  .m64n112k256 |
+    |   .m64n80k8       |  .m64n80k16        |  .m64n128k32   |  .m64n128k256 |
+    |   .m64n88k8       |  .m64n88k16        |  .m64n144k32   |  .m64n144k256 |
+    |   .m64n96k8       |  .m64n96k16        |  .m64n160k32   |  .m64n160k256 |
+    |   .m64n104k8      |  .m64n104k16       |  .m64n176k32   |  .m64n176k256 |
+    |   .m64n112k8      |  .m64n112k16       |  .m64n192k32   |  .m64n192k256 |
+    |   .m64n120k8      |  .m64n120k16       |  .m64n208k32   |  .m64n208k256 |
+    |   .m64n128k8      |  .m64n128k16       |  .m64n224k32   |  .m64n224k256 |
+    |   .m64n136k8      |  .m64n136k16       |  .m64n240k32   |  .m64n240k256 |
+    |   .m64n144k8      |  .m64n144k16       |  .m64n256k32   |  .m64n256k256 |
+    |   .m64n152k8      |  .m64n152k16       |                |               | 
+    |   .m64n160k8      |  .m64n160k16       |                |               | 
+    |   .m64n168k8      |  .m64n168k16       |                |               | 
+    |   .m64n176k8      |  .m64n176k16       |                |               | 
+    |   .m64n184k8      |  .m64n184k16       |                |               | 
+    |   .m64n192k8      |  .m64n192k16       |                |               |
+    |   .m64n200k8      |  .m64n200k16       |                |               |
+    |   .m64n208k8      |  .m64n208k16       |                |               |
+    |   .m64n216k8      |  .m64n216k16       |                |               |
+    |   .m64n224k8      |  .m64n224k16       |                |               |
+    |   .m64n232k8      |  .m64n232k16       |                |               |
+    |   .m64n240k8      |  .m64n240k16       |                |               |
+    |   .m64n248k8      |  .m64n248k16       |                |               |
+    |   .m64n256k8      |  .m64n256k16       |                |               |
+    |-------------------|--------------------|----------------|---------------|
+    ```
+
+    See for more information:
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions
+  }];
+  
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    void getAsmValues(RewriterBase &rewriter, 
+        llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 52abbe998872ab..2d7a441e950045 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -53,7 +53,7 @@ using namespace NVVM;
 namespace {
 
 class PtxBuilder {
-  Operation *op;
+  NVVM::BasicPtxBuilderInterface op;
   PatternRewriter &rewriter;
   std::string asmStr;
   SmallVector<Value> asmVals;
@@ -62,30 +62,35 @@ class PtxBuilder {
   bool hasResult = false;
 
   // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints
-  char getRegisterType(Value v) {
-    if (v.getDefiningOp<LLVM::ConstantOp>())
-      return 'n';
-    if (v.getType().isInteger(16))
+  char getRegisterType(Type type) {
+    if (type.isInteger(16))
       return 'h';
-    if (v.getType().isInteger(32))
+    if (type.isInteger(32))
       return 'r';
-    if (v.getType().isInteger(64))
+    if (type.isInteger(64))
       return 'l';
-    if (v.getType().isF32())
+    if (type.isF32())
       return 'f';
-    if (v.getType().isF64())
+    if (type.isF64())
       return 'd';
-    if (auto ptr = v.getType().dyn_cast<LLVM::LLVMPointerType>()) {
+    if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
       // Shared address spaces is addressed with 32-bit pointers.
       if (ptr.getAddressSpace() == NVVM::kSharedMemorySpace) {
         return 'r';
       }
       return 'l';
     }
-    assert(false && "Register type is not handled yet");
+    op->emitError() << "Register type could not deduced from MLIR type: "
+                    << type;
     return ' ';
   }
 
+  char getRegisterType(Value v) {
+    if (v.getDefiningOp<LLVM::ConstantOp>())
+      return 'n';
+    return getRegisterType(v.getType());
+  }
+
 public:
   PtxBuilder(Operation *op, PatternRewriter &rewriter, std::string ptxAsm,
              bool sideEffects = false)
@@ -93,26 +98,60 @@ class PtxBuilder {
         sideEffects(sideEffects) {}
 
   void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read) {
-    llvm::raw_string_ostream ss(asmConstraints);
-    if (itype == PTXRegisterMod::Read) {
-      asmVals.push_back(v);
-    } else if (itype == PTXRegisterMod::ReadWrite) {
-      asmVals.push_back(v);
-      ss << "+";
-      hasResult = true;
-    } else if (itype == PTXRegisterMod::Write) {
-      ss << "=";
+    LLVM_DEBUG(DBGS() << v << "\t Modifier : " << itype << "\n");
+    auto getModifier = [&]() -> const char * {
+      if (itype == PTXRegisterMod::ReadWrite) {
+        assert(false && "Read-Write modifier is not supported. Try setting the "
+                        "same value as Write and Read seperately.");
+        return "+";
+      }
+      if (itype == PTXRegisterMod::Write) {
+        return "=";
+      }
+      return "";
+    };
+    auto addValue = [&](Value v) {
+      if (itype == PTXRegisterMod::Read) {
+        asmVals.push_back(v);
+        return;
+      }
+      if (itype == PTXRegisterMod::ReadWrite)
+        asmVals.push_back(v);
       hasResult = true;
+    };
+
+    llvm::raw_string_ostream ss(asmConstraints);
+    // Handle Structs
+    if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
+      if (itype == PTXRegisterMod::Write) {
+        addValue(v);
+      }
+      for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
+        if (itype != PTXRegisterMod::Write) {
+          Value extractValue =
+              rewriter.create<LLVM::ExtractValueOp>(op->getLoc(), v, idx);
+          addValue(extractValue);
+        }
+        if (itype == PTXRegisterMod::ReadWrite) {
+          ss << idx << ",";
+        } else {
+          ss << getModifier() << getRegisterType(t) << ",";
+        }
+        ss.flush();
+      }
+      return;
     }
-    ss << getRegisterType(v) << ",";
+    // Handle Scalars
+    addValue(v);
+    ss << getModifier() << getRegisterType(v) << ",";
     ss.flush();
   }
 
   LLVM::InlineAsmOp build() {
     auto asmDialectAttr =
         LLVM::AsmDialectAttr::get(op->getContext(), LLVM::AsmDialect::AD_ATT);
-    Type resultType = hasResult ? op->getResult(0).getType()
-                                : LLVM::LLVMVoidType::get(op->getContext());
+
+    auto resultTypes = op->getResultTypes();
 
     // Remove the last comma from the constraints string.
     if (!asmConstraints.empty() &&
@@ -123,7 +162,8 @@ class PtxBuilder {
     std::replace(asmStr.begin(), asmStr.end(), '%', '$');
 
     return rewriter.create<LLVM::InlineAsmOp>(
-        op->getLoc(), resultType,
+        op->getLoc(),
+        /*result types=*/resultTypes,
         /*operands=*/asmVals,
         /*asm_string=*/llvm::StringRef(asmStr),
         /*constraints=*/asmConstraints.data(),
@@ -159,6 +199,7 @@ struct PtxLowering
     }
 
     SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
+    LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
     PtxBuilder generator(op, rewriter, op.getPtx(), op.hasSideEffect());
 
     op.getAsmValues(rewriter, asmValues);

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f086af6b74b1b0..25f837547b46f5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -26,7 +26,9 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Types.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Attributes.h"
@@ -34,6 +36,7 @@
 #include "llvm/IR/Type.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
 #include <optional>
 #include <string>
 
@@ -705,6 +708,287 @@ LogicalResult NVVM::LdMatrixOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::WgmmaMmaSyncOp::verify() {
+  Value outValue = getResults();
+  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
+  if (!stype)
+    return emitOpError() << "expected results to be struct";
+  Type outputType = stype.getBody().front();
+  int outputSize = stype.getBody().size();
+  for (Type t : stype.getBody()) {
+    if (t != outputType)
+      return emitOpError()
+             << "all elements in struct must be same type but there is " << t;
+  }
+
+  if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) {
+    return emitOpError() << "does not support the given output type "
+                         << outputType;
+  }
+  if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg ||
+                                   getScaleB() == NVVM::WGMMAScaleIn::neg)) {
+    return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
+  }
+
+  // Check M
+  if (getShape().getM() != 64)
+    return emitOpError() << "shape 'm' must be 64";
+
+  // Check K
+  mlir::NVVM::MMATypes typeA = getTypeA();
+  mlir::NVVM::MMATypes typeB = getTypeB();
+  switch (typeA) {
+  case mlir::NVVM::MMATypes::bf16:
+  case mlir::NVVM::MMATypes::f16:
+    if (typeA != typeB) {
+      return emitOpError() << "input types must be same but got "
+                           << NVVM::stringifyMMATypes(typeA) << " and "
+                           << NVVM::stringifyMMATypes(typeB);
+    }
+    if (getShape().getK() != 16) {
+      return emitOpError() << "shape 'k'  must be 16 "
+                           << "for input type "
+                           << NVVM::stringifyMMATypes(typeA);
+    }
+    break;
+  case mlir::NVVM::MMATypes::tf32:
+    if (typeA != typeB) {
+      return emitOpError() << "input types must be same but got "
+                           << NVVM::stringifyMMATypes(typeA) << " and "
+                           << NVVM::stringifyMMATypes(typeB);
+    }
+    if (getShape().getK() != 8) {
+      return emitOpError() << "shape 'k'  must be 8 "
+                           << "for input type "
+                           << NVVM::stringifyMMATypes(typeA);
+    }
+    break;
+  case mlir::NVVM::MMATypes::s8:
+  case mlir::NVVM::MMATypes::u8:
+    if (typeB != mlir::NVVM::MMATypes::s8 &&
+        typeB != mlir::NVVM::MMATypes::u8) {
+      return emitOpError() << "input type of rhs could be "
+                           << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::s8)
+                           << " or "
+                           << NVVM::stringifyMMATypes(mlir::NVVM::MMATypes::u8)
+                           << " same but got and "
+                           << NVVM::stringifyMMATypes(typeB);
+    }
+    if (getShape().getK() != 32) {
+      emitOpError() << "shape 'k'  must be 32 "
+                    << "for input type " << NVVM::stringifyMMATypes(typeA);
+    }
+    break;
+  case mlir::NVVM::MMATypes::b1:
+    if (typeA != typeB) {
+      return emitOpError() << "input types must be same but got "
+                           << NVVM::stringifyMMATypes(typeA) << " and "
+                           << NVVM::stringifyMMATypes(typeB);
+    }
+    if (getShape().getK() != 256) {
+      return emitOpError() << "shape 'k'  must be 256 "
+                           << "for input type "
+                           << NVVM::stringifyMMATypes(typeA);
+    }
+    break;
+  default:
+    return emitOpError() << "Unsupported input type "
+                         << NVVM::stringifyMMATypes(typeA) << " and "
+                         << NVVM::stringifyMMATypes(typeB);
+  }
+
+  // Check N
+  SmallVector<int> allowedNShapesF16 = {8,   16,  24,  32,  40,  48,  56,  64,
+                                        72,  80,  88,  96,  104, 112, 120, 128,
+                                        136, 144, 152, 160, 168, 176, 184, 192,
+                                        200, 208, 216, 224, 232, 240, 248, 256};
+  SmallVector<int> allowedNShapesU8S8B1 = {8,   16,  24,  32,  48,  64,
+                                           80,  96,  112, 128, 144, 160,
+                                           176, 192, 208, 224, 240, 256};
+
+  bool validGEMMType = false;
+  // f16 += f16 * f16
+  if (outputType.isF16() && typeA == mlir::NVVM::MMATypes::f16) {
+    if (!llvm::any_of(allowedNShapesF16,
+                      [&](int n) { return getShape().getN() == n; })) {
+      return emitOpError() << "has input type "
+                           << NVVM::stringifyMMATypes(typeA) << " n is set to "
+                           << getShape().getN() << ", it is not supported.";
+    }
+    validGEMMType = true;
+  }
+  // f32 += tf32 * tf32| f32 += f16 * f16| f16 += bf16 * bf16
+  if (outputType.isF32() && (typeA == mlir::NVVM::MMATypes::bf16 ||
+                             typeA == mlir::NVVM::MMATypes::tf32 ||
+                             typeA == mlir::NVVM::MMATypes::f16)) {
+    if (!llvm::any_of(allowedNShapesF16,
+                      [&](int n) { return getShape().getN() == n; })) {
+      return emitOpError() << "has input type "
+                           << NVVM::stringifyMMATypes(typeA) << " n is set to "
+                           << getShape().getN() << ", it is not supported.";
+    }
+    validGEMMType = true;
+  }
+  // s32 += s8 * s8 | s32 += s8 * u8 | s32 += u8 * u8 | s32 += b1 * b1
+  if (outputType.isInteger(32) &&
+      (typeA == mlir::NVVM::MMATypes::s8 || typeA == mlir::NVVM::MMATypes::u8 ||
+       typeA == mlir::NVVM::MMATypes::b1)) {
+    if (!llvm::any_of(allowedNShapesU8S8B1,
+                      [&](int n) { return getShape().getN() == n; })) {
+      return emitOpError() << "has input type "
+                           << NVVM::stringifyMMATypes(typeA) << " n is set to "
+                           << getShape().getN() << ", it is not supported.";
+    }
+    validGEMMType = true;
+  }
+
+  if (!validGEMMType) {
+    return emitOpError() << outputType
+                         << " += " << NVVM::stringifyMMATypes(typeA) << " * "
+                         << NVVM::stringifyMMATypes(typeB)
+                         << ", it is not supported.";
+  }
+
+  // Check transpose is needed from the given layouts. It is only
+  // supported for bf16 or f16.
+  if ((typeA != mlir::NVVM::MMATypes::f16 &&
+       typeA != mlir::NVVM::MMATypes::bf16) &&
+      (getLayoutA() == mlir::NVVM::MMALayout::col ||
+       getLayoutB() == mlir::NVVM::MMALayout::col)) {
+    return emitOpError()
+           << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
+           << " and layout_b = " << stringifyMMALayout(getLayoutB())
+           << " for input types " << stringifyMMATypes(typeA) << " and "
+           << stringifyMMATypes(typeB)
+           << " requires transpose. However, this is only supported for: "
+           << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and "
+           << stringifyMMATypes(mlir::NVVM::MMATypes::bf16);
+  }
+
+  // Check number of result registers
+  int expectedOutput;
+  if (outputType.isF32() || outputType.isInteger(32))
+    expectedOutput = getShape().getN() / 2;
+  if (outputType.isF16())
+    expectedOutput = getShape().getN() / 4;
+  if (outputSize != expectedOutput) {
+    return emitOpError() << "results " << expectedOutput
+                         << ", however output struct has " << outputSize
+                         << " elements";
+  }
+  // Check satfinite is set. It is only for s32 accumulator
+  if (!outputType.isInteger(32) &&
+      getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
+          NVVM::MMAIntOverflow::satfinite) {
+    return emitOpError()
+           << " `satfinite` can be only used with s32 accumulator, however "
+              "the current accumulator is "
+           << outputType;
+  }
+
+  return success();
+}
+
+std::string NVVM::WgmmaMmaSyncOp::getPtx() {
+
+  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
+  bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 ||
+               getTypeA() == mlir::NVVM::MMATypes::bf16;
+
+  Value outValue = getResults() ? getResults() : getInouts();
+  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
+  Type outputType = stype.getBody().front();
+  std::string outputTypeName;
+  if (outputType.isF16())
+    outputTypeName = "f16";
+  if (outputType.isF32())
+    outputTypeName = "f32";
+  else if (outputType.isInteger(32))
+    outputTypeName = "s32";
+  int expectedOutputRegisters;
+  if (outputType.isF32() || outputType.isInteger(32))
+    expectedOutputRegisters = getShape().getN() / 2;
+  if (outputType.isF16())
+    expectedOutputRegisters = getShape().getN() / 4;
+
+  std::string ptx;
+  llvm::raw_string_ostream ss(ptx);
+
+  ss << "{\n"
+        ".reg .pred p;\n"
+        "setp.ne.b32 p, $"
+     << (expectedOutputRegisters + 2)
+     << ", 0;\n"
+        "wgmma.mma_async.sync.aligned.m"
+     << m << "n" << n << "k" << k << "." << outputTypeName << "."
+     << stringifyMMATypes(getTypeA()) << "." << stringifyMMATypes(getTypeB());
+  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
+      NVVM::MMAIntOverflow::satfinite)
+    ss << ".satfinite";
+  ss << " {";
+  int regCnt = 0;
+  for (; regCnt < expectedOutputRegisters; ++regCnt) {
+    ss << "$" << regCnt;
+    if (regCnt != expectedOutputRegisters - 1)
+      ss << ", ";
+  }
+
+  ss << "},";
+  ss << " $" << (expectedOutputRegisters) << ","
+     << " $" << (expectedOutputRegisters + 1) << ","
+     << " p";
+  if (!outputType.isInteger(32)) {
+    ss << ", $" << (expectedOutputRegisters + 3) << ",  $"
+       << (expectedOutputRegisters + 4);
+  }
+  // Don't add transpose parameters unless needed.
+  if (isF16) {
+    ss << ", $" << (expectedOutputRegisters + 5) << ",  $"
+       << (expectedOutputRegisters + 6);
+  }
+  ss << ";\n"
+     << "}\n";
+  ss.flush();
+  return ptx;
+}
+
+void NVVM::WgmmaMmaSyncOp::getAsmValues(
+    RewriterBase &rewriter,
+    llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+        &asmValues) {
+  Value outValue = getResults() ? getResults() : getInouts();
+  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
+  Type outputType = stype.getBody().front();
+  bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 ||
+               getTypeA() == mlir::NVVM::MMATypes::bf16;
+  if (getResults())
+    asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
+  if (getInouts())
+    asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
+  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
+  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
+  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
+                       mlir::NVVM::PTXRegisterMod::Read});
+  if (!outputType.isInteger(32)) {
+    asmValues.push_back(
+        {makeConstantI32(rewriter,
+                         getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
+         mlir::NVVM::PTXRegisterMod::Read});
+    asmValues.push_back(
+        {makeConstantI32(rewriter,
+                         getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
+         mlir::NVVM::PTXRegisterMod::Read});
+  }
+  if (isF16) {
+    asmValues.push_back(
+        {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
+         mlir::NVVM::PTXRegisterMod::Read});
+    asmValues.push_back(
+        {makeConstantI32(rewriter, static_cast<int>(getLayoutB())),
+         mlir::NVVM::PTXRegisterMod::Read});
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
new file mode 100644
index 00000000000000..7f8fa5267c343f
--- /dev/null
+++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file -verify-diagnostics %s
+
+!mat64f32 = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32)>
+func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{  
+  %result = llvm.mlir.undef : !mat64f32
+  // expected-error @+1 {{'nvvm.wgmma.mma_async' op results 64, however output struct has 7 elements}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 128, k = 16>, 
+      D [%result, <zero>],
+      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      : !mat64f32 -> !mat64f32
+  return %res : !mat64f32
+}
+
+// -----
+
+func.func @wgmma_f32_satfinite(%descA : i64, %descB : i64) {  
+  %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  // expected-error @+1 {{`satfinite` can be only used with s32 accumulator, however the current accumulator is 'f32'}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 16, k = 16>, 
+      D [%result, <zero>, <satfinite>], 
+      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+      -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  return 
+}
+
+// -----
+
+func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {  
+  %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  // expected-error @+1 {{shape 'm' must be 64}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 32, n = 16, k = 16>, 
+      D [%result, <zero>], 
+      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+      -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  return 
+}
+
+// -----
+
+func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { 
+  %result = llvm.mlir.undef : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)> 
+  // expected-error @+1 {{op all elements in struct must be same type but there is 'i32'}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 16, k = 16>, 
+      D [%result, <zero>], 
+      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)> 
+      -> !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)> 
+  return 
+}
+
+// -----
+
+func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {  
+  %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // expected-error @+1 {{op shape 'k'  must be 16 for input type f16}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 16, k = 3>, 
+      D [%result, <zero>], 
+      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+      -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  return 
+}
+
+// -----
+
+func.func @wgmma_transpose(%descA : i64, %descB : i64) {
+  %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // expected-error @+1 {{op given layouts layout_a = col and layout_b = col for input types tf32 and tf32 requires transpose. However, this is only supported for: f16 and bf16}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 16, k = 8>, 
+      D [%result, <zero>], 
+      A [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>]
+      : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+      -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  return 
+}
+
+// -----
+
+func.func @wgmma_transpose(%descA : i64, %descB : i64) {  
+  %result = llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16)>
+  // expected-error @+1 {{'nvvm.wgmma.mma_async' op 'f16' += tf32 * tf32, it is not supported.}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 16, k = 8>, 
+      D [%result, <zero>], 
+      A [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>]
+      :!llvm.struct<(f16, f16, f16, f16)>
+      -> !llvm.struct<(f16, f16, f16, f16)>
+  return 
+}
+
+// -----
+
+func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {  
+  %result = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)>
+  // expected-error @+1 {{input struct and result struct must be the same type}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 8, k = 16>, 
+      D [%result, <zero>], 
+      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      : !llvm.struct<(i32, i32, i32, i32)> 
+      -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  return 
+}
\ No newline at end of file

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 9ba913b9d3ea2a..342fc2fc0430b1 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-nvvm-to-llvm --convert-arith-to-llvm --split-input-file %s | FileCheck %s
 // Same below, but using the `ConvertToLLVMPatternInterface` entry point
 // and the generic `convert-to-llvm` pass.
 // RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
@@ -103,3 +103,168 @@ func.func @wgmma_execute() {
   // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32)
   return
 }
+
+// -----
+
+!mat64f32 = !llvm.struct<(
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_f16_f16(
+// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{  
+  // CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+  // CHECK: %[[A1:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32
+  // CHECK: %[[A3:.*]] = llvm.mlir.constant(-1 : i32) : i32
+  // CHECK: %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[A5:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[V11:.*]] = llvm.extractvalue %[[RES]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  
+  // CHECK: %[[V13:.*]] = llvm.extractvalue %[[RES]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[RES1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $18, 0;\0Awgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $16, $17, p, $19,  $20, $21,  $22;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" %[[V0]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11]], %{{.*}}, %[[V13]], %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+  // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+  // CHECK: %[[DESCa:.+]] = llvm.add %[[ARG0]], %[[C2]] : i64
+  // CHECK: %[[DESCb:.+]] = llvm.add %[[ARG1]], %[[C2]] : i64
+  // CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES1]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[V4_2:.*]] = llvm.extractvalue %[[RES1]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[V11_2:.*]] = llvm.extractvalue %[[RES1]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  
+  // CHECK: %[[V13_2:.*]] = llvm.extractvalue %[[RES1]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $18, 0;\0Awgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $16, $17, p, $19,  $20, $21,  $22;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" %[[V0_2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4_2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11_2]], %{{.*}}, %[[V13_2]], %{{.*}}, %{{.*}}, %[[DESCa]], %[[DESCb]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+  %result = llvm.mlir.undef : !mat64f32
+  %result1 = nvvm.wgmma.mma_async 
+      %descA, %descB, 
+      #nvvm.shape<m = 64, n = 32, k = 16>, 
+      D [%result, #nvvm.wgmma_scale_out<zero>],
+      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      :!mat64f32 -> !mat64f32
+  %c2 = arith.constant 2 : i64
+  %descAnext = arith.addi %descA, %c2 : i64
+  %descBnext = arith.addi %descB, %c2 : i64
+  %result2 = nvvm.wgmma.mma_async 
+      %descAnext, %descBnext, 
+      #nvvm.shape<m = 64, n = 32, k = 16>, 
+      D [%result1, #nvvm.wgmma_scale_out<zero>],
+      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      : !mat64f32 -> !mat64f32
+  return %result2 : !mat64f32
+}
+
+// -----
+
+!mat16i32 = !llvm.struct<(i32, i32, i32, i32)>
+
+// CHECK-LABEL: @wgmma_s32_s8_s8_satfinite(
+// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{  
+  %result = llvm.mlir.undef : !mat16i32
+// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
+// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
+// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
+// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
+// CHECK: %[[RES_2:.*]] =  llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
+// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
+// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
+// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
+// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
+// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
+// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
+// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
+// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
+// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} 
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [%result, #nvvm.wgmma_scale_out<one>, <satfinite>],
+      A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+      : !mat16i32 -> !mat16i32
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [%result1, #nvvm.wgmma_scale_out<one>, <satfinite>],
+      A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+      : !mat16i32 -> !mat16i32
+  %result3 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [%result2, #nvvm.wgmma_scale_out<one>, <satfinite>],
+      A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+      : !mat16i32 -> !mat16i32
+  return %result3 : !mat16i32
+}
+
+// CHECK-LABEL: @wgmma_s32_u8_u8(
+  // CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {  
+// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
+// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
+// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
+// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
+// CHECK: %[[RES_2:.*]] =  llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] : (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
+// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
+// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
+// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
+// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
+// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
+// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
+// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
+// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
+// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $6, 0;\0Awgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $4, $5, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} 
+  %result = llvm.mlir.undef : !mat16i32
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [%result, #nvvm.wgmma_scale_out<one>],
+      A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+      : !mat16i32 -> !mat16i32
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [%result1, #nvvm.wgmma_scale_out<one>],
+      A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+      : !mat16i32 -> !mat16i32
+  %result3 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [%result2, #nvvm.wgmma_scale_out<one>],
+      A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+      : !mat16i32 -> !mat16i32
+  return %result3 : !mat16i32
+}
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_tf32_tf32
+func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {  
+  // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35,  $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A.reg .pred p;\0Asetp.ne.b32 p, $34, 0;\0Awgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $32, $33, p, $35,  $36;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  %result = llvm.mlir.undef : !mat32f32
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 64, k = 8>, 
+      D [%result, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+       : !mat32f32 -> !mat32f32
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, 
+      #nvvm.shape<m = 64, n = 64, k = 8>, 
+      D [%result1, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.mma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      : !mat32f32 -> !mat32f32
+  return %result2 : !mat32f32
+}


        


More information about the Mlir-commits mailing list