[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)

Guray Ozen llvmlistbot at llvm.org
Tue Sep 5 23:37:01 PDT 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/65441:

[MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs

This work introduces a new operation called `wargroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref.

An example of fragmentation is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d

The `wargroup.mma.store` does followings:
1) Takes one or more fragmented results matrix.
2) Calculates indexes per thread in warp group and stores the data into give memref.

Here's an example usage of the `nvgpu.wargroup.mma` operation:
```
// Performs matmul, results are fragmented and in registers
%res, %res2 = nvgpu.wargroup.mma ...

// Stores the fragmented result to the give memory
nvgpu.wargroup.mma.store [%res1, %res2], %matrixD : 
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>, 
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>> 
                to memref<128x128xf32,3>
```

Depends on #65440



>From c013a2133589f59aac77786e63eed2e1b52ff72f Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 1 Sep 2023 09:30:56 +0200
Subject: [PATCH 1/2] [MLIR][NVGPU] Adding `nvgpu.wargroup.mma` Op for Hopper
 GPUs

This work introduces a new operation called `wargroup.mma` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate warpgroup-level matrix multiply and accumulate (WGMMA) operations on Hopper GPUs with sm_90a architecture.

Previously, the `nvvm.wgmma.mma_async` operation was introduced to support wargroup-level matrix operations in NVVM dialect. This op is used multiple instances of `nvvm.wgmma.mma_async` to achieve the desired shape. The new `nvgpu.wargroup.mma` operation abstracts this complexity and provides a higher-level interface for performing wargroup-level matrix operations.

The `nvgpu.wargroup.mma` does followings:
1) Corresponds multiple `wgmma` instructions.
2) Iterates input matrix descriptors to achieve the desired computation shape.
3) Groups and runs `wgmma` instructions asynchronously, and eventually waits them. This are done by `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned`, and `wgmma.wait.group.sync.aligned`
4) Results fragmented matrices

Here's an example usage of the `nvgpu.wargroup.mma` operation:
```
%wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}:
      !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
      !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
      vector<128x128xf32>
      -> !nvgpu.warpgroup.result<tensor = !llvm.struct<...>,
  		 !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
```

Differential Revision: https://reviews.llvm.org/D158434
---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   |  48 +++++
 .../mlir/Dialect/NVGPU/IR/NVGPUDialect.h      |   2 +
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 169 +++++++++++++++++-
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp    | 108 ++++++++++-
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |  66 ++++++-
 mlir/test/Dialect/NVGPU/invalid.mlir          |  61 +++++++
 6 files changed, 446 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index a3245bf9196eed1..f891aae136eba25 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -192,6 +192,15 @@ def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "w
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+def NVGPU_WarpgroupResult : NVGPU_Type<"WarpgroupResult", "warpgroup.result", []> {
+  let parameters = (ins "Type":$tensor);
+  let assemblyFormat = "`<` struct(params) `>`";
+  let description = [{
+    It is fragmented result matrix from `nvgpu.wargroup.mma`.
+    [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVGPU Op Definitions
 //===----------------------------------------------------------------------===//
@@ -664,5 +673,44 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
   let hasVerifier = 1;
 }
 
+def NVGPU_WarpgroupMmaOp : NVGPU_Op<"wargroup.mma"> {
+  let description = [{
+    The `nvgpu.wargroup.mma` op performs the warpgroup-level (4 warps) 
+    matrix-multiply-and-accumulate (mma) operation that results in 
+    `nvvm.wgmma.mma_async`. 
+    
+    The operands are `descriptorA` and `descriptorB` that are wgmma matrix 
+    descriptors that shows the properties of the matrix in shared memory. The 
+    results are thread-level ownership to the warpgroup-level mma operation 
+    shape. The shape is deduced from the descriptor types and output vector.
+
+    The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the 
+    given shape. As the the instruction `nvvm.wgmma.async` is an asyncronous, 
+    this Op groups the `nvvm.wgmma.async` and surrounds them between 
+    `wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`, 
+    `wgmma.wait.group.sync.aligned` Ops.
+
+    Example:
+    ```mlir
+      %res = nvgpu.wargroup.mma %wgmmaDescA, %wgmmaDescB, %acc: 
+                    !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, 
+                    !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, 
+                    vector<128x128xf32> -> !nvgpu.warpgroup.result<tensor = ...>
+    ```
+  }];
+
+  let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA, 
+                       NVGPU_WarpgroupMatrixDescriptor:$descriptorB,                        
+                       AnyVector:$matrixC,
+                       DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
+                       OptionalAttr<UnitAttr>:$transposeA,
+                       OptionalAttr<UnitAttr>:$transposeB);
+  let results = (outs Variadic<NVGPU_WarpgroupResult>:$matrixD);
+  let assemblyFormat = [{    
+    $descriptorA`,` $descriptorB`,` $matrixC (`,` `group` `=` $waitGroup^ )? attr-dict
+    `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
+  }];
+  let hasVerifier = 1;
+}
 
 #endif // NVGPU
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 192afcb2dba7913..96af26842dafea2 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -21,6 +21,8 @@
 
 #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc"
 
