[Mlir-commits] [mlir] [mlir][nvgpu] Improve nvgpu->nvvm transformation of `warpgroup.mma` Op (NFC) (PR #67325)

Guray Ozen llvmlistbot at llvm.org
Thu Oct 5 01:14:27 PDT 2023


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/67325

>From a1195786930afc18f79958068fa58caf4ab8b867 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 25 Sep 2023 15:09:51 +0200
Subject: [PATCH 1/2] [mlir][nvgpu] Improve nvgpu->nvvm transformation of
 `nvgpu.warpgroup.mma` Op (NFC)

This PR introduces substantial improvements to the readability and maintainability of the `nvgpu.warpgroup.mma` Op transformation from nvgpu->nvvm. This transformation plays a crucial role in GEMM and manages complex operations such as generating multiple wgmma ops and iterating their descriptors. The prior code lacked clarity, but this PR addresses that issue effectively.

PR introduces a helper class `WarpgroupGemm`. This class encapsulates the necessary functionality, making the code cleaner and more understandable. Each function within the helper class is thoroughly documented to provide clear insights into its purpose and functionality.
---
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 350 ++++++++++++------
 1 file changed, 240 insertions(+), 110 deletions(-)

diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index a35c3d9c069e709..303462c22a6859c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -39,7 +39,7 @@ namespace mlir {
 
 using namespace mlir;
 
-/// Number of bits that needs to excluded when building matrix descriptor for
+/// Number of bits that needs to be excluded when building matrix descriptor for
 /// wgmma operations.
 constexpr int exclude4LSB = 4;
 
@@ -1160,137 +1160,267 @@ struct NVGPUWarpgroupMmaOpLowering
     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
   using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
 
-  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
-                              int &wgmmaShapeM, int &wgmmaShapeN,
-                              int &wgmmaShapeK) const {
-    wgmmaShapeM = 64;
-    wgmmaShapeN = sizeN;
-    if (inputElemType.isTF32()) {
-      wgmmaShapeK = 8;
-    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
-      wgmmaShapeK = 16;
-    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
-               inputElemType.isInteger(16)) {
-      wgmmaShapeK = 32;
-    } else if (inputElemType.isInteger(1)) {
-      wgmmaShapeK = 256;
-    } else {
-      llvm_unreachable("msg: not supported K shape");
+  /// This class assists in generating WgmmaMmaAsyncOp instructions to complete
+  /// a specified shape. If the GEMM shape is larger than the shape of a wgmma
+  /// instrution, it can generate multiple wgmma instructions, group and execute
+  /// them asynchronously. The class also handles waiting for instruction
+  /// completion and iterates through GenerateGmmaDescriptor to create
+  /// descriptors for each instruction.
+  class WarpgroupGemm {
+    nvgpu::WarpgroupMmaOp op;
+    ConversionPatternRewriter &rewriter;
+    OpAdaptor adaptor;
+    const LLVMTypeConverter &typeConverter;
+
+    // Entire shape of the given Op
+    int64_t totalM, totalN, totalK;
+
+    // Shape of one wgmma instruction
+    int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
+
+    // Iteration counts for GEMM
+    int iterationM = 0, iterationN = 0, iterationK = 0;
+
+    /// The function returns the shape of wgmma instruction that is defined in
+    /// PTX programming guide.
+    /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
+    void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
+      wgmmaM = 64;
+      wgmmaN = sizeN;
+      if (inputElemType.isTF32()) {
+        wgmmaK = 8;
+      } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+        wgmmaK = 16;
+      } else if (inputElemType.isFloat8E4M3FN() ||
+                 inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
+        wgmmaK = 32;
+      } else if (inputElemType.isInteger(1)) {
+        wgmmaK = 256;
+      } else {
+        llvm_unreachable("msg: not supported K shape");
+      }
+      LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+                        << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
     }
-    LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
-                      << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
-                      << "]\n");
-    return success();
-  }
 
