[Mlir-commits] [mlir] [mlir][nvgpu] Improve `WarpgroupAccumulator` type to simplify IR (PR #68728)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 10 10:56:58 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-nvgpu

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

`WarpgroupAccumulator` (or `!nvgpu.warpgroup.accumulator`) is a type that keeps the accumulator matrix that is used by warp-group level matrix multiplication. It is handy to have a special type for that as the matrix is distributed among the threads of the warp-group. However, current transformations requires to create and use multiple `WarpgroupAccumulator` if the shape of GEMM is larger than the supported shape of `wgmma.mma_async` instruction. This makes IR looks dense.

This PR improves the transformation of `WarpgroupAccumulator` type in every nvgpu Op that uses it. 

**Example: Current GEMM in NVGPU-IR**
```
// Init
%m1, %m2 = nvgpu.wargroup.mma.init.accumulator ->  
                    !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>,
                    !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>

// GEMM
%r1, %r2 = nvgpu.warpgroup.mma %descA, %descB, %m1, %m2 {transposeB}: 
      !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
      !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> 
      -> 
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>  


// Epilogue 
nvgpu.wargroup.mma.store [%r1, %r2] to %sharedMemoryBuffer
  : !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>, 
    !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>
    into memref<128x128xf32,3>
```

**Example: This PR simplifies the IR as below:**
```
// Init
%m = nvgpu.wargroup.mma.init.accumulator ->  
           !nvgpu.wargroup.accumulator<fragmented = vector<128x128xf32>>

// GEMM
%r1 = nvgpu.warpgroup.mma %descA, %descB, %m1 {transposeB}: 
      !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
      !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
      !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> 
      -> 
      !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>  

// Epilogue 
nvgpu.wargroup.mma.store [%matrixD1, %matrixD2] to %sharedMemoryBuffer
  : !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>, 
    !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>
    into memref<128x128xf32,3>
```

---

Patch is 43.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68728.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+16-4) 
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h (+3) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+84-28) 
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+46-51) 
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+101-29) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 57cd1a3806c2ed6..fd16376be366912 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -719,8 +719,8 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
                        DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
                        OptionalAttr<UnitAttr>:$transposeA,
                        OptionalAttr<UnitAttr>:$transposeB,
-                       Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
-  let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
+                       NVGPU_WarpgroupAccumulator:$matrixC);
+  let results = (outs NVGPU_WarpgroupAccumulator:$matrixD);
   let assemblyFormat = [{    
     $descriptorA`,` $descriptorB`,` $matrixC attr-dict
     `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
@@ -739,13 +739,25 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
     Note that, the op must be run with warp group.
   }];
 
-  let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
+  let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
                        Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
   
   let assemblyFormat = [{
-    `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+    $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
   }];
   let hasVerifier = 1;
 }
 