+constexpr int kWarpSize = 32;
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b045089244ff1a7..90d138bd206e010 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
@@ -34,6 +35,10 @@ namespace mlir {
 
 using namespace mlir;
 
+/// Number of bits that needs to excluded when building matrix descriptor for
+/// wgmma operations.
+constexpr int exclude4LSB = 4;
+
 /// GPU has 32 bit registers, this function truncates values when larger width
 /// is not needed.
 static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
@@ -984,10 +989,9 @@ struct NVGPUGenerateGmmaDescriptorLowering
                                          shiftLeft(val, startBit));
     };
 
-    int ex4LSB = 4;
     int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
-    uint64_t strideDimVal = (layout << 3) >> ex4LSB;
-    uint64_t leadDimVal = (sizeN * layout) >> ex4LSB;
+    uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
+    uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
     uint64_t offsetVal = 0;
 
     Value strideDim = makeConst(strideDimVal);
@@ -1141,6 +1145,164 @@ struct NVGPUTmaCreateDescriptorOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
+  using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+                              int &wgmmaShapeM, int &wgmmaShapeN,
+                              int &wgmmaShapeK) const {
+    wgmmaShapeM = 64;
+    wgmmaShapeN = sizeN;
+    if (inputElemType.isTF32()) {
+      wgmmaShapeK = 8;
+    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+      wgmmaShapeK = 16;
+    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+               inputElemType.isInteger(16)) {
+      wgmmaShapeK = 32;
+    } else if (inputElemType.isInteger(1)) {
+      wgmmaShapeK = 256;
+    } else {
+      return failure();
+    }
+    LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
+                      << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
+                      << "]\n");
+    return success();
+  }
+
+  Value generateNVVMWgmmaOp(MLIRContext *ctx,
+                            ConversionPatternRewriter &rewriter, Location loc,
+                            int m, int n, int k, Type resultStructType,
+                            Value inout, Value descriptorA,
+                            Value descriptorB) const {
+    TypeRange resultTypes = {resultStructType};
+    auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
+    auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
+    auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
+    auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
+    auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
+    // todo input type
+    auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
+    auto overflow =
+        NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
+    Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
+        loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype,
+        scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+    return res;
+  }
+
+  static Type buildOutputStructType(MLIRContext *ctx, Type outElemType,
+                                    int sizeN) {
+    int outputElements = 0;
+    if (outElemType.isF32() || outElemType.isInteger(32))
+      outputElements = sizeN / 2;
+    if (outElemType.isF16())
+      outputElements = sizeN / 4;
+    SmallVector<Type> structBody;
+    for (int i = 0; i < outputElements; i++)
+      structBody.push_back(outElemType);
+    return LLVM::LLVMStructType::getLiteral(ctx, structBody);
+  }
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> wgmmaResults;
+
+    int64_t sizeM = op.getMatrixC().getType().getDimSize(0);
+    int64_t sizeN = op.getMatrixC().getType().getDimSize(1);
+    int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+
+    LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
+                      << sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
+                      << sizeN << "] ---===\n");
+
+    int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
+    if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
+                             wgmmaShapeN, wgmmaShapeK))) {
+      return failure();
+    }
+
+    Value descriptorA = adaptor.getDescriptorA();
+    Value descriptorB = adaptor.getDescriptorB();
+
+    //  Generate wgmma group
+
+    auto loc = op->getLoc();
+    Type outElemType = op.getMatrixC().getType().getElementType();
+    Type stype = buildOutputStructType(op->getContext(), outElemType, sizeN);
+    MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
+    MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
+
+    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+    };
+
+    auto iterateDescA = [&](Value desc, int iterM, int iterN,
+                            int iterK) -> Value {
+      // todo : Handle column major
+      int byte = typeTensorA.getElementTypeBitWidth() / 8;
+      int tileShapeA = typeTensorA.getDimSize(1);
+      int incrementVal =
+          ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
+      incrementVal = incrementVal >> exclude4LSB;
+      LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
+                        << iterK << "] [wgmma descriptors] Descriptor A + "
+                        << incrementVal << " | \t ");
+      return incrementVal
+                 ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal))
+                 : desc;
+    };
+
+    auto iterateDescB = [&](Value desc, int iterM, int iterN,
+                            int iterK) -> Value {
+      // todo : Handle row major
+      int byte = typeTensorB.getElementTypeBitWidth() / 8;
+      int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
+      incrementVal = incrementVal >> exclude4LSB;
+      LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
+      return incrementVal
+                 ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal))
+                 : desc;
+    };
+
+    rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
+    for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
+      Value undefOp = rewriter.create<LLVM::UndefOp>(loc, stype);
+      Value inout = undefOp;
+      LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
+                        << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
+                        << ":" << wgmmaShapeN << "] += \n");
+      for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
+        Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
+        Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
+        LLVM_DEBUG(DBGS() << "\t wgmma."
+                          << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
+                          << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
+                          << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
+                          << (iterK * wgmmaShapeK) << ":"
+                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
+                          << " B[" << (iterK * wgmmaShapeK) << ":"
+                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
+                          << ":" << wgmmaShapeN << "])\n");
+        inout = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
+                                    wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
+                                    stype, inout, descA, descB);
+      }
+      wgmmaResults.push_back(inout);
+    }
+
+    rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
+    rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
+
+    ValueRange myres(wgmmaResults);
+    rewriter.replaceOp(op, myres);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1156,6 +1318,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUTmaCreateDescriptorOpLowering,    // nvgpu.tma.create.descriptor
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
       NVGPUGenerateGmmaDescriptorLowering,   // nvgpu.wgmma.generate.descriptor