-  Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
-                            Type resultStructType, Value inout,
-                            Value descriptorA, Value descriptorB) const {
-    MLIRContext *ctx = b.getContext();
-    auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
-    auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
-    auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
-    auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
-    auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
-    // todo: handle other input and output types
-    auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
-    auto overflow =
-        NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
-    Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
-        resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
-        scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
-    return res;
-  }
-
-  LogicalResult
-  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
-    int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
-    int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
-    int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
-
-    LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
-                      << sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
-                      << sizeN << "] ---===\n");
+    /// Generates WGMMATypesAttr from MLIR Type
+    NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
+      auto getWgmmaType = [](Type elemType) {
+        if (elemType.isF32() || elemType.isTF32())
+          return NVVM::WGMMATypes::tf32;
+        if (elemType.isF16())
+          return NVVM::WGMMATypes::f16;
+        if (elemType.isBF16())
+          return NVVM::WGMMATypes::bf16;
+        if (elemType.isFloat8E4M3FN())
+          return NVVM::WGMMATypes::e4m3;
+        if (elemType.isFloat8E5M2())
+          return NVVM::WGMMATypes::e5m2;
+        if (elemType.isInteger(1))
+          return NVVM::WGMMATypes::b1;
+        if (elemType.isInteger(8))
+          return NVVM::WGMMATypes::s8;
+        if (elemType.isUnsignedInteger(8))
+          return NVVM::WGMMATypes::u8;
+        llvm_unreachable("unsupported type");
+      };
+      return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
+    }
 
-    int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
-    if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
-                             wgmmaShapeN, wgmmaShapeK))) {
-      return failure();
+    /// Generates layout attribute for the input matrix for wgmma instruction
+    NVVM::MMALayoutAttr
+    generateWgmmaLayout(std::optional<bool> transpose) const {
+      if (transpose.value_or(false))
+        return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
+      return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
     }
 
-    Value descriptorA = adaptor.getDescriptorA();
-    Value descriptorB = adaptor.getDescriptorB();
+    /// Generates shape attribute for wgmma instruction
+    NVVM::MMAShapeAttr generateWgmmaShape() const {
+      return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
+    }
 
-    //  Generate wgmma group
-    MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
-    MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
+    /// Generates scale attributes of output matrix for wgmma instruction
+    NVVM::WGMMAScaleOutAttr generateScaleOut() const {
+      return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
+                                          NVVM::WGMMAScaleOut::one);
+    }
+    /// Generates scale attributes of input matrix for wgmma instruction
+    NVVM::WGMMAScaleInAttr generateScaleIn() const {
+      return NVVM::WGMMAScaleInAttr::get(op->getContext(),
+                                         NVVM::WGMMAScaleIn::one);
+    }
 
-    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
-      return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
+    /// Basic function to generate Add
+    Value makeAdd(Value lhs, Value rhs) {
+      return rewriter.create<LLVM::AddOp>(op->getLoc(), lhs.getType(), lhs,
+                                          rhs);
     };
 
-    auto iterateDescA = [&](Value desc, int iterM, int iterN,
-                            int iterK) -> Value {
-      // todo : Handle column major
-      int byte = typeTensorA.getElementTypeBitWidth() / 8;
-      int tileShapeA = typeTensorA.getDimSize(1);
-      int incrementVal =
-          ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
+    /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
+    /// Currently, it only handles row-major.
+    ///
+    /// It moves the pointer like below for [128][64] size:
+    ///                 +2 +4 +6
+    ///                  ↓  ↓  ↓
+    /// descA    ---> +--+--+--+--+
+    ///               |->|->|->|->|
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    /// descA+512---> +-----------+
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               +-----------+
+    ///
+    Value iterateDescriptorA(Value desc, int i, int j, int k) {
+      MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
+      Type elemA = matrixTypeA.getElementType();
+      int byte = elemA.getIntOrFloatBitWidth() / 8;
+      int tileShapeA = matrixTypeA.getDimSize(1);
+      int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
       incrementVal = incrementVal >> exclude4LSB;
-      LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
-                        << iterK << "] [wgmma descriptors] Descriptor A + "
+      LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
+                        << "] [wgmma descriptors] Descriptor A + "
                         << incrementVal << " | \t ");
       if (!incrementVal)
         return desc;
-      return makeAdd(desc, makeI64Const(b, incrementVal));
-    };
+      return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
+    }
 