+def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulator"> {  
+  let summary = "Initializes the accumulator matrix";
+
+  let description = [{
+    This Op generates and initializes the accumulator matrix for 
+    `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate.
+  }];
+  let results = (outs NVGPU_WarpgroupAccumulator:$matrixC);
+  let assemblyFormat = "attr-dict `->` type($matrixC)";
+  let hasVerifier = 1;
+}
+
 #endif // NVGPU
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 96af26842dafea2..e6bba7e6082964b 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -23,6 +23,9 @@
 
 constexpr int kWarpSize = 32;
 
+/// M size of wgmma.mma_async instruction
+constexpr int kWgmmaSizeM = 64;
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 99c4d4223351352..2d43230938526b9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -412,10 +412,28 @@ struct ConvertNVGPUToNVVMPass
       return converter.convertType(IntegerType::get(type.getContext(), 32));
     });
     converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
-      VectorType vtype = type.getFragmented();
+      Type elemType = type.getFragmented().getElementType();
+      int64_t sizeM = type.getFragmented().getDimSize(0);
+      int64_t sizeN = type.getFragmented().getDimSize(1);
+
+      unsigned numMembers;
+      if (elemType.isF32() || elemType.isInteger(32))
+        numMembers = sizeN / 2;
+      else if (elemType.isF16())
+        numMembers = sizeN / 4;
+      else
+        llvm_unreachable("unsupported type for warpgroup accumulator");
+
+      SmallVector<Type> innerStructBody;
+      for (unsigned i = 0; i < numMembers; i++)
+        innerStructBody.push_back(elemType);
+      auto innerStructType =
+          LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
+
       SmallVector<Type> structBody;
-      for (unsigned i = 0; i < vtype.getDimSize(0); i++)
-        structBody.push_back(vtype.getElementType());
+      for (int i = 0; i < sizeM; i += kWgmmaSizeM)
+        structBody.push_back(innerStructType);
+
       auto convertedType =
           LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
       return converter.convertType(convertedType);
@@ -1186,7 +1204,6 @@ struct NVGPUWarpgroupMmaOpLowering
     nvgpu::WarpgroupMmaOp op;
     ImplicitLocOpBuilder b;
     OpAdaptor adaptor;
-    const LLVMTypeConverter &typeConverter;
 
     // Entire shape of the given Op
     int64_t totalM, totalN, totalK;
@@ -1330,7 +1347,7 @@ struct NVGPUWarpgroupMmaOpLowering
 
     /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
     /// descriptors and arranges them based on induction variables: i, j, and k.
-    Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+    Value generateWgmma(int i, int j, int k, Value matrixC) {
       LLVM_DEBUG(DBGS() << "\t wgmma."
                         << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
                         << "(A[" << (iterationM * wgmmaM) << ":"
@@ -1359,34 +1376,36 @@ struct NVGPUWarpgroupMmaOpLowering
       auto overflow = NVVM::MMAIntOverflowAttr::get(
           op->getContext(), NVVM::MMAIntOverflow::wrapped);
 
-      Type resultStructType = typeConverter.convertType(matrixD.getType());
-
       return b.create<NVVM::WgmmaMmaAsyncOp>(
-          resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
+          matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
           itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
     }
 
     /// Generates multiple wgmma instructions to complete the given GEMM shape
-    SmallVector<Value> generateWgmmaGroup() {
-      SmallVector<Value> wgmmaResults;
+    Value generateWgmmaGroup() {
+      Value wgmmaResult =
+          b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
 
       // Perform GEMM
+      SmallVector<Value> wgmmaResults;
       for (int i = 0; i < iterationM; ++i) {
-        Value matrixC = adaptor.getMatrixC()[i];
-        Value matrixD = op.getMatrixD()[i];
+        Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
         for (int j = 0; j < iterationN; ++j)
           for (int k = 0; k < iterationK; ++k)
-            matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+            matrixC = generateWgmma(i, j, k, matrixC);
         wgmmaResults.push_back(matrixC);
       }
-
-      return wgmmaResults;
+      for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
+        wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
+                                                    wgmmaResult, matrix, idx);
+      }
+      return wgmmaResult;
     }
 
   public:
     WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
-                  OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
-        : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
+                  OpAdaptor adaptor)
+        : op(op), b(b), adaptor(adaptor) {
       // Find the entire GEMM Shape
       totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
       totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
@@ -1411,27 +1430,27 @@ struct NVGPUWarpgroupMmaOpLowering
     /// instructions and group synchronization, as well as waiting
     /// (WgmmaGroupSyncAlignedOp) for group synchronization
     /// (WgmmaWaitGroupSyncOp) after the instructions.
-    SmallVector<Value> generateWarpgroupMma() {
+    Value generateWarpgroupMma() {
       b.create<NVVM::WgmmaFenceAlignedOp>();
-      SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+      Value wgmmaResult = generateWgmmaGroup();
       b.create<NVVM::WgmmaGroupSyncAlignedOp>();
       b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
-      return wgmmaResults;
+      return wgmmaResult;
     }
   };
-
   LogicalResult
   matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+
     // Step 1. Build a helper class
-    WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
+    WarpgroupGemm warpgroupGemm(op, b, adaptor);
 
     // Step 2. Get the entire GEMM Shape
-    SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+    Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
 
     // Step 3. Replace fragmented result struct with the op results
-    rewriter.replaceOp(op, wgmmaResults);
+    rewriter.replaceOp(op, wgmmaResult);
     return success();
   }
 };
@@ -1535,10 +1554,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
   matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     int offset = 0;
-    ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
-    for (Value matrixD : adaptor.getMatrixD()) {
-      auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
-      storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+    Value matriDValue = adaptor.getMatrixD();
+    auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
+    for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
+      auto structType = matrixD.cast<LLVM::LLVMStructType>();
+      Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
+      storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
       offset += structType.getBody().size();
     }
     rewriter.eraseOp(op);
@@ -1546,6 +1568,39 @@ struct NVGPUWarpgroupMmaStoreOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+    LLVM::LLVMStructType structType =
+        getTypeConverter()
+            ->convertType(op.getMatrixC().getType())
+            .cast<LLVM::LLVMStructType>();
+    Type elemType = structType.getBody()
+                        .front()
+                        .cast<LLVM::LLVMStructType>()
+                        .getBody()
+                        .front();
+    Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
+    Value structValue = b.create<LLVM::UndefOp>(structType);
+    for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
+      auto innerStructType = s.cast<LLVM::LLVMStructType>();
+      int ii = idx;
+      Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
+      for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
+        innerStructValue = b.create<LLVM::InsertValueOp>(
+            innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
+      }
+    }
+    rewriter.replaceOp(op, structValue);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1563,6 +1618,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
       NVGPUWarpgroupMmaOpLowering,              // nvgpu.warpgroup.mma
       NVGPUWarpgroupMmaStoreOpLowering,         // nvgpu.warpgroup.mma.store
+      NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index e8ecd0faa4c86d3..f5b02fe1b515591 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -435,6 +435,12 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
   return failure();
 }
 
+LogicalResult isAllowedSizeM(int sizeM) {
+  if (sizeM % kWgmmaSizeM)
+    return failure();
+  return success();
+}
+
 LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
   SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
                                72,  80,  88,  96,  104, 112, 120, 128,
@@ -443,7 +449,7 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
   SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
                                     80,  96,  112, 128, 144, 160,
                                     176, 192, 208, 224, 240, 256};
-  if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() ||
+  if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
       typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
     if (llvm::is_contained(allowedN, sizeN))
       return success();
@@ -456,35 +462,16 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
 
 LogicalResult WarpgroupMmaOp::verify() {
   if (getTransposeA() && !getTransposeB())
-    return emitOpError() << "supports non-transpose A (Row Major) "
-                            "and transpose B (Column Major) for the time being";
+    return emitOpError()
+           << "supports non-transpose A (Row Major) "
+              "and transpose B (Column Major) for the time being ";
   MemRefType matrixA = getDescriptorA().getType().getTensor();
   MemRefType matrixB = getDescriptorB().getType().getTensor();
-  VectorType matrixC = getMatrixC()
-                           .front()
-                           .getType()
-                           .cast<WarpgroupAccumulatorType>()
-                           .getFragmented();
-  VectorType matrixD = getMatrixD()
-                           .front()
-                           .getType()
-                           .cast<WarpgroupAccumulatorType>()
-                           .getFragmented();
-  unsigned sizeAcc = getMatrixC().size();
-
-  if (getMatrixC().size() != getMatrixD().size())
-    return emitOpError() << "number of matrix C and matrix D must be the same";
-
-  if (llvm::all_of(getMatrixC(),
-                   [&](Value rhs) { return rhs.getType() == matrixC; })) {
-    return emitOpError()
-           << "types of all operands in matrix C must be the same";
-  }
-  if (llvm::all_of(getMatrixD(),
-                   [&](Value rhs) { return rhs.getType() == matrixC; })) {
-    return emitOpError()
-           << "types of all operands in matrix D must be the same as matrix C";
-  }
+  VectorType matrixC = getMatrixC().getType().getFragmented();
+  VectorType matrixD = getMatrixD().getType().getFragmented();
+
+  if (matrixC != matrixD)
+    return emitOpError() << "type of matrix C and matrix D must be the same";
 
   if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
       matrixC.getRank() != 2 || matrixD.getRank() != 2) {
@@ -496,7 +483,7 @@ LogicalResult WarpgroupMmaOp::verify() {
     return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
                          << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
                          << " )";
-  if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc))
+  if (matrixA.getShape()[0] != matrixC.getShape()[0])
     return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
                          << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
                          << " )";
@@ -532,29 +519,16 @@ LogicalResult WarpgroupMmaOp::verify() {
 
 LogicalResult WarpgroupMmaStoreOp::verify() {
   MemRefType dstMemrefType = getDstMemref().getType();
-  VectorType firstVtype = getMatrixD()
-                              .front()
-                              .getType()
-                              .cast<WarpgroupAccumulatorType>()
-                              .getFragmented();
-
-  int64_t totalFirstDimension = 0;
-  for (Value result : getMatrixD()) {
-    VectorType vtype =
-        result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
-    if (vtype != firstVtype)
-      return emitOpError() << "all fragmented types must be the same";
-    // Limitation
-    if (!vtype.getElementType().isF32()) {
-      return emitOpError()
-             << "hit a limitation: only f32 results for the time being";
-    }
-    totalFirstDimension += vtype.getDimSize(0);
+  VectorType vtype = getMatrixD().getType().getFragmented();
+
+  // Limitation
+  if (!vtype.getElementType().isF32()) {
+    return emitOpError()
+           << "hit a limitation: only f32 results for the time being";
   }
-  if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
-      firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
-    return emitOpError() << "results [" << totalFirstDimension << "]["
-                         << firstVtype.getDimSize(1)
+  if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
+      vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
+    return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
                          << "] values. However, destination memref["
                          << dstMemrefType.getDimSize(0) << "]["
                          << dstMemrefType.getDimSize(1)
@@ -563,6 +537,27 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// WarpgroupMmaInitAccumulatorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
+
+  nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
+  int64_t sizeM = accType.getFragmented().getDimSize(0);
+  int64_t sizeN = accType.getFragmented().getDimSize(1);
+  Type elemType = accType.getFragmented().getElementType();
+
+  if (failed(isAllowedSizeM(sizeM)) ||
+      failed(isAllowedSizeN(sizeN, elemType))) {
+    return emitOpError() << "has type " << accType.getFragmented()
+                         << ". It does not fit into warp-group "
+                            "level (wgmma) matrix multiplication instruction "
+                            "(or not supported yet)";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd dialect, type, and op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index e54b62a06d4313a..bf660e2683158e5 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -713,18 +713,18 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.warpgroup.
 }
 
 // CHECK-LABEL: @warpgroup_mma_128_128_64(  
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
 func.func @warpgroup_mma_128_128_64(
       %descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
       %descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
-      %acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
-      %acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>) 
+      %acc: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>) 
 {
 // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
 // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[ar...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/68728


More information about the Mlir-commits mailing list