+      NVGPUWarpgroupMmaOpLowering,           // nvgpu.wargroup.mma
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d832a983a132d61..cd0d65ddd9a65c0 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -151,7 +151,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
   //  - For F32 (TF32), F16, S8, and S4 data
   //    types the fundamental tensor core operation is of shape 8-by-8-by-128b.
   //  - F64 is an exception and is of shape 8-by-8-by-256b.
-  constexpr int kThreads = 32; // 32 threads per warp
   int64_t shapeM = 8;
   int64_t shapeN = 8;
   int64_t shapeK; // set based on data type (128b for all data types except F64)
@@ -206,17 +205,17 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
 
   // verify warp-wide size for vector a
   int64_t sparseFactor = sparse ? 2 : 1;
-  if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor)
+  if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
     return op->emitOpError()
            << "expected " << m * k << " warp-wide matrix A elements";
 
   // verify warp-wide size for vector b
-  if (bShape[0] * bShape[1] * kThreads != k * n)
+  if (bShape[0] * bShape[1] * kWarpSize != k * n)
     return op->emitOpError()
            << "expected " << k * n << " warp-wide matrix B elements";
 
   // verify warp-wide size for vector c
-  if (cShape[0] * cShape[1] * kThreads != m * n)
+  if (cShape[0] * cShape[1] * kWarpSize != m * n)
     return op->emitOpError()
            << "expected " << m * n << " warp-wide matrix C elements";
 
