[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)

Guray Ozen llvmlistbot at llvm.org
Mon Oct 2 02:09:41 PDT 2023


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/65441

>From 7b71da55fca8fe2a7dbe4982b1959be6a6175fa1 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 7 Sep 2023 11:52:38 +0200
Subject: [PATCH 1/6] [MLIR][NVGPU] Introduce `nvgpu.warpgroup.mma.store` Op
 for Hopper GPUs

This work introduces a new operation called `warpgroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref.

An example of fragmentation is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d

The `warpgroup.mma.store` does followings:
1) Takes one or more fragmented results matrix.
2) Calculates indexes per thread in warp group and stores the data into give memref.

Here's an example usage of the `nvgpu.warpgroup.mma` operation:
```
// Performs matmul, results are fragmented and in registers
%res, %res2 = nvgpu.warpgroup.mma ...

// Stores the fragmented result to the give memory
nvgpu.warpgroup.mma.store [%res1, %res2], %matrixD :
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>,
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
                to memref<128x128xf32,3>
```

Depends on #65440
---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   | 19 +++++
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 83 ++++++++++++++++++-
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp    | 29 +++++++
 3 files changed, 129 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 90381648dac6acc..e102ae0dc581013 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -721,4 +721,23 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
   let hasVerifier = 1;
 }
 
+def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
+  let description = [{
+    The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result 
+    in $matrixD to give memref. 
+
+    [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
+
+    Note that, the op must be run with warp group.
+  }];
+
+  let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD,
+                       Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+  
+  let assemblyFormat = [{
+    `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f74aa05c0c4c4ff..4f1a0bc651e81b7 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -409,8 +410,8 @@ struct ConvertNVGPUToNVVMPass
   using Base::Base;
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
+    registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
+                    arith::ArithDialect>();
   }
 
   void runOnOperation() override {
@@ -451,6 +452,7 @@ struct ConvertNVGPUToNVVMPass
     populateNVGPUToNVVMConversionPatterns(converter, patterns);
     LLVMConversionTarget target(getContext());
     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+    target.addLegalDialect<::mlir::arith::ArithDialect>();
     target.addLegalDialect<::mlir::memref::MemRefDialect>();
     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
@@ -1299,11 +1301,88 @@ struct NVGPUWarpgroupMmaOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaStoreOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
+
+  void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
+                             OpAdaptor adaptor,
+                             ConversionPatternRewriter &rewriter,
+                             int offset) const {
+    Location loc = op->getLoc();
+    Type i32 = rewriter.getI32Type();
+
+    auto makeConst = [&](int32_t index) -> Value {
+      return rewriter.create<LLVM::ConstantOp>(
+          loc, i32, rewriter.getI32IntegerAttr(index));
+    };
+    Value c4 = makeConst(4);
+    Value c32 = makeConst(kWarpSize);
+    Value c8 = makeConst(8);
+    Value c2 = makeConst(2);
+    Value c1 = makeConst(1);
+    Value c16 = makeConst(16);
+
+    auto makeMul = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
+    };
+    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+    };
+
+    Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
+    Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
+    Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
+    Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
+    Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
+
+    auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
+                                   TypedValue<::mlir::MemRefType> memref) {
+      Type it = rewriter.getIndexType();
+      Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
+      Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
+      Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
+      Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
+      Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
+      rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
+      rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
+    };
+
+    Value tj = makeMul(lane4modId, c2);
+    Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
+    if (offset)
+      ti = makeAdd(ti, makeConst(offset));
+    for (int i = 0; i < 2; ++i) {
+      Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
+      for (int j = 0; j < 16; ++j) {
+        Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
+        int sIndex = i * 2 + j * 4;
+        makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
+      }
+    }
+  }
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    int offset = 0;
+    for (auto result : adaptor.getMatrixD()) {
+      auto stype = result.getType().cast<LLVM::LLVMStructType>();
+      storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
+      offset += stype.getBody().size();
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   patterns.add<
+      NVGPUWarpgroupMmaStoreOpLowering,      // nvgpu.warpgroup.mma.store`
       NVGPUMBarrierCreateLowering,           // nvgpu.mbarrier.create
       NVGPUMBarrierInitLowering,             // nvgpu.mbarrier.init
       NVGPUMBarrierArriveLowering,           // nvgpu.mbarrier.arrive
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d96ed69982870b4..fc85df1654198d5 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -529,6 +530,34 @@ LogicalResult WarpgroupMmaOp::verify() {
   return success();
 }
 
+LogicalResult WarpgroupMmaStoreOp::verify() {
+  Type stype =
+      getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor();
+
+  for (auto result : getMatrixD()) {
+    auto resultStype = result.getType()
+                           .cast<WarpgroupResultType>()
+                           .getTensor()
+                           .dyn_cast<LLVM::LLVMStructType>();
+    if (!resultStype)
+      return emitOpError() << "result is " << result.getType()
+                           << "  but must keep type of llvm struct";
+    if (stype != resultStype)
+      return emitOpError() << "all results must be the same type";
+
+    // todo improve this limitation
+    if (!resultStype.getBody().front().isF32()) {
+      return emitOpError() << "supporst only f32 results for the time being";
+    }
+  }
+
+  if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
+    return emitOpError() << "all element types must be equal  ";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd dialect, type, and op definitions
 //===----------------------------------------------------------------------===//

>From 4a1824f3e6ae955b78f7262178fa1b8e4608e3da Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 22 Sep 2023 16:53:21 +0200
Subject: [PATCH 2/6] use new type `WarpgroupAccumulator`

---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td     |  5 +++--
 mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp |  2 +-
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp      | 11 +++++++----
 3 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index e102ae0dc581013..4e80c33aec6043d 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -726,12 +726,13 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
     The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result 
     in $matrixD to give memref. 
 
-    [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
+    [See the details of register fragment layout for accumulator matrix D]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
 
     Note that, the op must be run with warp group.
   }];
 
-  let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD,
+  let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
                        Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
   
   let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4f1a0bc651e81b7..006ecbef2546e3e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1382,7 +1382,6 @@ struct NVGPUWarpgroupMmaStoreOpLowering
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   patterns.add<
-      NVGPUWarpgroupMmaStoreOpLowering,      // nvgpu.warpgroup.mma.store`
       NVGPUMBarrierCreateLowering,           // nvgpu.mbarrier.create
       NVGPUMBarrierInitLowering,             // nvgpu.mbarrier.init
       NVGPUMBarrierArriveLowering,           // nvgpu.mbarrier.arrive
@@ -1394,6 +1393,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
       NVGPUGenerateGmmaDescriptorLowering,   // nvgpu.wgmma.generate.descriptor
       NVGPUWarpgroupMmaOpLowering,           // nvgpu.warpgroup.mma
+      NVGPUWarpgroupMmaStoreOpLowering,      // nvgpu.warpgroup.mma.store`
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index fc85df1654198d5..1486bba5d3e57f6 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -531,13 +531,16 @@ LogicalResult WarpgroupMmaOp::verify() {
 }
 
 LogicalResult WarpgroupMmaStoreOp::verify() {
-  Type stype =
-      getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor();
+  Type stype = getMatrixD()
+                   .front()
+                   .getType()
+                   .cast<WarpgroupAccumulatorType>()
+                   .getFragmented();
 
   for (auto result : getMatrixD()) {
     auto resultStype = result.getType()
-                           .cast<WarpgroupResultType>()
-                           .getTensor()
+                           .cast<WarpgroupAccumulatorType>()
+                           .getFragmented()
                            .dyn_cast<LLVM::LLVMStructType>();
     if (!resultStype)
       return emitOpError() << "result is " << result.getType()

>From e60310d10c8e43669402e432cd130383cdf7a837 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 27 Sep 2023 09:41:45 +0200
Subject: [PATCH 3/6] add test

---
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 127 ++++++++++++++++++
 1 file changed, 127 insertions(+)

diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index f011007e040ce9c..93123cecbc38f94 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -732,6 +732,133 @@ func.func @warpgroup_mma_128_128_64(
   return
 }
 
+// CHECK-LABEL: @warpgroup_mma_store(  
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
+func.func @warpgroup_mma_store(
+    %result1 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+    %result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, 
+    %matrixD: memref<128x128xf32,3>) {
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[DB:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : 
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
+// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
+
+// ### Store {d0, d1} of each thread ###
+
+// CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
+// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[S3]]  : i32
+// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[S3]]  : i32
+// CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]]  : i32
+// CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]]  : i32
+// CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]]  : i32
+// CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]]  : i32
+// CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]]  : i32
+// CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]]  : i32
+// CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]]  : i32
+// CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]]  : i32
+// CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]]  : i32
+// CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
+// CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]]  : i32
+// CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
+// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
+// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
+// CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>
+
+// ### Store {d2, d3} of each thread ###
+
+// CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]]  : i32
+// CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]]  : i32
+// CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
+// CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]]  : i32
+// CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
+// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
+// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
+// CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>
+
+// ### Store {d4, d5} of each thread ### 
+
+// CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]]  : i32
+// CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]]  : i32
+// CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
+// CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]]  : i32
+// CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
+// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
+// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
+// CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>
+
+// ### Store {d6, d7} of each thread ### 
+
+// CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]]  : i32
+// CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]]  : i32
+// CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
+// CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]]  : i32
+// CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
+// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
+// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
+// CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>
+
+// Pattern continues similarly 28x times until {... d62, d63}
+
+// ### Store {d64, d65} of each thread ### 
+
+// CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK: %[[S312:.+]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
+// CHECK: %[[S314:.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
+// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[S312]]  : i32
+// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[S312]]  : i32
+// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]]
+// CHECK: %[[S321:.+]] = llvm.urem %[[S318]]
+// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S314]]  : i32
+// CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]]  : i32
+// CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]]  : i32
+// CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32
+// CHECK: %[[S326:.+]] = llvm.add %[[S324]], %[[S325]]  : i32
+// CHECK: %[[S327:.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[S328:.+]] = llvm.mul %[[S327]], %[[S313]]  : i32
+// CHECK: %[[S329:.+]] = llvm.add %[[S326]], %[[S328]]  : i32
+// CHECK: %[[S330:.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[S331:.+]] = llvm.mul %[[S330]], %[[S313]]  : i32
+// CHECK: %[[S332:.+]] = llvm.add %[[S322]], %[[S331]]  : i32
+// CHECK: %[[S333:.+]] = arith.index_cast %[[S329]] : i32 to index
+// CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index
+// CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]]  : i32
+// CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index
+// CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0] 
+// CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1]  
+// CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3>
+
+// Pattern continues similarly 31x times until {... d126, d127}
+
+  nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> 
+    to memref<128x128xf32,3>
+  return 
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 

>From fe03a52c573c287efba3e9c77837a5d91a1e3ad1 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 27 Sep 2023 09:41:53 +0200
Subject: [PATCH 4/6] better verification

---
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 48 +++++++++++-----------
 1 file changed, 25 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 1486bba5d3e57f6..b9994aced0be7f4 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -531,33 +531,35 @@ LogicalResult WarpgroupMmaOp::verify() {
 }
 
 LogicalResult WarpgroupMmaStoreOp::verify() {
-  Type stype = getMatrixD()
-                   .front()
-                   .getType()
-                   .cast<WarpgroupAccumulatorType>()
-                   .getFragmented();
-
+  MemRefType dstMemrefType = getDstMemref().getType();
+  VectorType firstVtype = getMatrixD()
+                              .front()
+                              .getType()
+                              .cast<WarpgroupAccumulatorType>()
+                              .getFragmented();
+
+  int64_t totalFirstDimension = 0;
   for (auto result : getMatrixD()) {
-    auto resultStype = result.getType()
-                           .cast<WarpgroupAccumulatorType>()
-                           .getFragmented()
-                           .dyn_cast<LLVM::LLVMStructType>();
-    if (!resultStype)
-      return emitOpError() << "result is " << result.getType()
-                           << "  but must keep type of llvm struct";
-    if (stype != resultStype)
-      return emitOpError() << "all results must be the same type";
-
-    // todo improve this limitation
-    if (!resultStype.getBody().front().isF32()) {
-      return emitOpError() << "supporst only f32 results for the time being";
+    VectorType vtype =
+        result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
+    if (vtype != firstVtype)
+      return emitOpError() << "all fragmented types must be the same";
+    // Limitation
+    if (!vtype.getElementType().isF32()) {
+      return emitOpError()
+             << "hit a limitation: only f32 results for the time being";
     }
+    totalFirstDimension += vtype.getDimSize(0);
   }
-
-  if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
-    return emitOpError() << "all element types must be equal  ";
+  if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
+      firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
+    return emitOpError() << "results [" << totalFirstDimension << "]["
+                         << firstVtype.getDimSize(1)
+                         << "] values. However, destination memref["
+                         << dstMemrefType.getDimSize(0) << "]["
+                         << dstMemrefType.getDimSize(1)
+                         << "]  does not have same size as results";
   }
-
   return success();
 }
 

>From 2ba9de5e479a6a680b4cacf2b8a90d8da87115c1 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 27 Sep 2023 09:52:01 +0200
Subject: [PATCH 5/6] fix test

---
 mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 93123cecbc38f94..cd8222c1c0ce585 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -739,8 +739,7 @@ func.func @warpgroup_mma_store(
     %result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, 
     %matrixD: memref<128x128xf32,3>) {
 // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
-// CHECK: %[[DB:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
-// CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : 
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32
 // CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32

>From c894a472ba7f703a02e2da19fcd09a3e71e15517 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 2 Oct 2023 11:09:23 +0200
Subject: [PATCH 6/6] address @qcolombet comments

---
 .../lib/Conversion/NVGPUToNVVM/CMakeLists.txt |   1 +
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 105 +++++++++++++-----
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |  28 ++---
 .../llvm-project-overlay/mlir/BUILD.bazel     |   1 +
 4 files changed, 92 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
index 00e775ce7dd22bc..a050749eb7da87e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRNVGPUToNVVM
   MLIRLLVMDialect
   MLIRNVGPUDialect
   MLIRNVVMDialect
+  MLIRArithDialect
   MLIRPass
   MLIRSCFTransforms
   MLIRTransforms
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 006ecbef2546e3e..c02026977922e81 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
@@ -52,6 +53,16 @@ static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
   return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), value);
 }
 
+/// Returns warp-size as a value.
+static Value getWarpSizeValue(ImplicitLocOpBuilder &b) {
+  static std::optional<Value> warpSize = std::nullopt;
+  if (!warpSize.has_value()) {
+    warpSize = b.create<LLVM::ConstantOp>(IntegerType::get(b.getContext(), 32),
+                                          b.getI32IntegerAttr(kWarpSize));
+  }
+  return warpSize.value();
+}
+
 /// Returns the type for the intrinsic given the vectorResultType of the
 /// `gpu.mma.sync` operation.
 static Type inferIntrinsicResultType(Type vectorResultType) {
@@ -1306,47 +1317,80 @@ struct NVGPUWarpgroupMmaStoreOpLowering
   using ConvertOpToLLVMPattern<
       nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
 
-  void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
-                             OpAdaptor adaptor,
-                             ConversionPatternRewriter &rewriter,
+  /// This function stores a fragmented register matrix owned by a warp group
+  /// (128 threads) into a memref. Each thread has 64 registers, each the size
+  /// of a struct.
+  /// Here is what each threads (T) holds, each `d` is struct value with a
+  /// number.
+  ///
+  /// Threads in warp-group (128 threads) and what they owns in the matriD:
+  /// 0-31 	  Warp-0  -> MatrixD[0:15 ][0:N]
+  /// 32-63 	Warp-1  -> MatrixD[16:31][0:N]
+  /// 64-95 	Warp-2  -> MatrixD[32:47][0:N]
+  /// 96-127 	Warp-3  -> MatrixD[48:64][0:N]
+  ///
+  /// Matrix-D:
+  ///   +______________________________________________________________________+
+  ///   |     0-1  |    2-3  |    4-5  |    6-7  |   8-9  |   10-11|..|N-8,N-7 |
+  /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
+  /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
+  /// ..| .........|.........|.........|.........|........|...........|........|
+  /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
+  /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
+  /// ..| .........|.........|.........|.........|........|...........|........|
+  /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
+  /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
+  /// ..| .........|.........|.........|.........|........|...........|........|
+  /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
+  /// ..| .........|.........|.........|.........|........|...........|........|
+  /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
+  /// ..| .........|.........|.........|.........|........|...........|........|
+  ///   +______________________________________________________________________+
+  ///
+  /// \param rewriter: The pattern rewriter.
+  /// \param matrixD: Result of the warp-group MMA operation (fragmented
+  /// matrix). It is holded by a thread and a struct with 64 elements.
+  /// \param dstMemref: The memref where the registers will be stored.
+  /// \param offset: the offset within the memref where the registers will be
+  /// stored.
+  void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
+                             TypedValue<MemRefType> dstMemref,
                              int offset) const {
-    Location loc = op->getLoc();
-    Type i32 = rewriter.getI32Type();
+    Type i32 = b.getI32Type();
 
     auto makeConst = [&](int32_t index) -> Value {
-      return rewriter.create<LLVM::ConstantOp>(
-          loc, i32, rewriter.getI32IntegerAttr(index));
+      return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
     };
+    Value c1 = makeConst(1);
+    Value c2 = makeConst(2);
     Value c4 = makeConst(4);
-    Value c32 = makeConst(kWarpSize);
     Value c8 = makeConst(8);
-    Value c2 = makeConst(2);
-    Value c1 = makeConst(1);
     Value c16 = makeConst(16);
+    Value warpSize = getWarpSizeValue(b);
 
     auto makeMul = [&](Value lhs, Value rhs) -> Value {
-      return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
+      return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
     };
     auto makeAdd = [&](Value lhs, Value rhs) -> Value {
-      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+      return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
     };
 
-    Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
-    Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
-    Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
-    Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
-    Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
+    Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
+    Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
+    Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
+    Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
+    Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
 
     auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
                                    TypedValue<::mlir::MemRefType> memref) {
-      Type it = rewriter.getIndexType();
-      Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
-      Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
-      Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
-      Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
-      Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
-      rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
-      rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
+      Type it = b.getIndexType();
+      Value idx = b.create<arith::IndexCastOp>(it, x);
+      Value idy0 = b.create<arith::IndexCastOp>(it, y);
+      Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
+      Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
+      Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
+      b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
+      b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
     };
 
     Value tj = makeMul(lane4modId, c2);
@@ -1358,7 +1402,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
       for (int j = 0; j < 16; ++j) {
         Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
         int sIndex = i * 2 + j * 4;
-        makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
+        makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
       }
     }
   }
@@ -1367,10 +1411,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering
   matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     int offset = 0;
-    for (auto result : adaptor.getMatrixD()) {
-      auto stype = result.getType().cast<LLVM::LLVMStructType>();
-      storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
-      offset += stype.getBody().size();
+    ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
+    for (Value matrixD : adaptor.getMatrixD()) {
+      auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
+      storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
+      offset += structType.getBody().size();
     }
     rewriter.eraseOp(op);
     return success();
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index cd8222c1c0ce585..7e9b2f3ed01c862 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -740,18 +740,18 @@ func.func @warpgroup_mma_store(
     %matrixD: memref<128x128xf32,3>) {
 // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
-// CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32
 // CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
-// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
-// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
 // CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: %[[WarpSize:.+]] = llvm.mlir.constant(32 : i32) : i32
 
 // ### Store {d0, d1} of each thread ###
 
 // CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
-// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[S3]]  : i32
-// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[S3]]  : i32
+// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[WarpSize]]  : i32
+// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[WarpSize]]  : i32
 // CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]]  : i32
 // CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]]  : i32
 // CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]]  : i32
@@ -816,20 +816,22 @@ func.func @warpgroup_mma_store(
 
 // Pattern continues similarly 28x times until {... d62, d63}
 
+// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
+
 // ### Store {d64, d65} of each thread ### 
 
+// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[S312:.+]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
-// CHECK: %[[S312:.+]] = llvm.mlir.constant(32 : i32) : i32
 // CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
-// CHECK: %[[S314:.+]] = llvm.mlir.constant(2 : i32) : i32
-// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
 // CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
 // CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
-// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[S312]]  : i32
-// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[S312]]  : i32
-// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]]
-// CHECK: %[[S321:.+]] = llvm.urem %[[S318]]
-// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S314]]  : i32
+// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[WarpSize]]  : i32
+// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[WarpSize]]  : i32
+// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]], %[[S311]]  : i32
+// CHECK: %[[S321:.+]] = llvm.urem %[[S318]], %[[S311]]  : i32
+// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S312]]  : i32
 // CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]]  : i32
 // CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]]  : i32
 // CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 3c30afabfbb204b..8e9d22016526ceb 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5237,6 +5237,7 @@ cc_library(
         ":LLVMCommonConversion",
         ":LLVMDialect",
         ":MemRefDialect",
+        ":MLIRArithDialect",
         ":NVGPUDialect",
         ":NVVMDialect",
         ":Pass",



More information about the Mlir-commits mailing list