-    auto iterateDescB = [&](Value desc, int iterM, int iterN,
-                            int iterK) -> Value {
-      // todo : Handle row major
-      int byte = typeTensorB.getElementTypeBitWidth() / 8;
-      int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
+    /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
+    /// Currently, it only handles column-major.
+    ///
+    /// It moves the pointer like below for [128][64] size:
+    /// descB     ---> +--+--+--+--+--+--+--+--+
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                +--+--+--+--+--+--+--+--+
+    ///
+    Value iterateDescriptorB(Value desc, int i, int j, int k) {
+      MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
+      Type elemB = matrixTypeB.getElementType();
+      int byte = elemB.getIntOrFloatBitWidth() / 8;
+      int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
       incrementVal = incrementVal >> exclude4LSB;
       LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
       if (!incrementVal)
         return desc;
-      return makeAdd(desc, makeI64Const(b, incrementVal));
-    };
+      return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
+    }
+
+    /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
+    /// descriptors and arranges them based on induction variables: i, j, and k.
+    Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+      LLVM_DEBUG(DBGS() << "\t wgmma."
+                        << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+                        << "(A[" << (iterationM * wgmmaM) << ":"
+                        << (iterationM * wgmmaM) + wgmmaM << "]["
+                        << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "] * "
+                        << " B[" << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
+                        << wgmmaN << "])\n");
+
+      Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
+      Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
+
+      Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
+      NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
+
+      Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
+      NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
+
+      NVVM::MMAShapeAttr shape = generateWgmmaShape();
+      NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
+      NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
+      NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
+      NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
+
+      auto overflow = NVVM::MMAIntOverflowAttr::get(
+          op->getContext(), NVVM::MMAIntOverflow::wrapped);
+
+      Type resultStructType = typeConverter.convertType(matrixD.getType());
+
+      return rewriter.create<NVVM::WgmmaMmaAsyncOp>(
+          op->getLoc(), resultStructType, matrixC, descriptorA, descriptorB,
+          shape, itypeA, itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+          overflow);
+    }
 
-    b.create<NVVM::WgmmaFenceAlignedOp>();
-
-    SmallVector<Value> wgmmaResults;
-    for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
-      Value matrixC = adaptor.getMatrixC()[iterM];
-      Value matrixD = op.getMatrixD()[iterM];
-      Type structType = getTypeConverter()->convertType(matrixD.getType());
-      LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
-                        << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
-                        << ":" << wgmmaShapeN << "] += \n");
-      for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
-        Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
-        Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
-        LLVM_DEBUG(DBGS() << "\t wgmma."
-                          << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
-                          << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
-                          << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
-                          << (iterK * wgmmaShapeK) << ":"
-                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
-                          << " B[" << (iterK * wgmmaShapeK) << ":"
-                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
-                          << ":" << wgmmaShapeN << "])\n");
-        matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
-                                      structType, matrixC, descA, descB);
+    /// Generates multiple wgmma instructions to complete the given GEMM shape
+    SmallVector<Value> generateWgmmaGroup() {
+      SmallVector<Value> wgmmaResults;
+
+      // Perform GEMM
+      for (int i = 0; i < iterationM; ++i) {
+        Value matrixC = adaptor.getMatrixC()[i];
+        Value matrixD = op.getMatrixD()[i];
+        for (int j = 0; j < iterationN; ++j)
+          for (int k = 0; k < iterationK; ++k)
+            matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+        wgmmaResults.push_back(matrixC);
       }
-      wgmmaResults.push_back(matrixC);
+
+      return wgmmaResults;
+    }
+
+  public:
+    WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ConversionPatternRewriter &rewriter,
+                  OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
+        : op(op), rewriter(rewriter), adaptor(adaptor),
+          typeConverter(typeConverter) {
+      // Find the entire GEMM Shape
+      totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
+      totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
+      totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+      LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
+                        << "] += A[" << totalM << "][" << totalK << "] * B["
+                        << totalK << "][" << totalN << "] ---===\n");
+
+      // Find the shape for one wgmma instruction
+      findWgmmaShape(
+          totalM, totalN,
+          op.getDescriptorA().getType().getTensor().getElementType());
+
+      // Iterations counts to complete the given shape with wgmma shape
+      iterationM = totalM / wgmmaM;
+      iterationN = totalN / wgmmaN;
+      iterationK = totalK / wgmmaK;
     }
-    b.create<NVVM::WgmmaGroupSyncAlignedOp>();
-    b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
 
-    ValueRange myres(wgmmaResults);
-    rewriter.replaceOp(op, myres);
+    /// Generates WgmmaMmaAsync Ops to complete the specified GEMM  shape. It
+    /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
+    /// instructions and group synchronization, as well as waiting
+    /// (WgmmaGroupSyncAlignedOp) for group synchronization
+    /// (WgmmaWaitGroupSyncOp) after the instructions.
+    SmallVector<Value> generateWarpgroupMma() {
+      Location loc = op->getLoc();
+      rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
+      SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+      rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
+      rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
+
+      return wgmmaResults;
+    }
+  };
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Step 1. Build a helper class
+    WarpgroupGemm warpgroupGemm(op, rewriter, adaptor,
+                                *this->getTypeConverter());
+
+    // Step 2. Get the entire GEMM Shape
+    SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+
+    // Step 3. Replace fragmented result struct with the op results
+    rewriter.replaceOp(op, wgmmaResults);
     return success();
   }
 };

>From 2ca1556b4d7531080cf27734f79a62bcbe77ebc6 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 5 Oct 2023 10:14:08 +0200
Subject: [PATCH 2/2] rebase and add more comment