@@ -402,6 +401,107 @@ LogicalResult GenerateGmmaDescriptorOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// WarpgroupMmaOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
+  // F32 += F16 + F16
+  // F16 += F16 + F16
+  if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16()))
+    return success();
+  // F32 += TF32 + TF32
+  if (typeA.isTF32() && typeD.isF32() && typeB.isTF32())
+    return success();
+  // s32 += i8 + i8
+  if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32))
+    return success();
+  // s32 += i1 + i1
+  if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32))
+    return success();
+  // F32 += BF16 + BF16
+  // F16 += BF16 + BF16
+  if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16()))
+    return success();
+  // F16 += f8 + f8
+  // F32 += f8 + f8
+  if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) &&
+      (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) &&
+      (typeD.isF32() || typeD.isF16()))
+    return success();
+
+  return failure();
+}
+
+LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
+  SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
+                               72,  80,  88,  96,  104, 112, 120, 128,
+                               136, 144, 152, 160, 168, 176, 184, 192,
+                               200, 208, 216, 224, 232, 240, 248, 256};
+  SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
+                                    80,  96,  112, 128, 144, 160,
+                                    176, 192, 208, 224, 240, 256};
+  if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() ||
+      typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
+    if (llvm::any_of(allowedN, [&](int n) { return sizeN == n; }))
+      return success();
+
+  if (typeA.isInteger(8) || typeA.isInteger(1))
+    if (llvm::any_of(allowedNshort, [&](int n) { return sizeN == n; }))
+      return success();
+  return failure();
+}
+
+LogicalResult WarpgroupMmaOp::verify() {
+  if (getTransposeA() && !getTransposeB())
+    return emitOpError() << "supports non-transpose A (Row Major) "
+                            "and transpose B (Column Major) for the time being";
+  auto matrixA = getDescriptorA().getType().getTensor();
+  auto matrixB = getDescriptorB().getType().getTensor();
+  auto matrixC = getMatrixC().getType();
+  if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
+      matrixC.getRank() != 2)
+    return emitOpError()
+           << "has input matrices A, B and D, they must be 2 dimensional";
+
+  if (matrixA.getShape()[1] != matrixB.getShape()[0])
+    return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
+                         << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
+                         << " )";
+  if (matrixA.getShape()[0] != matrixC.getShape()[0])
+    return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
+                         << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
+                         << " )";
+  if (matrixB.getShape()[1] != matrixC.getShape()[1])
+    return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1]
+                         << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
+                         << " )";
+
+  if (failed(isAllowedWGMMADataType(matrixC.getElementType(),
+                                    matrixA.getElementType(),
+                                    matrixB.getElementType())))
+    return emitOpError() << matrixC.getElementType()
+                         << " += " << matrixA.getElementType() << " * "
+                         << matrixB.getElementType()
+                         << ", it is not supported.";
+  // Check N
+  if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
+    return emitOpError() << "has input type " << matrixB << " n is set to "
+                         << matrixB.getDimSize(1) << ", it is not supported";
+  }
+
+  // Currently, f16/bf16 supported
+  if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
+      !matrixA.getElementType().isBF16()) {
+    return emitOpError() << "hit a limitation: " << matrixC.getElementType()
+                         << " += " << matrixA.getElementType() << " * "
+                         << matrixB.getElementType()
+                         << ", it is not supported yet";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd dialect, type, and op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 0d7ace52ccb36c9..cafeb785e31ffe2 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -672,6 +672,70 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc
   func.return %descA : !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
 }
 
