[Mlir-commits] [mlir] [MLIR][NVVM] Explicit Data Type for Output in `wgmma.mma_async` (PR #78713)

Guray Ozen llvmlistbot at llvm.org
Fri Jan 19 05:43:11 PST 2024


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/78713

The current implementation of `nvvm.wgmma.mma_async` Op deduces the data type of the output matrix from the data type struct member, which can be non-intuitive, especially in cases where types like `2xf16` are packed into `i32`.

This PR addresses this issue by improving the Op to include an explicit data type for the output matrix.

The modified Op now includes an explicit data type for Matrix-D (<f16>), and looks as follows:

```
%result = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, ...
nvvm.wgmma.mma_async
    %descA, %descB, %result,
    #nvvm.shape<m = 64, n = 32, k = 16>,
    D [<f16>, #nvvm.wgmma_scale_out<zero>],
    A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
    B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
```

>From 3b6e6f00f6c70970c03640de36901c8cfa1f0c87 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 19 Jan 2024 14:41:30 +0100
Subject: [PATCH] [MLIR][NVVM] Explicit Data Type for Output Matrix in
 nvvm.wgmma.mma_async

The current implementation of `nvvm.wgmma.mma_async` Op deduces the data type of the output matrix from the data type struct member, which can be non-intuitive, especially in cases where types like `2xf16` are packed into `i32`.

This PR addresses this issue by improving the Op to include an explicit data type for the output matrix.

The modified Op now includes an explicit data type for Matrix-D (<f16>), and looks as follows:

```
%result = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, ...
nvvm.wgmma.mma_async
    %descA, %descB, %result,
    #nvvm.shape<m = 64, n = 32, k = 16>,
    D [<f16>, #nvvm.wgmma_scale_out<zero>],
    A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
    B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
```
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  10 +-
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |  15 ++-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 119 +++++++++---------
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |  18 +--
 mlir/test/Conversion/NVVMToLLVM/invalid.mlir  |  42 +++----
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   |  56 ++++-----
 6 files changed, 134 insertions(+), 126 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f98..b1bd3a95068076 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1833,11 +1833,14 @@ def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>;
 def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>;
 def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>;
 def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>;
+def WGMMATypeF32 : I32EnumAttrCase<"f32", 8>;
+def WGMMATypeS32 : I32EnumAttrCase<"s32", 9>;
+
 def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types",
   [WGMMATypeF16, WGMMATypeTF32,
     WGMMATypeU8, WGMMATypeS8,
     WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3, 
-    WGMMATypeF8E5M2]> {
+    WGMMATypeF8E5M2, WGMMATypeF32, WGMMATypeS32]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::NVVM";
 }
@@ -1859,6 +1862,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
     NVVM_MMAShapeAttr:$shape,
     WGMMATypesAttr:$typeA,
     WGMMATypesAttr:$typeB,
+    WGMMATypesAttr:$typeD,
     WGMMAScaleOutAttr:$scaleD,
     WGMMAScaleInAttr:$scaleA,
     WGMMAScaleInAttr:$scaleB, 