---
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 59 +++++++++++--------
 1 file changed, 34 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 303462c22a6859c..e4705bc474b4ef1 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1160,15 +1160,29 @@ struct NVGPUWarpgroupMmaOpLowering
     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
   using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
 
-  /// This class assists in generating WgmmaMmaAsyncOp instructions to complete
-  /// a specified shape. If the GEMM shape is larger than the shape of a wgmma
-  /// instrution, it can generate multiple wgmma instructions, group and execute
-  /// them asynchronously. The class also handles waiting for instruction
-  /// completion and iterates through GenerateGmmaDescriptor to create
-  /// descriptors for each instruction.
+  /// This is a helper class to generate required NVVM Ops for warp-group level
+  /// matrix multiplication.
+  /// When the given GEMM shape is larger than the shape of
+  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
+  /// Op(s), group and execute them asynchronously. The class also handles
+  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
+  /// create descriptors for each instruction.
+  ///
+  /// For example this is the case when the shape of GEMM is 128x128x128
+  ///
+  ///    nvvm.wgmma.fence.aligned
+  ///
+  ///    nvvm.wgmma.mma.async descA, descB
+  ///    iterate(descA, descB)
+  ///    nvvm.wgmma.mma.async descA, descB
+  ///    [6x times more]
+  ///
+  ///    nvvm.wgmma.group.sync.aligned
+  ///    nvvm.wgmma.wait.group.sync [groupId]
+  ///
   class WarpgroupGemm {
     nvgpu::WarpgroupMmaOp op;
-    ConversionPatternRewriter &rewriter;
+    ImplicitLocOpBuilder b;
     OpAdaptor adaptor;
     const LLVMTypeConverter &typeConverter;
 
@@ -1253,8 +1267,7 @@ struct NVGPUWarpgroupMmaOpLowering
 
     /// Basic function to generate Add
     Value makeAdd(Value lhs, Value rhs) {
-      return rewriter.create<LLVM::AddOp>(op->getLoc(), lhs.getType(), lhs,
-                                          rhs);
+      return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
     };
 
     /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
@@ -1287,7 +1300,7 @@ struct NVGPUWarpgroupMmaOpLowering
                         << incrementVal << " | \t ");
       if (!incrementVal)
         return desc;
-      return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
+      return makeAdd(desc, makeI64Const(b, incrementVal));
     }
 
     /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
@@ -1310,7 +1323,7 @@ struct NVGPUWarpgroupMmaOpLowering
       LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
       if (!incrementVal)
         return desc;
-      return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
+      return makeAdd(desc, makeI64Const(b, incrementVal));
     }
 
     /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
@@ -1346,10 +1359,9 @@ struct NVGPUWarpgroupMmaOpLowering
 
       Type resultStructType = typeConverter.convertType(matrixD.getType());
 
-      return rewriter.create<NVVM::WgmmaMmaAsyncOp>(
-          op->getLoc(), resultStructType, matrixC, descriptorA, descriptorB,
-          shape, itypeA, itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
-          overflow);
+      return b.create<NVVM::WgmmaMmaAsyncOp>(
+          resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
+          itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
     }
 
     /// Generates multiple wgmma instructions to complete the given GEMM shape
@@ -1370,10 +1382,9 @@ struct NVGPUWarpgroupMmaOpLowering
     }
 
   public:
-    WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ConversionPatternRewriter &rewriter,
+    WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
                   OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
-        : op(op), rewriter(rewriter), adaptor(adaptor),
-          typeConverter(typeConverter) {
+        : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
       // Find the entire GEMM Shape
       totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
       totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
@@ -1399,12 +1410,10 @@ struct NVGPUWarpgroupMmaOpLowering
     /// (WgmmaGroupSyncAlignedOp) for group synchronization
     /// (WgmmaWaitGroupSyncOp) after the instructions.
     SmallVector<Value> generateWarpgroupMma() {
-      Location loc = op->getLoc();
-      rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
+      b.create<NVVM::WgmmaFenceAlignedOp>();
       SmallVector<Value> wgmmaResults = generateWgmmaGroup();
-      rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
-      rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
-
+      b.create<NVVM::WgmmaGroupSyncAlignedOp>();
+      b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
       return wgmmaResults;
     }
   };
@@ -1412,9 +1421,9 @@ struct NVGPUWarpgroupMmaOpLowering
   LogicalResult
   matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     // Step 1. Build a helper class
-    WarpgroupGemm warpgroupGemm(op, rewriter, adaptor,
-                                *this->getTypeConverter());
+    WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
 
     // Step 2. Get the entire GEMM Shape
     SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();



More information about the Mlir-commits mailing list