+!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @warpgroup_mma_128_128_64(  
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
+func.func @warpgroup_mma_128_128_64(
+      %descA: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, 
+      %descB: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, 
+      %D: memref<128x128xf32,3>) 
+{
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %arg0 : !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %arg1 : !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK: nvvm.wgmma.fence.aligned
+// CHECK: %[[S3:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%3, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64
+// CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64
+// CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64
+// CHECK: %[[S8:.+]] = llvm.add %[[S1]], %[[S7]]  : i64
+// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], <m = 64, n = 128, k = 16>, D[%[[S4]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S10:.+]] = llvm.mlir.constant(4 : i32) : i64
+// CHECK: %[[S11:.+]] = llvm.add %[[S0]], %[[S10]]  : i64
+// CHECK: %[[S12:.+]] = llvm.mlir.constant(256 : i32) : i64
+// CHECK: %[[S13:.+]] = llvm.add %[[S1]], %[[S12]]  : i64
+// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], <m = 64, n = 128, k = 16>, D[%[[S9]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S15:.+]] = llvm.mlir.constant(6 : i32) : i64
+// CHECK: %[[S16:.+]] = llvm.add %[[S0]], %[[S15]]  : i64
+// CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64
+// CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]]  : i64
+// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], <m = 64, n = 128, k = 16>, D[%[[S14]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S20:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64
+// CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]]  : i64
+// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S20]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S24:.+]] = llvm.mlir.constant(514 : i32) : i64
+// CHECK: %[[S25:.+]] = llvm.add %[[S0]], %[[S24]]  : i64
+// CHECK: %[[S26:.+]] = llvm.mlir.constant(128 : i32) : i64
+// CHECK: %[[S27:.+]] = llvm.add %[[S1]], %[[S26]]  : i64
+// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], <m = 64, n = 128, k = 16>, D[%[[S23]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S29:.+]] = llvm.mlir.constant(516 : i32) : i64
+// CHECK: %[[S30:.+]] = llvm.add %[[S0]], %[[S29]]  : i64
+// CHECK: %[[S31:.+]] = llvm.mlir.constant(256 : i32) : i64
+// CHECK: %[[S32:.+]] = llvm.add %[[S1]], %[[S31]]  : i64
+// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], <m = 64, n = 128, k = 16>, D[%[[S28]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S34:.+]] = llvm.mlir.constant(518 : i32) : i64
+// CHECK: %[[S35:.+]] = llvm.add %[[S0]], %[[S34]]  : i64
+// CHECK: %[[S36:.+]] = llvm.mlir.constant(384 : i32) : i64
+// CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]]  : i64
+// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], <m = 64, n = 128, k = 16>, D[%[[S33]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: nvvm.wgmma.commit.group.sync.aligned
+// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0.0 : f32
+  %acc = vector.transfer_read %D[%c0, %c0], %f0 {in_bounds = [true, true]} : memref<128x128xf32,3>, vector<128x128xf32>
+  %wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}: 
+      !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, 
+      !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, 
+      vector<128x128xf32> -> !nvgpu.warpgroup.result<tensor = !accMatrixStruct>, !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+  
+  return
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 
@@ -681,5 +745,5 @@ transform.sequence failures(propagate) {
   } with type_converter {
     transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
       {use_opaque_pointers = true}
-  } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "scf"], partial_conversion} : !transform.any_op
+  } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "vector", "scf"], partial_conversion} : !transform.any_op
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index ef721b18014071a..d7af22085c10b9e 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -221,3 +221,64 @@ func.func @async_cp_size_invalid_f64(
   %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 3: memref<128x128xf64> to memref<3x16x128xf64, 3>
   return
 }
+
+// -----
+
+!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32)>
+!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tDescA  = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB  = !nvgpu.wgmma.descriptor<tensor = memref<64x121xf16, 3>>
+
+func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) {
+  // expected-error @+1 {{'nvgpu.wargroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}}  
+  %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult
+  return
+}
+
+// -----
+
+!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32)>
+!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tDescA  = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB  = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>
+func.func @warpgroup_mma_wrong_accumulator(%descA: !tDescA, %descB: !tDescB, %D: vector<128xf32>) {
+  // expected-error @+1 {{'nvgpu.wargroup.mma' op has input matrices A, B and D, they must be 2 dimensional}}  
+  %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128xf32> -> !tResult, !tResult
+  return
+}
+
+// -----
+
+!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32)>
+!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tDescA  = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB  = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf32, 3>>
+func.func @warpgroup_mma_wrong_datatypes(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) {
+  // expected-error @+1 {{'nvgpu.wargroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}}  
+  %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult
+  return
+}
+
+// -----
+
+!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32)>
+!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tDescA  = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB  = !nvgpu.wgmma.descriptor<tensor = memref<64x512xf16, 3>>
+func.func @warpgroup_mma_wrong_large_shape(%descA: !tDescA, %descB: !tDescB, %D: vector<128x512xf32>) {
+  // expected-error @+1 {{'nvgpu.wargroup.mma' op has input type 'memref<64x512xf16, 3>' n is set to 512, it is not supported}}  
+  %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x512xf32> -> !tResult, !tResult
+  return
+}

>From 3cd57425d5d46427f267f6d910dc36b832e9eb8e Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 6 Sep 2023 08:34:15 +0200
Subject: [PATCH 2/2] [MLIR][NVGPU] Adding `nvgpu.wargroup.mma` Op for Hopper
 GPUs

