[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `warpgroup.init.accumulator` Op (PR #67530)

Guray Ozen llvmlistbot at llvm.org
Thu Oct 5 08:56:44 PDT 2023


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/67530

>From 05ee1276bdebb73b87ba02fdcb4fd75ca8087c0c Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 5 Oct 2023 09:06:04 +0200
Subject: [PATCH 1/4] [MLIR][NVGPU] Introduce `warpgroup.init.accumulator` Op
 #67530

This Op generates and initilizes the accumulator matrix for `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate (mma).

Its associated transformation generates `!llvm.struct<>` and fill it with the initial values. The size of struct is number of required inout registers for `nvgpu.warpgroup.mma` op.
---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   | 11 +++
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 30 ++++++++
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 71 +++++++++++++++++++
 3 files changed, 112 insertions(+)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 337174d6367157a..6c14c6c8ac9ca4c 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -728,4 +728,15 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
   let hasVerifier = 1;
 }
 
+def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulator"> {  
+  let summary = "Initialize accumulator matrix for `warppgroup.mma`";
+
+  let description = [{
+    This Op generates and initilizes the accumulator matrix for 
+    `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate (mma).
+  }];
+  let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
+  let assemblyFormat = "attr-dict `->` type($matrixC)";
+}
+
 #endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index a35c3d9c069e709..4665a3ed2dc5c57 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1295,6 +1295,35 @@ struct NVGPUWarpgroupMmaOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    SmallVector<Value> results;
+    for (auto matrixD : op.getMatrixC()) {
+      nvgpu::WarpgroupAccumulatorType matrixDType =
+          matrixD.getType().cast<nvgpu::WarpgroupAccumulatorType>();
+      Type stype = getTypeConverter()->convertType(matrixDType);
+      Value undefStruct = rewriter.create<LLVM::UndefOp>(loc, stype);
+      Type elemType = matrixDType.getFragmented().getElementType();
+      int64_t elemSize = matrixDType.getFragmented().getDimSize(0);
+      Value zero = rewriter.create<LLVM::ConstantOp>(
+          loc, elemType, rewriter.getZeroAttr(elemType));
+      for (int64_t i = 0; i < elemSize; ++i) {
+        undefStruct = rewriter.create<LLVM::InsertValueOp>(
+            loc, stype, undefStruct, zero, ArrayRef<int64_t>({i}));
+      }
+      results.push_back(undefStruct);
+    }
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1311,6 +1340,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
       NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
       NVGPUWarpgroupMmaOpLowering,              // nvgpu.warpgroup.mma
+      NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 3710b06288e2a7f..e1589af68f4f1c4 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -772,6 +772,77 @@ func.func @warpgroup_mma_128_128_64(
   return
 }
 
