[Mlir-commits] [mlir] [MLIR][NVVM] Explicit Data Type for Output in `wgmma.mma_async` (PR #78713)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 19 05:43:39 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-llvm

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

The current implementation of `nvvm.wgmma.mma_async` Op deduces the data type of the output matrix from the data type struct member, which can be non-intuitive, especially in cases where types like `2xf16` are packed into `i32`.

This PR addresses this issue by improving the Op to include an explicit data type for the output matrix.

The modified Op now includes an explicit data type for Matrix-D (<f16>), and looks as follows:

```
%result = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, ...
nvvm.wgmma.mma_async
    %descA, %descB, %result,
    #nvvm.shape<m = 64, n = 32, k = 16>,
    D [<f16>, #nvvm.wgmma_scale_out<zero>],
    A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
    B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
```

---

Patch is 42.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78713.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+7-3) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+11-4) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+58-61) 
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+9-9) 
- (modified) mlir/test/Conversion/NVVMToLLVM/invalid.mlir (+21-21) 
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+28-28) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f98..b1bd3a95068076 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1833,11 +1833,14 @@ def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>;
 def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>;
 def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>;
 def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>;
+def WGMMATypeF32 : I32EnumAttrCase<"f32", 8>;
+def WGMMATypeS32 : I32EnumAttrCase<"s32", 9>;
+
 def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types",
   [WGMMATypeF16, WGMMATypeTF32,
     WGMMATypeU8, WGMMATypeS8,
     WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3, 
-    WGMMATypeF8E5M2]> {
+    WGMMATypeF8E5M2, WGMMATypeF32, WGMMATypeS32]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::NVVM";
 }
@@ -1859,6 +1862,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
     NVVM_MMAShapeAttr:$shape,
     WGMMATypesAttr:$typeA,
     WGMMATypesAttr:$typeB,
+    WGMMATypesAttr:$typeD,
     WGMMAScaleOutAttr:$scaleD,
     WGMMAScaleInAttr:$scaleA,
     WGMMAScaleInAttr:$scaleB, 
@@ -1868,8 +1872,8 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
   );  
   
    let assemblyFormat = [{ 
-      $descriptorA `,` $descriptorB `,` $shape `,` 
-      `D` `[` $inouts `,` $scaleD (`,` $satfinite^)? `]` `,`
+      $descriptorA `,` $descriptorB `,` $inouts `,` $shape `,`
+      `D` `[` $typeD `,` $scaleD (`,` $satfinite^)? `]` `,`
       `A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,` 
       `B` `[` $typeB `,` $scaleB `,` $layoutB `]`
       attr-dict `:` 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5..9950499817789d 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1267,10 +1267,11 @@ struct NVGPUWarpgroupMmaOpLowering
     }
 
     /// Generates WGMMATypesAttr from MLIR Type
-    NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
-      auto getWgmmaType = [](Type elemType) {
+    NVVM::WGMMATypesAttr generateWgmmaType(Type type,
+                                           bool useF32 = false) const {
+      auto getWgmmaType = [=](Type elemType) {
         if (elemType.isF32() || elemType.isTF32())
-          return NVVM::WGMMATypes::tf32;
+          return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
         if (elemType.isF16())
           return NVVM::WGMMATypes::f16;
         if (elemType.isBF16())
@@ -1285,6 +1286,8 @@ struct NVGPUWarpgroupMmaOpLowering
           return NVVM::WGMMATypes::s8;
         if (elemType.isUnsignedInteger(8))
           return NVVM::WGMMATypes::u8;
+        if (elemType.isInteger(32))
+          return NVVM::WGMMATypes::s32;
         llvm_unreachable("unsupported type");
       };
       return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
@@ -1397,6 +1400,9 @@ struct NVGPUWarpgroupMmaOpLowering
       Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
       NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
 
+      Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
+      NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
+
       NVVM::MMAShapeAttr shape = generateWgmmaShape();
       NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
       NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
@@ -1408,7 +1414,8 @@ struct NVGPUWarpgroupMmaOpLowering
 
       return b.create<NVVM::WgmmaMmaAsyncOp>(
           matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
-          itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+          itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+          overflow);
     }
 
     /// Generates multiple wgmma instructions to complete the given GEMM shape
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc0..bb720603819407 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -755,37 +755,44 @@ FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
   return failure();
 }
 
-LogicalResult isAllowedWGMMADataType(Type typeD, NVVM::WGMMATypes typeA,
+LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
+                                     NVVM::WGMMATypes typeA,
                                      NVVM::WGMMATypes typeB) {
   switch (typeA) {
   case NVVM::WGMMATypes::f16:
-    if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::f16)
+    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+        typeB == NVVM::WGMMATypes::f16)
       return success();
     break;
   case NVVM::WGMMATypes::tf32:
-    if (typeD.isF32() && typeB == NVVM::WGMMATypes::tf32)
+    if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
       return success();
     break;
   case NVVM::WGMMATypes::u8:
   case NVVM::WGMMATypes::s8:
-    if (typeD.isInteger(32) &&
+    if (typeD == NVVM::WGMMATypes::s32 &&
         (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
       return success();
     break;
   case NVVM::WGMMATypes::b1:
-    if (typeD.isInteger(32) && typeB == NVVM::WGMMATypes::b1)
+    if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
       return success();
     break;
   case NVVM::WGMMATypes::bf16:
-    if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::bf16)
+    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+        typeB == NVVM::WGMMATypes::bf16)
       return success();
     break;
   case NVVM::WGMMATypes::e4m3:
   case NVVM::WGMMATypes::e5m2:
-    if ((typeD.isF32() || typeD.isF16()) &&
+    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
         (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
       return success();
     break;
+  case WGMMATypes::f32:
+  case WGMMATypes::s32:
+    llvm_unreachable("unsupported input types");
+    break;
   }
   return failure();
 }
@@ -799,19 +806,24 @@ LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
                                     80,  96,  112, 128, 144, 160,
                                     176, 192, 208, 224, 240, 256};
   switch (typeA) {
-  case mlir::NVVM::WGMMATypes::f16:
-  case mlir::NVVM::WGMMATypes::tf32:
-  case mlir::NVVM::WGMMATypes::bf16:
-  case mlir::NVVM::WGMMATypes::e4m3:
-  case mlir::NVVM::WGMMATypes::e5m2:
+  case WGMMATypes::f16:
+  case WGMMATypes::tf32:
+  case WGMMATypes::bf16:
+  case WGMMATypes::e4m3:
+  case WGMMATypes::e5m2:
     if (llvm::is_contained(allowedN, sizeN))
       return success();
     break;
-  case mlir::NVVM::WGMMATypes::u8:
-  case mlir::NVVM::WGMMATypes::s8:
-  case mlir::NVVM::WGMMATypes::b1:
+  case WGMMATypes::u8:
+  case WGMMATypes::s8:
+  case WGMMATypes::b1:
     if (llvm::is_contained(allowedNshort, sizeN))
       return success();
+    break;
+  case WGMMATypes::f32:
+  case WGMMATypes::s32:
+    llvm_unreachable("unsupported input types");
+    break;
   }
   return failure();
 }