This work introduces a new operation called `wargroup.mma` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate warpgroup-level matrix multiply and accumulate (WGMMA) operations on Hopper GPUs with sm_90a architecture.

Previously, the `nvvm.wgmma.mma_async` operation was introduced to support wargroup-level matrix operations in NVVM dialect. This op is used multiple instances of `nvvm.wgmma.mma_async` to achieve the desired shape. The new `nvgpu.wargroup.mma` operation abstracts this complexity and provides a higher-level interface for performing wargroup-level matrix operations.

The `nvgpu.wargroup.mma` does followings:
1) Corresponds multiple `wgmma` instructions.
2) Iterates input matrix descriptors to achieve the desired computation shape.
3) Groups and runs `wgmma` instructions asynchronously, and eventually waits them. This are done by `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned`, and `wgmma.wait.group.sync.aligned`
4) Results fragmented matrices

Here's an example usage of the `nvgpu.wargroup.mma` operation:
```
%wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}:
      !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
      !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
      vector<128x128xf32>
      -> !nvgpu.warpgroup.result<tensor = !llvm.struct<...>,
  		 !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
```

Differential Revision: https://reviews.llvm.org/D158434
---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   | 19 ++++
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 86 ++++++++++++++++++-
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp    | 31 +++++++
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 85 +++++++++++++++++-
 4 files changed, 218 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index f891aae136eba25..443abe6e40db82f 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -713,4 +713,23 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"wargroup.mma"> {
   let hasVerifier = 1;
 }
 