+func.func @warpgroup_mma_init(){
+  //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+  //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
+  //CHECK: %[[S2:.+]] = llvm.insertvalue %[[S1]], %[[S0]][0] : !llvm.struct
+  //CHECK: %[[S3:.+]] = llvm.insertvalue %[[S1]], %[[S2]][1] : !llvm.struct
+  //CHECK: %[[S4:.+]] = llvm.insertvalue %[[S1]], %[[S3]][2] : !llvm.struct
+  //CHECK: %[[S5:.+]] = llvm.insertvalue %[[S1]], %[[S4]][3] : !llvm.struct
+  //CHECK: %[[S6:.+]] = llvm.insertvalue %[[S1]], %[[S5]][4] : !llvm.struct
+  //CHECK: %[[S7:.+]] = llvm.insertvalue %[[S1]], %[[S6]][5] : !llvm.struct
+  //CHECK: %[[S8:.+]] = llvm.insertvalue %[[S1]], %[[S7]][6] : !llvm.struct
+  //CHECK: %[[S9:.+]] = llvm.insertvalue %[[S1]], %[[S8]][7] : !llvm.struct
+  //CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][8] : !llvm.struct
+  //CHECK: %[[S11:.+]] = llvm.insertvalue %[[S1]], %[[S10]][9] : !llvm.struct
+  //CHECK: %[[S12:.+]] = llvm.insertvalue %[[S1]], %[[S11]][10] : !llvm.struct
+  //CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][11] : !llvm.struct
+  //CHECK: %[[S14:.+]] = llvm.insertvalue %[[S1]], %[[S13]][12] : !llvm.struct
+  //CHECK: %[[S15:.+]] = llvm.insertvalue %[[S1]], %[[S14]][13] : !llvm.struct
+  //CHECK: %[[S16:.+]] = llvm.insertvalue %[[S1]], %[[S15]][14] : !llvm.struct
+  //CHECK: %[[S17:.+]] = llvm.insertvalue %[[S1]], %[[S16]][15] : !llvm.struct
+  //CHECK: %[[S18:.+]] = llvm.insertvalue %[[S1]], %[[S17]][16] : !llvm.struct
+  //CHECK: %[[S19:.+]] = llvm.insertvalue %[[S1]], %[[S18]][17] : !llvm.struct
+  //CHECK: %[[S20:.+]] = llvm.insertvalue %[[S1]], %[[S19]][18] : !llvm.struct
+  //CHECK: %[[S21:.+]] = llvm.insertvalue %[[S1]], %[[S20]][19] : !llvm.struct
+  //CHECK: %[[S22:.+]] = llvm.insertvalue %[[S1]], %[[S21]][20] : !llvm.struct
+  //CHECK: %[[S23:.+]] = llvm.insertvalue %[[S1]], %[[S22]][21] : !llvm.struct
+  //CHECK: %[[S24:.+]] = llvm.insertvalue %[[S1]], %[[S23]][22] : !llvm.struct
+  //CHECK: %[[S25:.+]] = llvm.insertvalue %[[S1]], %[[S24]][23] : !llvm.struct
+  //CHECK: %[[S26:.+]] = llvm.insertvalue %[[S1]], %[[S25]][24] : !llvm.struct
+  //CHECK: %[[S27:.+]] = llvm.insertvalue %[[S1]], %[[S26]][25] : !llvm.struct
+  //CHECK: %[[S28:.+]] = llvm.insertvalue %[[S1]], %[[S27]][26] : !llvm.struct
+  //CHECK: %[[S29:.+]] = llvm.insertvalue %[[S1]], %[[S28]][27] : !llvm.struct
+  //CHECK: %[[S30:.+]] = llvm.insertvalue %[[S1]], %[[S29]][28] : !llvm.struct
+  //CHECK: %[[S31:.+]] = llvm.insertvalue %[[S1]], %[[S30]][29] : !llvm.struct
+  //CHECK: %[[S32:.+]] = llvm.insertvalue %[[S1]], %[[S31]][30] : !llvm.struct
+  //CHECK: %[[S33:.+]] = llvm.insertvalue %[[S1]], %[[S32]][31] : !llvm.struct
+  //CHECK: %[[S34:.+]] = llvm.insertvalue %[[S1]], %[[S33]][32] : !llvm.struct
+  //CHECK: %[[S35:.+]] = llvm.insertvalue %[[S1]], %[[S34]][33] : !llvm.struct
+  //CHECK: %[[S36:.+]] = llvm.insertvalue %[[S1]], %[[S35]][34] : !llvm.struct
+  //CHECK: %[[S37:.+]] = llvm.insertvalue %[[S1]], %[[S36]][35] : !llvm.struct
+  //CHECK: %[[S38:.+]] = llvm.insertvalue %[[S1]], %[[S37]][36] : !llvm.struct
+  //CHECK: %[[S39:.+]] = llvm.insertvalue %[[S1]], %[[S38]][37] : !llvm.struct
+  //CHECK: %[[S40:.+]] = llvm.insertvalue %[[S1]], %[[S39]][38] : !llvm.struct
+  //CHECK: %[[S41:.+]] = llvm.insertvalue %[[S1]], %[[S40]][39] : !llvm.struct
+  //CHECK: %[[S42:.+]] = llvm.insertvalue %[[S1]], %[[S41]][40] : !llvm.struct
+  //CHECK: %[[S43:.+]] = llvm.insertvalue %[[S1]], %[[S42]][41] : !llvm.struct
+  //CHECK: %[[S44:.+]] = llvm.insertvalue %[[S1]], %[[S43]][42] : !llvm.struct
+  //CHECK: %[[S45:.+]] = llvm.insertvalue %[[S1]], %[[S44]][43] : !llvm.struct
+  //CHECK: %[[S46:.+]] = llvm.insertvalue %[[S1]], %[[S45]][44] : !llvm.struct
+  //CHECK: %[[S47:.+]] = llvm.insertvalue %[[S1]], %[[S46]][45] : !llvm.struct
+  //CHECK: %[[S48:.+]] = llvm.insertvalue %[[S1]], %[[S47]][46] : !llvm.struct
+  //CHECK: %[[S49:.+]] = llvm.insertvalue %[[S1]], %[[S48]][47] : !llvm.struct
+  //CHECK: %[[S50:.+]] = llvm.insertvalue %[[S1]], %[[S49]][48] : !llvm.struct
+  //CHECK: %[[S51:.+]] = llvm.insertvalue %[[S1]], %[[S50]][49] : !llvm.struct
+  //CHECK: %[[S52:.+]] = llvm.insertvalue %[[S1]], %[[S51]][50] : !llvm.struct
+  //CHECK: %[[S53:.+]] = llvm.insertvalue %[[S1]], %[[S52]][51] : !llvm.struct
+  //CHECK: %[[S54:.+]] = llvm.insertvalue %[[S1]], %[[S53]][52] : !llvm.struct
+  //CHECK: %[[S55:.+]] = llvm.insertvalue %[[S1]], %[[S54]][53] : !llvm.struct
+  //CHECK: %[[S56:.+]] = llvm.insertvalue %[[S1]], %[[S55]][54] : !llvm.struct
+  //CHECK: %[[S57:.+]] = llvm.insertvalue %[[S1]], %[[S56]][55] : !llvm.struct
+  //CHECK: %[[S58:.+]] = llvm.insertvalue %[[S1]], %[[S57]][56] : !llvm.struct
+  //CHECK: %[[S59:.+]] = llvm.insertvalue %[[S1]], %[[S58]][57] : !llvm.struct
+  //CHECK: %[[S60:.+]] = llvm.insertvalue %[[S1]], %[[S59]][58] : !llvm.struct
+  //CHECK: %[[S61:.+]] = llvm.insertvalue %[[S1]], %[[S60]][59] : !llvm.struct
+  //CHECK: %[[S62:.+]] = llvm.insertvalue %[[S1]], %[[S61]][60] : !llvm.struct
+  //CHECK: %[[S63:.+]] = llvm.insertvalue %[[S1]], %[[S62]][61] : !llvm.struct
+  //CHECK: %[[S64:.+]] = llvm.insertvalue %[[S1]], %[[S63]][62] : !llvm.struct
+  //CHECK: %[[S65:.+]] = llvm.insertvalue %[[S1]], %[[S64]][63] : !llvm.struct
+  %matrixC = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+  return 
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 

>From 2781170f377e900b1e82cdbc6d680c9cf3f66215 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 5 Oct 2023 09:07:42 +0200
Subject: [PATCH 2/4] address comments

---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td     | 2 +-
 mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 6c14c6c8ac9ca4c..ee141af3812fbff 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -729,7 +729,7 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
 }
 
 def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulator"> {  