@@ -1868,8 +1872,8 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
   );  
   
    let assemblyFormat = [{ 
-      $descriptorA `,` $descriptorB `,` $shape `,` 
-      `D` `[` $inouts `,` $scaleD (`,` $satfinite^)? `]` `,`
+      $descriptorA `,` $descriptorB `,` $inouts `,` $shape `,`
+      `D` `[` $typeD `,` $scaleD (`,` $satfinite^)? `]` `,`
       `A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,` 
       `B` `[` $typeB `,` $scaleB `,` $layoutB `]`
       attr-dict `:` 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5..9950499817789d 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1267,10 +1267,11 @@ struct NVGPUWarpgroupMmaOpLowering
     }
 
     /// Generates WGMMATypesAttr from MLIR Type
-    NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
-      auto getWgmmaType = [](Type elemType) {
+    NVVM::WGMMATypesAttr generateWgmmaType(Type type,
+                                           bool useF32 = false) const {
+      auto getWgmmaType = [=](Type elemType) {
         if (elemType.isF32() || elemType.isTF32())
-          return NVVM::WGMMATypes::tf32;
+          return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
         if (elemType.isF16())
           return NVVM::WGMMATypes::f16;
         if (elemType.isBF16())
@@ -1285,6 +1286,8 @@ struct NVGPUWarpgroupMmaOpLowering
           return NVVM::WGMMATypes::s8;
         if (elemType.isUnsignedInteger(8))
           return NVVM::WGMMATypes::u8;
+        if (elemType.isInteger(32))
+          return NVVM::WGMMATypes::s32;
         llvm_unreachable("unsupported type");
       };
       return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
@@ -1397,6 +1400,9 @@ struct NVGPUWarpgroupMmaOpLowering
       Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
       NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
 
+      Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
+      NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
+
       NVVM::MMAShapeAttr shape = generateWgmmaShape();
       NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
       NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
@@ -1408,7 +1414,8 @@ struct NVGPUWarpgroupMmaOpLowering
 
       return b.create<NVVM::WgmmaMmaAsyncOp>(
           matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
-          itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+          itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+          overflow);
     }
 
     /// Generates multiple wgmma instructions to complete the given GEMM shape
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc0..bb720603819407 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -755,37 +755,44 @@ FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
   return failure();
 }
 
-LogicalResult isAllowedWGMMADataType(Type typeD, NVVM::WGMMATypes typeA,
+LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
+                                     NVVM::WGMMATypes typeA,
                                      NVVM::WGMMATypes typeB) {
   switch (typeA) {
   case NVVM::WGMMATypes::f16:
-    if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::f16)
+    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+        typeB == NVVM::WGMMATypes::f16)
       return success();
     break;
   case NVVM::WGMMATypes::tf32:
-    if (typeD.isF32() && typeB == NVVM::WGMMATypes::tf32)
+    if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
       return success();
     break;
   case NVVM::WGMMATypes::u8:
   case NVVM::WGMMATypes::s8:
-    if (typeD.isInteger(32) &&
+    if (typeD == NVVM::WGMMATypes::s32 &&
         (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
       return success();
     break;
   case NVVM::WGMMATypes::b1:
-    if (typeD.isInteger(32) && typeB == NVVM::WGMMATypes::b1)
+    if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
       return success();
     break;
   case NVVM::WGMMATypes::bf16:
-    if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::bf16)
+    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+        typeB == NVVM::WGMMATypes::bf16)
       return success();
     break;
   case NVVM::WGMMATypes::e4m3:
   case NVVM::WGMMATypes::e5m2:
-    if ((typeD.isF32() || typeD.isF16()) &&
+    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
         (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
       return success();
     break;
+  case WGMMATypes::f32:
+  case WGMMATypes::s32:
+    llvm_unreachable("unsupported input types");
+    break;
   }
   return failure();
 }
@@ -799,19 +806,24 @@ LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
                                     80,  96,  112, 128, 144, 160,
                                     176, 192, 208, 224, 240, 256};
   switch (typeA) {
-  case mlir::NVVM::WGMMATypes::f16:
-  case mlir::NVVM::WGMMATypes::tf32:
-  case mlir::NVVM::WGMMATypes::bf16:
-  case mlir::NVVM::WGMMATypes::e4m3:
-  case mlir::NVVM::WGMMATypes::e5m2:
+  case WGMMATypes::f16:
+  case WGMMATypes::tf32:
+  case WGMMATypes::bf16:
+  case WGMMATypes::e4m3:
+  case WGMMATypes::e5m2:
     if (llvm::is_contained(allowedN, sizeN))
       return success();
     break;
-  case mlir::NVVM::WGMMATypes::u8:
-  case mlir::NVVM::WGMMATypes::s8:
-  case mlir::NVVM::WGMMATypes::b1:
+  case WGMMATypes::u8:
+  case WGMMATypes::s8:
+  case WGMMATypes::b1:
     if (llvm::is_contained(allowedNshort, sizeN))
       return success();
+    break;
+  case WGMMATypes::f32:
+  case WGMMATypes::s32:
+    llvm_unreachable("unsupported input types");
+    break;
   }
   return failure();
 }
@@ -821,27 +833,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
   if (!stype)
     return emitOpError() << "expected results to be struct";
-  Type outputType = stype.getBody().front();
   int outputSize = stype.getBody().size();
+  WGMMATypes typeD = getTypeD();
+  WGMMATypes typeA = getTypeA();
+  WGMMATypes typeB = getTypeB();
+
   for (Type t : stype.getBody()) {
-    if (t != outputType)
+    if (t != stype.getBody().front())
       return emitOpError()
              << "all elements in struct must be same type but there is " << t;
   }
 