+def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"wargroup.mma.store"> {
+  let description = [{
+    The `nvgpu.wargroup.mma.store` op performs the store of fragmented result 
+    in $matrixD to give memref. 
+
+    [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
+
+    Note that, the op must be run with warp group.
+  }];
+
+  let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD,
+                       Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+  
+  let assemblyFormat = [{
+    `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 90d138bd206e010..90f999d36e6f07f 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -408,8 +409,8 @@ struct ConvertNVGPUToNVVMPass
   using Base::Base;
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
+    registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
+                    arith::ArithDialect>();
   }
 
   void runOnOperation() override {
@@ -424,6 +425,9 @@ struct ConvertNVGPUToNVVMPass
     converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
       return converter.convertType(IntegerType::get(type.getContext(), 32));
     });
+    converter.addConversion([&](nvgpu::WarpgroupResultType type) -> Type {
+      return converter.convertType(type.getTensor());
+    });
     converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
       return converter.convertType(IntegerType::get(type.getContext(), 64));
     });
@@ -441,6 +445,7 @@ struct ConvertNVGPUToNVVMPass
     populateNVGPUToNVVMConversionPatterns(converter, patterns);
     LLVMConversionTarget target(getContext());
     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+    target.addLegalDialect<::mlir::arith::ArithDialect>();
     target.addLegalDialect<::mlir::memref::MemRefDialect>();
     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
     if (failed(applyPartialConversion(getOperation(), target,
@@ -1303,11 +1308,88 @@ struct NVGPUWarpgroupMmaOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaStoreOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
+
+  void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
+                             OpAdaptor adaptor,
+                             ConversionPatternRewriter &rewriter,
+                             int offset) const {
+    Location loc = op->getLoc();
+    Type i32 = rewriter.getI32Type();
+
+    auto makeConst = [&](int32_t index) -> Value {
+      return rewriter.create<LLVM::ConstantOp>(
+          loc, i32, rewriter.getI32IntegerAttr(index));
+    };
+    Value c4 = makeConst(4);
+    Value c32 = makeConst(kWarpSize);
+    Value c8 = makeConst(8);
+    Value c2 = makeConst(2);
+    Value c1 = makeConst(1);
+    Value c16 = makeConst(16);
+
+    auto makeMul = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
+    };
+    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+    };
+
+    Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
+    Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
+    Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
+    Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
+    Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
+
+    auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
+                                   TypedValue<::mlir::MemRefType> memref) {
+      Type it = rewriter.getIndexType();
+      Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
+      Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
+      Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
+      Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
+      Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
+      rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
+      rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
+    };
+
+    Value tj = makeMul(lane4modId, c2);
+    Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
+    if (offset)
+      ti = makeAdd(ti, makeConst(offset));
+    for (int i = 0; i < 2; ++i) {
+      Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
+      for (int j = 0; j < 16; ++j) {
+        Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
+        int sIndex = i * 2 + j * 4;
+        makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
+      }
+    }
+  }
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    int offset = 0;
+    for (auto result : adaptor.getMatrixD()) {
+      auto stype = result.getType().cast<LLVM::LLVMStructType>();
+      storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
+      offset += stype.getBody().size();
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   patterns.add<
+      NVGPUWarpgroupMmaStoreOpLowering,      // nvgpu.wargroup.mma.store`
       NVGPUMBarrierCreateLowering,           // nvgpu.mbarrier.create
       NVGPUMBarrierInitLowering,             // nvgpu.mbarrier.init
       NVGPUMBarrierArriveLowering,           // nvgpu.mbarrier.arrive
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index cd0d65ddd9a65c0..baee82a941ee245 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -22,6 +23,8 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -502,6 +505,34 @@ LogicalResult WarpgroupMmaOp::verify() {
   return success();
 }
 
+LogicalResult WarpgroupMmaStoreOp::verify() {
+  Type stype =
+      getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor();
+
+  for (auto result : getMatrixD()) {
+    auto resultStype = result.getType()
+                           .cast<WarpgroupResultType>()
+                           .getTensor()
+                           .dyn_cast<LLVM::LLVMStructType>();
+    if (!resultStype)
+      return emitOpError() << "result is " << result.getType()
+                           << "  but must keep type of llvm struct";
+    if (stype != resultStype)
+      return emitOpError() << "all results must be the same type";
+
+    // todo improve this limitation
+    if (!resultStype.getBody().front().isF32()) {
+      return emitOpError() << "supporst only f32 results for the time being";
+    }
+  }
+
+  if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
+    return emitOpError() << "all element types must be equal  ";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd dialect, type, and op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index cafeb785e31ffe2..68a944b4c99a996 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -732,10 +732,93 @@ func.func @warpgroup_mma_128_128_64(
       !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, 
       !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, 
       vector<128x128xf32> -> !nvgpu.warpgroup.result<tensor = !accMatrixStruct>, !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
-  
   return
 }
 
+// CHECK-LABEL: @warpgroup_mma_store(  
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.result<tensor = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>
+func.func @warpgroup_mma_store(%result1 : !nvgpu.warpgroup.result<tensor = !accMatrixStruct>, %matrixD: memref<128x128xf32,3>) {
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : 
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
+// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
+
+// ### Store {d0, d1} of each thread ###
+
+// CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
+// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[S3]]  : i32
+// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[S3]]  : i32
+// CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]]  : i32
+// CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]]  : i32
+// CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]]  : i32
+// CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]]  : i32
+// CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]]  : i32
+// CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]]  : i32
+// CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]]  : i32
+// CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]]  : i32
+// CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]]  : i32
+// CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
+// CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]]  : i32
+// CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
+// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
+// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
+// CHECK: memref.store %[[S26]], %[[arg1]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S27]], %[[arg1]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>
+
+// ### Store {d2, d3} of each thread ###
+
+// CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]]  : i32
+// CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]]  : i32
+// CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
+// CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]]  : i32
+// CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
+// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
+// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
+// CHECK: memref.store %[[S35]], %[[arg1]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S36]], %[[arg1]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>
+
+// ### Store {d4, d5} of each thread ### 
+
+// CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]]  : i32
+// CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]]  : i32
+// CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
+// CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]]  : i32
+// CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
+// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
+// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
+// CHECK: memref.store %[[S44]], %[[arg1]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S45]], %[[arg1]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>
+
+// ### Store {d6, d7} of each thread ### 
+
+// CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]]  : i32
+// CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]]  : i32
+// CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
+// CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]]  : i32
+// CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
+// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
+// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
+// CHECK: memref.store %[[S53]], %[[arg1]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S54]], %[[arg1]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>
+
+// Pattern continues similarly 60x times until {... d126, d127}
+
+  nvgpu.wargroup.mma.store [%result1], %matrixD : !nvgpu.warpgroup.result<tensor = !accMatrixStruct> to memref<128x128xf32,3>
+  return 
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 



More information about the Mlir-commits mailing list