-  let summary = "Initialize accumulator matrix for `warppgroup.mma`";
+  let summary = "Initializes accumulator matrix for `warppgroup.mma`";
 
   let description = [{
     This Op generates and initilizes the accumulator matrix for 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4665a3ed2dc5c57..270d7054e72eb91 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1304,7 +1304,7 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
     SmallVector<Value> results;
-    for (auto matrixD : op.getMatrixC()) {
+    for (Value matrixD : op.getMatrixC()) {
       nvgpu::WarpgroupAccumulatorType matrixDType =
           matrixD.getType().cast<nvgpu::WarpgroupAccumulatorType>();
       Type stype = getTypeConverter()->convertType(matrixDType);

>From d748caf00aa356ae6b428075034df090378e0cb6 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 5 Oct 2023 09:59:29 +0200
Subject: [PATCH 3/4] add verifier

---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   |  1 +
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 24 ++++++++---------
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp    | 26 ++++++++++++++++++-
 3 files changed, 38 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index ee141af3812fbff..43c948550ed623f 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -737,6 +737,7 @@ def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulat
   }];
   let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
   let assemblyFormat = "attr-dict `->` type($matrixC)";
+  let hasVerifier = 1;
 }
 
 #endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 270d7054e72eb91..535b76774941052 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1302,20 +1302,20 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
   LogicalResult
   matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     SmallVector<Value> results;
-    for (Value matrixD : op.getMatrixC()) {
-      nvgpu::WarpgroupAccumulatorType matrixDType =
-          matrixD.getType().cast<nvgpu::WarpgroupAccumulatorType>();
-      Type stype = getTypeConverter()->convertType(matrixDType);
-      Value undefStruct = rewriter.create<LLVM::UndefOp>(loc, stype);
-      Type elemType = matrixDType.getFragmented().getElementType();
-      int64_t elemSize = matrixDType.getFragmented().getDimSize(0);
-      Value zero = rewriter.create<LLVM::ConstantOp>(
-          loc, elemType, rewriter.getZeroAttr(elemType));
+    for (OpResult m : op.getMatrixC()) {
+      nvgpu::WarpgroupAccumulatorType mType =
+          m.getType().cast<nvgpu::WarpgroupAccumulatorType>();
+      Type stype = getTypeConverter()->convertType(mType);
+      Value undefStruct = b.create<LLVM::UndefOp>(stype);
+      Type elemType = mType.getFragmented().getElementType();
+      int64_t elemSize = mType.getFragmented().getDimSize(0);
+      Value zero =
+          b.create<LLVM::ConstantOp>(elemType, rewriter.getZeroAttr(elemType));
       for (int64_t i = 0; i < elemSize; ++i) {
-        undefStruct = rewriter.create<LLVM::InsertValueOp>(
-            loc, stype, undefStruct, zero, ArrayRef<int64_t>({i}));
+        undefStruct = b.create<LLVM::InsertValueOp>(stype, undefStruct, zero,
+                                                    ArrayRef<int64_t>({i}));
       }
       results.push_back(undefStruct);
     }
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index eb8fc4b65bc89ad..11f37c7f9cc94d6 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -434,6 +434,8 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
   return failure();
 }
 
+LogicalResult isAllowedSizeM(int sizeM) { return success(sizeM == 64); }
+
 LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
   SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
                                72,  80,  88,  96,  104, 112, 120, 128,
@@ -442,7 +444,7 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
   SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
                                     80,  96,  112, 128, 144, 160,
                                     176, 192, 208, 224, 240, 256};