-  if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) {
+  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
+      typeD != WGMMATypes::s32) {
     return emitOpError() << "does not support the given output type "
-                         << outputType;
+                         << NVVM::stringifyWGMMATypes(typeD);
   }
-  if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg ||
-                                   getScaleB() == NVVM::WGMMAScaleIn::neg)) {
-    return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
+  if (typeD == WGMMATypes::s32 &&
+      (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
+    return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg ";
   }
 
-  mlir::NVVM::WGMMATypes typeA = getTypeA();
-  mlir::NVVM::WGMMATypes typeB = getTypeB();
-  if (failed(isAllowedWGMMADataType(outputType, typeA, typeB))) {
-    return emitOpError() << outputType
+  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
+    return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
                          << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
                          << NVVM::stringifyWGMMATypes(typeB)
                          << ", it is not supported.";
@@ -866,8 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   }
 
   // Check transpose (only available for f16/bf16)
-  if ((typeA != mlir::NVVM::WGMMATypes::f16 &&
-       typeA != mlir::NVVM::WGMMATypes::bf16) &&
+  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
       (getLayoutA() == mlir::NVVM::MMALayout::col ||
        getLayoutB() == mlir::NVVM::MMALayout::col)) {
     return emitOpError()
@@ -876,29 +889,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
            << " for input types " << stringifyWGMMATypes(typeA) << " and "
            << stringifyWGMMATypes(typeB)
            << " requires transpose. However, this is only supported for: "
-           << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and "
-           << stringifyMMATypes(mlir::NVVM::MMATypes::bf16);
+           << stringifyMMATypes(MMATypes::f16) << " and "
+           << stringifyMMATypes(MMATypes::bf16);
   }
 
   // Check result registers
-  int expectedOutput;
-  if (outputType.isF32() || outputType.isInteger(32))
+  int expectedOutput = 0;
+  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
     expectedOutput = getShape().getN() / 2;
-  if (outputType.isF16())
+  if (typeD == WGMMATypes::f16)
     expectedOutput = getShape().getN() / 4;
   if (outputSize != expectedOutput) {
     return emitOpError() << "results " << expectedOutput
                          << ", however output struct has " << outputSize
                          << " elements";
   }
-  // Check satfinite (only availalbe for s32 accumulator)
-  if (!outputType.isInteger(32) &&
+  // Check satfinite (only available for s32 accumulator)
+  if (typeD != WGMMATypes::s32 &&
       getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
           NVVM::MMAIntOverflow::satfinite) {
     return emitOpError()
            << " `satfinite` can be only used with s32 accumulator, however "
               "the current accumulator is "
-           << outputType;
+           << NVVM::stringifyWGMMATypes(typeD);
   }
 
   return success();
@@ -907,27 +920,15 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
 
   int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
-  bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
-               getTypeA() == mlir::NVVM::WGMMATypes::bf16;
+  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
 
-  Value outValue = getResults() ? getResults() : getInouts();
-  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
-  Type outputType = stype.getBody().front();
-  std::string outputTypeName;
-  if (outputType.isF16())
-    outputTypeName = "f16";
-  else if (outputType.isF32())
-    outputTypeName = "f32";
-  else if (outputType.isInteger(32))
-    outputTypeName = "s32";
-  else
-    assert(false && "unsupported output type");
+  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
 
-  int expectedOutputRegisters;
-  if (outputType.isF32() || outputType.isInteger(32))
-    expectedOutputRegisters = getShape().getN() / 2;
-  if (outputType.isF16())
+  int expectedOutputRegisters = 0;
+  if (getTypeD() == WGMMATypes::f16)
     expectedOutputRegisters = getShape().getN() / 4;
+  else
+    expectedOutputRegisters = getShape().getN() / 2;
 
   std::string ptx;
   llvm::raw_string_ostream ss(ptx);
@@ -958,7 +959,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   ss << " $" << (regCnt) << ","
      << " $" << (regCnt + 1) << ","
      << " p";
-  if (!outputType.isInteger(32)) {
+  if (getTypeD() != WGMMATypes::s32) {
     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
   }
   // Don't add transpose parameters unless needed.
@@ -975,11 +976,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
     RewriterBase &rewriter,
     llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
         &asmValues) {
-  Value outValue = getResults() ? getResults() : getInouts();
-  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
-  Type outputType = stype.getBody().front();
-  bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
-               getTypeA() == mlir::NVVM::WGMMATypes::bf16;
+  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
   if (getResults())
     asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
   if (getInouts())
@@ -988,7 +985,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
   asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
   asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
                        mlir::NVVM::PTXRegisterMod::Read});
-  if (!outputType.isInteger(32)) {
+  if (getTypeD() != WGMMATypes::s32) {
     asmValues.push_back(
         {makeConstantI32(rewriter,
                          getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bd..3ca970f412833f 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -880,41 +880,41 @@ func.func @warpgroup_mma_128_128_64(
 // CHECK: nvvm.wgmma.fence.aligned
 // CHECK: %[[UD:.+]] =  llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
-// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S2]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S2]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64
 // CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64
 // CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64
 // CHECK: %[[S8:.+]] = llvm.add %[[S1]], %[[S7]]  : i64
-// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], <m = 64, n = 128, k = 16>, D[%[[S4]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], %[[S4]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
 // CHECK: %[[S10:.+]] = llvm.mlir.constant(4 : i32) : i64
 // CHECK: %[[S11:.+]] = llvm.add %[[S0]], %[[S10]]  : i64
 // CHECK: %[[S12:.+]] = llvm.mlir.constant(256 : i32) : i64
 // CHECK: %[[S13:.+]] = llvm.add %[[S1]], %[[S12]]  : i64
-// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], <m = 64, n = 128, k = 16>, D[%[[S9]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], %[[S9]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
 // CHECK: %[[S15:.+]] = llvm.mlir.constant(6 : i32) : i64
 // CHECK: %[[S16:.+]] = llvm.add %[[S0]], %[[S15]]  : i64
 // CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64
 // CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]]  : i64
-// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], <m = 64, n = 128, k = 16>, D[%[[S14]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], %[[S14]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
 // CHECK: %[[S3:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64
 // CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]]  : i64
-// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S3]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], %[[S3]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
 // CHECK: %[[S24:.+]] = llvm.mlir.constant(514 : i32) : i64
 // CHECK: %[[S25:.+]] = llvm.add %[[S0]], %[[S24]]  : i64
 // CHECK: %[[S26:.+]] = llvm.mlir.constant(128 : i32) : i64
 // CHECK: %[[S27:.+]] = llvm.add %[[S1]], %[[S26]]  : i64
-// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], <m = 64, n = 128, k = 16>, D[%[[S23]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], %[[S23]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
 // CHECK: %[[S29:.+]] = llvm.mlir.constant(516 : i32) : i64
 // CHECK: %[[S30:.+]] = llvm.add %[[S0]], %[[S29]]  : i64
 // CHECK: %[[S31:.+]] = llvm.mlir.constant(256 : i32) : i64
 // CHECK: %[[S32:.+]] = llvm.add %[[S1]], %[[S31]]  : i64
-// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], <m = 64, n = 128, k = 16>, D[%[[S28]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], %[[S28]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
 // CHECK: %[[S34:.+]] = llvm.mlir.constant(518 : i32) : i64
 // CHECK: %[[S35:.+]] = llvm.add %[[S0]], %[[S34]]  : i64
 // CHECK: %[[S36:.+]] = llvm.mlir.constant(384 : i32) : i64
 // CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]]  : i64
-// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], <m = 64, n = 128, k = 16>, D[%[[S33]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], %[[S33]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
 // CHECK: %[[S40:.+]] = llvm.insertvalue %[[S19]], %[[UD]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: %[[S41:.+]] = llvm.insertvalue %[[S38]], %[[S40]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: nvvm.wgmma.commit.group.sync.aligned
@@ -1169,7 +1169,7 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
 // CHECK: nvvm.wgmma.fence.aligned
 // CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
-// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, <m = 64, n = 128, k = 16>, D[%[[S138]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: nvvm.wgmma.mma_async
 // CHECK: nvvm.wgmma.mma_async
 // CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async
diff --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
index 34c8de9f7ed8c6..9ebe3a009adf25 100644
--- a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
@@ -4,9 +4,9 @@
 func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{  
   %result = llvm.mlir.undef : !mat64f32
   // expected-error @+1 {{'nvvm.wgmma.mma_async' op results 64, however output struct has 7 elements}}
-  %res = nvvm.wgmma.mma_async %descA, %descB, 
+  %res = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 128, k = 16>, 
-      D [%result, <zero>],
+      D [<f32>, <zero>],
       A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
       : !mat64f32 -> !mat64f32
@@ -17,10 +17,10 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
 
 func.func @wgmma_f32_satfinite(%descA : i64, %descB : i64) {  
   %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-  // expected-error @+1 {{`satfinite` can be only used with s32 accumulator, however the current accumulator is 'f32'}}
-  %res = nvvm.wgmma.mma_async %descA, %descB, 
+  // expected-error @+1 {{`satfinite` can be only used with s32 accumulator, however the current accumulator is f32}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 16, k = 16>, 
-      D [%result, <zero>, <satfinite>], 
+      D [<f32>, <zero>, <satfinite>], 
       A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
       : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
@@ -33,9 +33,9 @@ func.func @wgmma_f32_satfinite(%descA : i64, %descB : i64) {
 func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {  
   %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   // expected-error @+1 {{shape 'm' must be 64}}
-  %res = nvvm.wgmma.mma_async %descA, %descB, 
+  %res = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 32, n = 16, k = 16>, 
-      D [%result, <zero>], 
+      D [<f32>, <zero>], 
       A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
       : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
@@ -48,9 +48,9 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
 func.func @wgmma_f32_m32(%descA : i64, %descB : i64) { 
   %result = llvm.mlir.undef : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)> 
   // expected-error @+1 {{op all elements in struct must be same type but there is 'i32'}}
-  %res = nvvm.wgmma.mma_async %descA, %descB, 
+  %res = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 16, k = 16>, 
-      D [%result, <zero>], 
+      D [<f32>, <zero>], 
       A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
       : !llvm.struct<(f32, f32, i32, f32, f32, f32, f32, f32)> 
@@ -63,9 +63,9 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
 func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {  
   %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
   // expected-error @+1 {{op shape 'k' must be 16 for input type f16}}
-  %res = nvvm.wgmma.mma_async %descA, %descB, 
+  %res = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 16, k = 3>, 
-      D [%result, <zero>], 
+      D [<f32>, <zero>], 
       A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
       : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
@@ -78,9 +78,9 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
 func.func @wgmma_transpose(%descA : i64, %descB : i64) {
   %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
   // expected-error @+1 {{op given layouts layout_a = col and layout_b = col for input types tf32 and tf32 requires transpose. However, this is only supported for: f16 and bf16}}
-  %res = nvvm.wgmma.mma_async %descA, %descB, 
+  %res = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 16, k = 8>, 
-      D [%result, <zero>], 
+      D [<f32>, <zero>], 
       A [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>]
       : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
@@ -92,10 +92,10 @@ func.func @wgmma_transpose(%descA : i64, %descB : i64) {
 
 func.func @wgmma_transpose(%descA : i64, %descB : i64) {  
   %result = llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16)>
-  // expected-error @+1 {{'nvvm.wgmma.mma_async' op 'f16' += tf32 * tf32, it is not supported.}}
-  %res = nvvm.wgmma.mma_async %descA, %descB, 
+  // expected-error @+1 {{'nvvm.wgmma.mma_async' op f16 += tf32 * tf32, it is not supported.}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 16, k = 8>, 
-      D [%result, <zero>], 
+      D [<f16>, <zero>], 
       A [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<tf32>, #nvvm.wgmma_scale_in<neg>, <col>]
       :!llvm.struct<(f16, f16, f16, f16)>
@@ -108,9 +108,9 @@ func.func @wgmma_transpose(%descA : i64, %descB : i64) {
 func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {  
   %result = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)>
   // expected-error @+1 {{input struct and result struct must be the same type}}
-  %res = nvvm.wgmma.mma_async %descA, %descB, 
+  %res = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 8, k = 16>, 
-      D [%result, <zero>], 
+      D [<f16>, <zero>], 
       A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
       : !llvm.struct<(i32, i32, i32, i32)> 
@@ -122,10 +122,10 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
 
 func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {  
   %result = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
-  // expected-error @+1 {{op 'f32' += bf16 * f16, it is not supported}}
-  %res = nvvm.wgmma.mma_async %descA, %descB, 
+  // expected-error @+1 {{op f32 += bf16 * f16, it is not supported}}
+  %res = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 8, k = 16>, 
-      D [%result, <zero>], 
+      D [<f32>, <zero>], 
       A [<bf16>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
       : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9487bdf3bd218..9c7c27c49eb11d 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -329,9 +329,9 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
     // CHECK-SAME: %[[V0_2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4_2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11_2]], %{{.*}}, %[[V13_2]], %{{.*}}, %{{.*}}, %[[DESCa]], %[[DESCb]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} 
   %result = llvm.mlir.undef : !mat64f32
   %result1 = nvvm.wgmma.mma_async 
-      %descA, %descB, 
+      %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 32, k = 16>, 
-      D [%result, #nvvm.wgmma_scale_out<zero>],
+      D [<f32>, #nvvm.wgmma_scale_out<zero>],
       A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
       :!mat64f32 -> !mat64f32
@@ -339,9 +339,9 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
   %descAnext = arith.addi %descA, %c2 : i64
   %descBnext = arith.addi %descB, %c2 : i64
   %result2 = nvvm.wgmma.mma_async 
-      %descAnext, %descBnext, 
+      %descAnext, %descBnext, %result1,
       #nvvm.shape<m = 64, n = 32, k = 16>, 
-      D [%result1, #nvvm.wgmma_scale_out<zero>],
+      D [<f32>, #nvvm.wgmma_scale_out<zero>],
       A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
       B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
       : !mat64f32 -> !mat64f32
@@ -393,21 +393,21 @@ func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{
 // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite
 // CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" 
 // CHECK-SAME: %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} 
-  %result1 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result, 
       #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [%result, #nvvm.wgmma_scale_out<one>, <satfinite>],
+      D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
       A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
       B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
       : !mat16i32 -> !mat16i32
-  %result2 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1, 
       #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [%result1, #nvvm.wgmma_scale_out<one>, <satfinite>],
+      D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
       A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
       B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
       : !mat16i32 -> !mat16i32
-  %result3 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2, 
       #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [%result2, #nvvm.wgmma_scale_out<one>, <satfinite>],
+      D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
       A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
       B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
       : !mat16i32 -> !mat16i32
@@ -454,21 +454,21 @@ func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {
 // CHECK-SAME:}\0A", 
 // CHECK-SAME:"=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} 
   %result = llvm.mlir.undef : !mat16i32
-  %result1 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [%result, #nvvm.wgmma_scale_out<one>],
+      D [<s32>, #nvvm.wgmma_scale_out<one>],
       A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
       B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
       : !mat16i32 -> !mat16i32
-  %result2 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [%result1, #nvvm.wgmma_scale_out<one>],
+      D [<s32>, #nvvm.wgmma_scale_out<one>],
       A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
       B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
       : !mat16i32 -> !mat16i32
-  %result3 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
       #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [%result2, #nvvm.wgmma_scale_out<one>],
+      D [<s32>, #nvvm.wgmma_scale_out<one>],
       A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
       B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
       : !mat16i32 -> !mat16i32
@@ -496,15 +496,15 @@ func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
   // CHECK-SAME: setp.ne.b32 p, $66, 0;
   // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
   %result = llvm.mlir.undef : !mat32f32
-  %result1 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 64, k = 8>, 
-      D [%result, #nvvm.wgmma_scale_out<one>],
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
       B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
        : !mat32f32 -> !mat32f32
-  %result2 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 64, k = 8>, 
-      D [%result1, #nvvm.wgmma_scale_out<one>],
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
       B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
       : !mat32f32 -> !mat32f32
@@ -529,15 +529,15 @@ func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
   // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
   // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
   %result = llvm.mlir.undef : !mat32f32
-  %result1 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 64, k = 32>, 
-      D [%result, #nvvm.wgmma_scale_out<one>],
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
       B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
        : !mat32f32 -> !mat32f32
-  %result2 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 64, k = 32>, 
-      D [%result1, #nvvm.wgmma_scale_out<one>],
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
       B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
       : !mat32f32 -> !mat32f32
@@ -561,15 +561,15 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
   // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
   // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
   %result = llvm.mlir.undef : !mat32f32
-  %result1 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
       #nvvm.shape<m = 64, n = 64, k = 32>, 
-      D [%result, #nvvm.wgmma_scale_out<one>],
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
       B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
        : !mat32f32 -> !mat32f32
-  %result2 = nvvm.wgmma.mma_async %descA, %descB, 
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 64, k = 32>, 
-      D [%result1, #nvvm.wgmma_scale_out<one>],
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
       B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
       : !mat32f32 -> !mat32f32



More information about the Mlir-commits mailing list