[Mlir-commits] [mlir] [mlir][nvgpu] Improve `WarpgroupAccumulator` type to simplify IR (PR #68728)

Guray Ozen llvmlistbot at llvm.org
Mon Oct 16 05:58:50 PDT 2023


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/68728

>From acce6abce8e5f0da1df64b0978b64dce64ba4d6d Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 10 Oct 2023 02:33:20 +0200
Subject: [PATCH 1/2] [mlir][nvgpu] Simplify `NVGPU_WarpgroupAccumulator`
 accumulator

---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   |  10 +-
 .../mlir/Dialect/NVGPU/IR/NVGPUDialect.h      |   3 +
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 112 +++++++++++-------
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp    |  99 ++++++----------
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |  65 +++++-----
 5 files changed, 145 insertions(+), 144 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 79183acfb71b61e..fd16376be366912 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -719,8 +719,8 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
                        DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
                        OptionalAttr<UnitAttr>:$transposeA,
                        OptionalAttr<UnitAttr>:$transposeB,
-                       Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
-  let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
+                       NVGPU_WarpgroupAccumulator:$matrixC);
+  let results = (outs NVGPU_WarpgroupAccumulator:$matrixD);
   let assemblyFormat = [{    
     $descriptorA`,` $descriptorB`,` $matrixC attr-dict
     `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
@@ -739,11 +739,11 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
     Note that, the op must be run with warp group.
   }];
 
-  let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
+  let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
                        Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
   
   let assemblyFormat = [{
-    `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+    $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
   }];
   let hasVerifier = 1;
 }