-  if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() ||
+  if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
       typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
     if (llvm::is_contained(allowedN, sizeN))
       return success();
@@ -529,6 +531,28 @@ LogicalResult WarpgroupMmaOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// WarpgroupMmaInitAccumulatorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
+  for (OpResult matrix : getMatrixC()) {
+    VectorType vectorType = matrix.getType()
+                                .cast<nvgpu::WarpgroupAccumulatorType>()
+                                .getFragmented();
+    // Check [M][N] shape
+    if (failed(isAllowedSizeM(vectorType.getDimSize(0))) ||
+        failed(isAllowedSizeN(vectorType.getDimSize(1),
+                              vectorType.getElementType()))) {
+      return emitOpError() << "has type " << vectorType
+                           << ". It does not fit into warp-group "
+                              "level (wgmma) matrix multiplication instruction "
+                              "(or not supported yet)";
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd dialect, type, and op definitions
 //===----------------------------------------------------------------------===//

>From 96194cdf4122e196c61cf3e3274a0585a2603cc7 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 5 Oct 2023 17:56:07 +0200
Subject: [PATCH 4/4] fix typo

---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 43c948550ed623f..efe21211920cfe9 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -729,11 +729,11 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
 }
 
 def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulator"> {  
-  let summary = "Initializes accumulator matrix for `warppgroup.mma`";
+  let summary = "Initializes the accumulator matrix";
 
   let description = [{
-    This Op generates and initilizes the accumulator matrix for 
-    `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate (mma).
+    This Op generates and initializes the accumulator matrix for 
+    `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate.
   }];
   let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
   let assemblyFormat = "attr-dict `->` type($matrixC)";



More information about the Mlir-commits mailing list