@@ -821,27 +833,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
   if (!stype)
     return emitOpError() << "expected results to be struct";
-  Type outputType = stype.getBody().front();
   int outputSize = stype.getBody().size();
+  WGMMATypes typeD = getTypeD();
+  WGMMATypes typeA = getTypeA();
+  WGMMATypes typeB = getTypeB();
+
   for (Type t : stype.getBody()) {
-    if (t != outputType)
+    if (t != stype.getBody().front())
       return emitOpError()
              << "all elements in struct must be same type but there is " << t;
   }
 
-  if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) {
+  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
+      typeD != WGMMATypes::s32) {
     return emitOpError() << "does not support the given output type "
-                         << outputType;
+                         << NVVM::stringifyWGMMATypes(typeD);
   }
-  if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg ||
-                                   getScaleB() == NVVM::WGMMAScaleIn::neg)) {
-    return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
+  if (typeD == WGMMATypes::s32 &&
+      (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
+    return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg ";
   }
 
-  mlir::NVVM::WGMMATypes typeA = getTypeA();
-  mlir::NVVM::WGMMATypes typeB = getTypeB();
-  if (failed(isAllowedWGMMADataType(outputType, typeA, typeB))) {
-    return emitOpError() << outputType
+  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
+    return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
                          << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
                          << NVVM::stringifyWGMMATypes(typeB)
                          << ", it is not supported.";
@@ -866,8 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   }
 
   // Check transpose (only available for f16/bf16)
-  if ((typeA != mlir::NVVM::WGMMATypes::f16 &&
-       typeA != mlir::NVVM::WGMMATypes::bf16) &&
+  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
       (getLayoutA() == mlir::NVVM::MMALayout::col ||
        getLayoutB() == mlir::NVVM::MMALayout::col)) {
     return emitOpError()
@@ -876,29 +889,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
            << " for input types " << stringifyWGMMATypes(typeA) << " and "
            << stringifyWGMMATypes(typeB)
            << " requires transpose. However, this is only supported for: "
-           << stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and "
-           << stringifyMMATypes(mlir::NVVM::MMATypes::bf16);
+           << stringifyMMATypes(MMATypes::f16) << " and "
+           << stringifyMMATypes(MMATypes::bf16);
   }
 
   // Check result registers
-  int expectedOutput;
-  if (outputType.isF32() || outputType.isInteger(32))
+  int expectedOutput = 0;
+  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
     expectedOutput = getShape().getN() / 2;
-  if (outputType.isF16())
+  if (typeD == WGMMATypes::f16)
     expectedOutput = getShape().getN() / 4;
   if (outputSize != expectedOutput) {
     return emitOpError() << "results " << expectedOutput
                          << ", however output struct has " << outputSize
                          << " elements";
   }
-  // Check satfinite (only availalbe for s32 accumulator)
-  if (!outputType.isInteger(32) &&
+  // Check satfinite (only available for s32 accumulator)
+  if (typeD != WGMMATypes::s32 &&
       getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
           NVVM::MMAIntOverflow::satfinite) {
     return emitOpError()
            << " `satfinite` can be only used with s32 accumulator, however "
               "the current accumulator is "
-           << outputType;
+           << NVVM::stringifyWGMMATypes(typeD);
   }
 
   return success();
@@ -907,27 +920,15 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
 
   int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
-  bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
-               getTypeA() == mlir::NVVM::WGMMATypes::bf16;
+  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
 
-  Value outValue = getResults() ? getResults() : getInouts();
-  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
-  Type outputType = stype.getBody().front();
-  std::string outputTypeName;
-  if (outputType.isF16())
-    outputTypeName = "f16";
-  else if (outputType.isF32())
-    outputTypeName = "f32";
-  else if (outputType.isInteger(32))
-    outputTypeName = "s32";
-  else
-    assert(false && "unsupported output type");
+  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
 
-  int expectedOutputRegisters;
-  if (outputType.isF32() || outputType.isInteger(32))
-    expectedOutputRegisters = getShape().getN() / 2;
-  if (outputType.isF16())
+  int expectedOutputRegisters = 0;
+  if (getTypeD() == WGMMATypes::f16)
     expectedOutputRegisters = getShape().getN() / 4;
+  else
+    expectedOutputRegisters = getShape().getN() / 2;
 
   std::string ptx;
   llvm::raw_string_ostream ss(ptx);
@@ -958,7 +959,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   ss << " $" << (regCnt) << ","
      << " $" << (regCnt + 1) << ","
      << " p";
-  if (!outputType.isInteger(32)) {
+  if (getTypeD() != WGMMATypes::s32) {
     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
   }
   // Don't add transpose parameters unless needed.
@@ -975,11 +976,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
     RewriterBase &rewriter,
     llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
         &asmValues) {
-  Value outValue = getResults() ? getResults() : getInouts();
-  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
-  Type outputType = stype.getBody().front();
-  bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
-               getTypeA() == mlir::NVVM::WGMMATypes::bf16;
+  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
   if (getResults())
     asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
   if (getInouts())
@@ -988,7 +985,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
   asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
   asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
                        mlir::NVVM::PTXRegisterMod::Read});
-  if (!outputType.isInteger(32)) {
+  if (getTypeD() != WGMMATypes::s32) {
     asmValues.push_back(
         {makeConstantI32(rewriter,
                          getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bd..3ca970f412833f 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -880,41 +880,41 @@ func.func @warpgroup_mma_128_128_64(
 // CHECK: nvvm.wgmma.fence.aligned
 // CHECK: %[[UD:.+]] =  llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
-// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S2]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S2]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64
 // CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64
 // CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64
 // CHECK: %[[S8:.+]] = llvm.add %[[S1]], %[[S7]]  : i64
-// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], <m = 64, n = 128, k = 16>, D[%[[S4]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], %[[S4]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
 // CHECK: %[[S10:.+]] = llvm.mlir.constant(4 : i32) : i64
 // CHECK: %[[S11:.+]] = llvm.add %[[S0]], %[[S10]]  : i64
 // CHECK: %[[S12:.+]] = llvm.mlir.constant(256 : i32) : i64
 // CHECK: %[[S13:.+]] = llvm.add %[[S1]], %[[S12]]  : i64
-// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], <m = 64, n = 128, k = 16>, D[%[[S9]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], %[[S9]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
 // CHECK: %[[S15:.+]] = llvm.mlir.constant(6 : i32) : i64
 // CHECK: %[[S16:.+]] = llvm.add %[[S0]], %[[S15]]  : i64
 // CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64
 // CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]]  : i64
-// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], <m = 64, n = 128, k = 16>, D[%[[S14]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], %[[S14]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
 // CHECK: %[[S3:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64
 // CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]]  : i64
-// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S3]], <one>, <wrapped>], A[<f1...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/78713


More information about the Mlir-commits mailing list