@@ -755,7 +755,7 @@ def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulat
     This Op generates and initializes the accumulator matrix for 
     `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate.
   }];
-  let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
+  let results = (outs NVGPU_WarpgroupAccumulator:$matrixC);
   let assemblyFormat = "attr-dict `->` type($matrixC)";
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 96af26842dafea2..e6bba7e6082964b 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -23,6 +23,9 @@
 
 constexpr int kWarpSize = 32;
 
+/// M size of wgmma.mma_async instruction
+constexpr int kWgmmaSizeM = 64;
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 84f53a4572294ad..2d43230938526b9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -412,10 +412,28 @@ struct ConvertNVGPUToNVVMPass
       return converter.convertType(IntegerType::get(type.getContext(), 32));
     });
     converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
-      VectorType vtype = type.getFragmented();
+      Type elemType = type.getFragmented().getElementType();
+      int64_t sizeM = type.getFragmented().getDimSize(0);
+      int64_t sizeN = type.getFragmented().getDimSize(1);
+
+      unsigned numMembers;
+      if (elemType.isF32() || elemType.isInteger(32))
+        numMembers = sizeN / 2;
+      else if (elemType.isF16())
+        numMembers = sizeN / 4;
+      else
+        llvm_unreachable("unsupported type for warpgroup accumulator");
+
+      SmallVector<Type> innerStructBody;
+      for (unsigned i = 0; i < numMembers; i++)
+        innerStructBody.push_back(elemType);
+      auto innerStructType =
+          LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
+
       SmallVector<Type> structBody;
-      for (unsigned i = 0; i < vtype.getDimSize(0); i++)
-        structBody.push_back(vtype.getElementType());
+      for (int i = 0; i < sizeM; i += kWgmmaSizeM)
+        structBody.push_back(innerStructType);
+
       auto convertedType =
           LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
       return converter.convertType(convertedType);
@@ -1186,7 +1204,6 @@ struct NVGPUWarpgroupMmaOpLowering
     nvgpu::WarpgroupMmaOp op;
     ImplicitLocOpBuilder b;
     OpAdaptor adaptor;
-    const LLVMTypeConverter &typeConverter;
 
     // Entire shape of the given Op
     int64_t totalM, totalN, totalK;
@@ -1330,7 +1347,7 @@ struct NVGPUWarpgroupMmaOpLowering
 
     /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
     /// descriptors and arranges them based on induction variables: i, j, and k.
-    Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+    Value generateWgmma(int i, int j, int k, Value matrixC) {
       LLVM_DEBUG(DBGS() << "\t wgmma."
                         << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
                         << "(A[" << (iterationM * wgmmaM) << ":"
@@ -1359,34 +1376,36 @@ struct NVGPUWarpgroupMmaOpLowering
       auto overflow = NVVM::MMAIntOverflowAttr::get(
           op->getContext(), NVVM::MMAIntOverflow::wrapped);
 
-      Type resultStructType = typeConverter.convertType(matrixD.getType());
-
       return b.create<NVVM::WgmmaMmaAsyncOp>(
-          resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
+          matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
           itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
     }
 
     /// Generates multiple wgmma instructions to complete the given GEMM shape
-    SmallVector<Value> generateWgmmaGroup() {
-      SmallVector<Value> wgmmaResults;
+    Value generateWgmmaGroup() {
+      Value wgmmaResult =
+          b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
 
       // Perform GEMM
+      SmallVector<Value> wgmmaResults;
       for (int i = 0; i < iterationM; ++i) {
-        Value matrixC = adaptor.getMatrixC()[i];
-        Value matrixD = op.getMatrixD()[i];
+        Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
         for (int j = 0; j < iterationN; ++j)
           for (int k = 0; k < iterationK; ++k)
-            matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+            matrixC = generateWgmma(i, j, k, matrixC);
         wgmmaResults.push_back(matrixC);
       }
-
-      return wgmmaResults;
+      for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
+        wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
+                                                    wgmmaResult, matrix, idx);
+      }
+      return wgmmaResult;
     }
 
   public:
     WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
-                  OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
-        : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
+                  OpAdaptor adaptor)
+        : op(op), b(b), adaptor(adaptor) {
       // Find the entire GEMM Shape
       totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
       totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
@@ -1411,27 +1430,27 @@ struct NVGPUWarpgroupMmaOpLowering
     /// instructions and group synchronization, as well as waiting
     /// (WgmmaGroupSyncAlignedOp) for group synchronization
     /// (WgmmaWaitGroupSyncOp) after the instructions.
-    SmallVector<Value> generateWarpgroupMma() {
+    Value generateWarpgroupMma() {
       b.create<NVVM::WgmmaFenceAlignedOp>();
-      SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+      Value wgmmaResult = generateWgmmaGroup();
       b.create<NVVM::WgmmaGroupSyncAlignedOp>();
       b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
-      return wgmmaResults;
+      return wgmmaResult;
     }
   };
-
   LogicalResult
   matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+
     // Step 1. Build a helper class
-    WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
+    WarpgroupGemm warpgroupGemm(op, b, adaptor);
 
     // Step 2. Get the entire GEMM Shape
-    SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+    Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
 
     // Step 3. Replace fragmented result struct with the op results
-    rewriter.replaceOp(op, wgmmaResults);
+    rewriter.replaceOp(op, wgmmaResult);
     return success();
   }
 };
@@ -1535,10 +1554,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
   matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     int offset = 0;
-    ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
-    for (Value matrixD : adaptor.getMatrixD()) {
-      auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
-      storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+    Value matriDValue = adaptor.getMatrixD();
+    auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
+    for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
+      auto structType = matrixD.cast<LLVM::LLVMStructType>();
+      Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
+      storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
       offset += structType.getBody().size();
     }
     rewriter.eraseOp(op);
@@ -1554,23 +1576,27 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
   matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
-    SmallVector<Value> results;
-    for (OpResult m : op.getMatrixC()) {
-      nvgpu::WarpgroupAccumulatorType mType =
-          m.getType().cast<nvgpu::WarpgroupAccumulatorType>();
-      Type stype = getTypeConverter()->convertType(mType);
-      Value undefStruct = b.create<LLVM::UndefOp>(stype);
-      Type elemType = mType.getFragmented().getElementType();
-      int64_t elemSize = mType.getFragmented().getDimSize(0);
-      Value zero =
-          b.create<LLVM::ConstantOp>(elemType, rewriter.getZeroAttr(elemType));
-      for (int64_t i = 0; i < elemSize; ++i) {
-        undefStruct = b.create<LLVM::InsertValueOp>(stype, undefStruct, zero,
-                                                    ArrayRef<int64_t>({i}));
+    LLVM::LLVMStructType structType =
+        getTypeConverter()
+            ->convertType(op.getMatrixC().getType())
+            .cast<LLVM::LLVMStructType>();
+    Type elemType = structType.getBody()
+                        .front()
+                        .cast<LLVM::LLVMStructType>()
+                        .getBody()
+                        .front();
+    Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
+    Value structValue = b.create<LLVM::UndefOp>(structType);
+    for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
+      auto innerStructType = s.cast<LLVM::LLVMStructType>();
+      int ii = idx;
+      Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
+      for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
+        innerStructValue = b.create<LLVM::InsertValueOp>(
+            innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
       }
-      results.push_back(undefStruct);
     }
-    rewriter.replaceOp(op, results);
+    rewriter.replaceOp(op, structValue);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index fe71eae899cd63d..f5b02fe1b515591 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -435,7 +435,11 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
   return failure();
 }
 
-LogicalResult isAllowedSizeM(int sizeM) { return success(sizeM == 64); }
+LogicalResult isAllowedSizeM(int sizeM) {
+  if (sizeM % kWgmmaSizeM)
+    return failure();
+  return success();
+}
 
 LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
   SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
@@ -458,35 +462,16 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
 
 LogicalResult WarpgroupMmaOp::verify() {
   if (getTransposeA() && !getTransposeB())
-    return emitOpError() << "supports non-transpose A (Row Major) "
-                            "and transpose B (Column Major) for the time being";
+    return emitOpError()
+           << "supports non-transpose A (Row Major) "
+              "and transpose B (Column Major) for the time being ";
   MemRefType matrixA = getDescriptorA().getType().getTensor();
   MemRefType matrixB = getDescriptorB().getType().getTensor();
-  VectorType matrixC = getMatrixC()
-                           .front()
-                           .getType()
-                           .cast<WarpgroupAccumulatorType>()
-                           .getFragmented();
-  VectorType matrixD = getMatrixD()
-                           .front()
-                           .getType()
-                           .cast<WarpgroupAccumulatorType>()
-                           .getFragmented();
-  unsigned sizeAcc = getMatrixC().size();
-
-  if (getMatrixC().size() != getMatrixD().size())
-    return emitOpError() << "number of matrix C and matrix D must be the same";
-
-  if (llvm::all_of(getMatrixC(),
-                   [&](Value rhs) { return rhs.getType() == matrixC; })) {
-    return emitOpError()
-           << "types of all operands in matrix C must be the same";
-  }
-  if (llvm::all_of(getMatrixD(),
-                   [&](Value rhs) { return rhs.getType() == matrixC; })) {
-    return emitOpError()
-           << "types of all operands in matrix D must be the same as matrix C";
-  }
+  VectorType matrixC = getMatrixC().getType().getFragmented();
+  VectorType matrixD = getMatrixD().getType().getFragmented();
+
+  if (matrixC != matrixD)
+    return emitOpError() << "type of matrix C and matrix D must be the same";
 
   if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
       matrixC.getRank() != 2 || matrixD.getRank() != 2) {
@@ -498,7 +483,7 @@ LogicalResult WarpgroupMmaOp::verify() {
     return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
                          << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
                          << " )";
-  if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc))
+  if (matrixA.getShape()[0] != matrixC.getShape()[0])
     return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
                          << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
                          << " )";
@@ -534,29 +519,16 @@ LogicalResult WarpgroupMmaOp::verify() {
 
 LogicalResult WarpgroupMmaStoreOp::verify() {
   MemRefType dstMemrefType = getDstMemref().getType();
-  VectorType firstVtype = getMatrixD()
-                              .front()
-                              .getType()
-                              .cast<WarpgroupAccumulatorType>()
-                              .getFragmented();
-
-  int64_t totalFirstDimension = 0;
-  for (Value result : getMatrixD()) {
-    VectorType vtype =
-        result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
-    if (vtype != firstVtype)
-      return emitOpError() << "all fragmented types must be the same";
-    // Limitation
-    if (!vtype.getElementType().isF32()) {
-      return emitOpError()
-             << "hit a limitation: only f32 results for the time being";
-    }
-    totalFirstDimension += vtype.getDimSize(0);
+  VectorType vtype = getMatrixD().getType().getFragmented();
+
+  // Limitation
+  if (!vtype.getElementType().isF32()) {
+    return emitOpError()
+           << "hit a limitation: only f32 results for the time being";
   }
-  if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
-      firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
-    return emitOpError() << "results [" << totalFirstDimension << "]["
-                         << firstVtype.getDimSize(1)
+  if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
+      vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
+    return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
                          << "] values. However, destination memref["
                          << dstMemrefType.getDimSize(0) << "]["
                          << dstMemrefType.getDimSize(1)
@@ -570,19 +542,18 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
-  for (OpResult matrix : getMatrixC()) {
-    VectorType vectorType = matrix.getType()
-                                .cast<nvgpu::WarpgroupAccumulatorType>()
-                                .getFragmented();
-    // Check [M][N] shape
-    if (failed(isAllowedSizeM(vectorType.getDimSize(0))) ||
-        failed(isAllowedSizeN(vectorType.getDimSize(1),
-                              vectorType.getElementType()))) {
-      return emitOpError() << "has type " << vectorType
-                           << ". It does not fit into warp-group "
-                              "level (wgmma) matrix multiplication instruction "
-                              "(or not supported yet)";
-    }
+
+  nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
+  int64_t sizeM = accType.getFragmented().getDimSize(0);
+  int64_t sizeN = accType.getFragmented().getDimSize(1);
+  Type elemType = accType.getFragmented().getElementType();
+
+  if (failed(isAllowedSizeM(sizeM)) ||
+      failed(isAllowedSizeN(sizeN, elemType))) {
+    return emitOpError() << "has type " << accType.getFragmented()
+                         << ". It does not fit into warp-group "
+                            "level (wgmma) matrix multiplication instruction "
+                            "(or not supported yet)";
   }
   return success();
 }
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index ca030575e5e961e..bf660e2683158e5 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -713,18 +713,18 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.warpgroup.
 }
 
 // CHECK-LABEL: @warpgroup_mma_128_128_64(  
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
 func.func @warpgroup_mma_128_128_64(
       %descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
       %descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
-      %acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
-      %acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>) 
+      %acc: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>) 
 {
 // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
 // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
-// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: nvvm.wgmma.fence.aligned
+// CHECK: %[[UD:.+]] =  llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S2]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64
 // CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64
@@ -741,6 +741,7 @@ func.func @warpgroup_mma_128_128_64(
 // CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64
 // CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]]  : i64
 // CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], <m = 64, n = 128, k = 16>, D[%[[S14]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S3:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64
 // CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]]  : i64
 // CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S3]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
@@ -759,27 +760,26 @@ func.func @warpgroup_mma_128_128_64(
 // CHECK: %[[S36:.+]] = llvm.mlir.constant(384 : i32) : i64
 // CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]]  : i64
 // CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], <m = 64, n = 128, k = 16>, D[%[[S33]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S40:.+]] = llvm.insertvalue %[[S19]], %[[UD]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
+// CHECK: %[[S41:.+]] = llvm.insertvalue %[[S38]], %[[S40]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: nvvm.wgmma.commit.group.sync.aligned
 // CHECK: nvvm.wgmma.wait.group.sync.aligned 1  
-  %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}: 
+  %wgmmaResult = nvgpu.warpgroup.mma %descA, %descB, %acc {transposeB}: 
       !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
       !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
-      !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
-      !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> 
+      !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> 
       -> 
-      !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, 
-      !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>  
+      !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>  
   return
 }
 
 // CHECK-LABEL: @warpgroup_mma_store(  
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
 func.func @warpgroup_mma_store(
-    %result1 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
-    %result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, 
+    %result : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, 
     %matrixD: memref<128x128xf32,3>) {
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK: %[[EX1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
 // CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
@@ -807,8 +807,8 @@ func.func @warpgroup_mma_store(
 // CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
 // CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]]  : i32
 // CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
-// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
-// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
+// CHECK: %[[S26:.+]] = llvm.extractvalue %[[EX1]][0] : !llvm.struct
+// CHECK: %[[S27:.+]] = llvm.extractvalue %[[EX1]][1] : !llvm.struct
 // CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
 // CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>
 
@@ -821,8 +821,8 @@ func.func @warpgroup_mma_store(
 // CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
 // CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]]  : i32
 // CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
-// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
-// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
+// CHECK: %[[S35:.+]] = llvm.extractvalue %[[EX1]][4] : !llvm.struct<
+// CHECK: %[[S36:.+]] = llvm.extractvalue %[[EX1]][5] : !llvm.struct<
 // CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
 // CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>
 
@@ -835,8 +835,8 @@ func.func @warpgroup_mma_store(
 // CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
 // CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]]  : i32
 // CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
-// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
-// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
+// CHECK: %[[S44:.+]] = llvm.extractvalue %[[EX1]][8] : !llvm.struct<
+// CHECK: %[[S45:.+]] = llvm.extractvalue %[[EX1]][9] : !llvm.struct<
 // CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
 // CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>
 
@@ -849,8 +849,8 @@ func.func @warpgroup_mma_store(
 // CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
 // CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]]  : i32
 // CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
-// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
-// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
+// CHECK: %[[S53:.+]] = llvm.extractvalue %[[EX1]][12] : !llvm.struct<
+// CHECK: %[[S54:.+]] = llvm.extractvalue %[[EX1]][13] : !llvm.struct<
 // CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
 // CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>
 
@@ -860,7 +860,7 @@ func.func @warpgroup_mma_store(
 // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
 
 // ### Store {d64, d65} of each thread ### 
-
+// CHECK: %[[EX2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
 // CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
 // CHECK: %[[S312:.+]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
@@ -887,24 +887,24 @@ func.func @warpgroup_mma_store(
 // CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index
 // CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]]  : i32
 // CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index
-// CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0] 
-// CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1]  
+// CHECK: %[[S337:.+]] = llvm.extractvalue %[[EX2]][0] 
+// CHECK: %[[S338:.+]] = llvm.extractvalue %[[EX2]][1]  
 // CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3>
 // CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3>
 
 // Pattern continues similarly 31x times until {... d126, d127}
 
-  nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD : 
-    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
-    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> 
+  nvgpu.warpgroup.mma.store %result, %matrixD : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<128x128xf32>> 
     to memref<128x128xf32,3>
   return 
 }
 
 func.func @warpgroup_mma_init() {
-  //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
   //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
-  //CHECK: %[[S2:.+]] = llvm.insertvalue %[[S1]], %[[S0]][0] : !llvm.struct
+  //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+  //CHECK: %[[EX:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
+  //CHECK: %[[S2:.+]] = llvm.insertvalue %[[S1]], %[[EX]][0] : !llvm.struct
   //CHECK: %[[S3:.+]] = llvm.insertvalue %[[S1]], %[[S2]][1] : !llvm.struct
   //CHECK: %[[S4:.+]] = llvm.insertvalue %[[S1]], %[[S3]][2] : !llvm.struct
   //CHECK: %[[S5:.+]] = llvm.insertvalue %[[S1]], %[[S4]][3] : !llvm.struct
@@ -968,10 +968,11 @@ func.func @warpgroup_mma_init() {
   //CHECK: %[[S63:.+]] = llvm.insertvalue %[[S1]], %[[S62]][61] : !llvm.struct
   //CHECK: %[[S64:.+]] = llvm.insertvalue %[[S1]], %[[S63]][62] : !llvm.struct
   //CHECK: %[[S65:.+]] = llvm.insertvalue %[[S1]], %[[S64]][63] : !llvm.struct
-  %matrixC = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+  %matrixC = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator< fragmented = vector<128x128xf32>>
   return 
 }
 
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 

>From 04582c019cec9c7c8fa9122b35ace2e301e6348c Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 16 Oct 2023 14:58:16 +0200
Subject: [PATCH 2/2] fix transform dialect

---
 .../NVGPU/TransformOps/NVGPUTransformOps.cpp  | 24 ++++++++++++++++---
 1 file changed, 21 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 94d7d565ff1a905..eaaadbbea4d0a75 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -62,10 +62,28 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
   });
   llvmTypeConverter.addConversion(
       [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
-        VectorType vtype = type.getFragmented();
+        Type elemType = type.getFragmented().getElementType();
+        int64_t sizeM = type.getFragmented().getDimSize(0);
+        int64_t sizeN = type.getFragmented().getDimSize(1);
+
+        unsigned numMembers;
+        if (elemType.isF32() || elemType.isInteger(32))
+          numMembers = sizeN / 2;
+        else if (elemType.isF16())
+          numMembers = sizeN / 4;
+        else
+          llvm_unreachable("unsupported type for warpgroup accumulator");
+
+        SmallVector<Type> innerStructBody;
+        for (unsigned i = 0; i < numMembers; i++)
+          innerStructBody.push_back(elemType);
+        auto innerStructType = LLVM::LLVMStructType::getLiteral(
+            type.getContext(), innerStructBody);
+
         SmallVector<Type> structBody;
-        for (unsigned i = 0; i < vtype.getDimSize(0); i++)
-          structBody.push_back(vtype.getElementType());
+        for (int i = 0; i < sizeM; i += kWgmmaSizeM)
+          structBody.push_back(innerStructType);
+
         auto convertedType =
             LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
         return llvmTypeConverter.convertType(convertedType);



More information about the Mlir-commits mailing list