[Mlir-commits] [mlir] [mlir][nvgpu] Improve nvgpu->nvvm transformation of `warpgroup.mma` Op (NFC) (PR #67325)

Guray Ozen llvmlistbot at llvm.org
Mon Sep 25 06:14:08 PDT 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/67325

This PR introduces substantial improvements to the readability and maintainability of the `nvgpu.warpgroup.mma` Op transformation from nvgpu->nvvm. This transformation plays a crucial role in GEMM and manages complex operations such as generating multiple wgmma ops and iterating their descriptors. The prior code lacked clarity, but this PR addresses that issue effectively.

**PR does followings:**
**Introduces a helper class:** `WarpgroupGemm` class encapsulates the necessary functionality, making the code cleaner and more understandable. 

**Detailed Documentation:** Each function within the helper class is thoroughly documented to provide clear insights into its purpose and functionality.

>From 18407498c4b43ccf1c2c39109efc87b84f366bf4 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 25 Sep 2023 15:09:51 +0200
Subject: [PATCH] [mlir][nvgpu] Improve nvgpu->nvvm transformation of
 `nvgpu.warpgroup.mma` Op (NFC)

This PR introduces substantial improvements to the readability and maintainability of the `nvgpu.warpgroup.mma` Op transformation from nvgpu->nvvm. This transformation plays a crucial role in GEMM and manages complex operations such as generating multiple wgmma ops and iterating their descriptors. The prior code lacked clarity, but this PR addresses that issue effectively.

PR introduces a helper class `WarpgroupGemm`. This class encapsulates the necessary functionality, making the code cleaner and more understandable. Each function within the helper class is thoroughly documented to provide clear insights into its purpose and functionality.
---
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 349 ++++++++++++------
 1 file changed, 238 insertions(+), 111 deletions(-)

diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..3bbee8934a1d4ae 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -38,7 +38,7 @@ namespace mlir {
 
 using namespace mlir;
 
-/// Number of bits that needs to excluded when building matrix descriptor for
+/// Number of bits that needs to be excluded when building matrix descriptor for
 /// wgmma operations.
 constexpr int exclude4LSB = 4;
 
@@ -1168,140 +1168,267 @@ struct NVGPUWarpgroupMmaOpLowering
     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
   using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
 
-  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
-                              int &wgmmaShapeM, int &wgmmaShapeN,
-                              int &wgmmaShapeK) const {
-    wgmmaShapeM = 64;
-    wgmmaShapeN = sizeN;
-    if (inputElemType.isTF32()) {
-      wgmmaShapeK = 8;
-    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
-      wgmmaShapeK = 16;
-    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
-               inputElemType.isInteger(16)) {
-      wgmmaShapeK = 32;
-    } else if (inputElemType.isInteger(1)) {
-      wgmmaShapeK = 256;
-    } else {
-      llvm_unreachable("msg: not supported K shape");
+  /// This class assists in generating WgmmaMmaAsyncOp instructions to complete
+  /// a specified shape. If the GEMM shape is larger than the shape of a wgmma
+  /// instrution, it can generate multiple wgmma instructions, group and execute
+  /// them asynchronously. The class also handles waiting for instruction
+  /// completion and iterates through GenerateGmmaDescriptor to create
+  /// descriptors for each instruction.
+  class WarpgroupGemm {
+    nvgpu::WarpgroupMmaOp op;
+    ConversionPatternRewriter &rewriter;
+    OpAdaptor adaptor;
+    const LLVMTypeConverter &typeConverter;
+
+    // Entire shape of the given Op
+    int64_t totalM, totalN, totalK;
+
+    // Shape of one wgmma instruction
+    int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
+
+    // Iteration counts for GEMM
+    int iterationM = 0, iterationN = 0, iterationK = 0;
+
+    /// The function returns the shape of wgmma instruction that is defined in
+    /// PTX programming guide.
+    /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
+    void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
+      wgmmaM = 64;
+      wgmmaN = sizeN;
+      if (inputElemType.isTF32()) {
+        wgmmaK = 8;
+      } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+        wgmmaK = 16;
+      } else if (inputElemType.isFloat8E4M3FN() ||
+                 inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
+        wgmmaK = 32;
+      } else if (inputElemType.isInteger(1)) {
+        wgmmaK = 256;
+      } else {
+        llvm_unreachable("msg: not supported K shape");
+      }
+      LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+                        << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
     }
-    LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
-                      << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
-                      << "]\n");
-    return success();
-  }
 
-  Value generateNVVMWgmmaOp(MLIRContext *ctx,
-                            ConversionPatternRewriter &rewriter, Location loc,
-                            int m, int n, int k, Type resultStructType,
-                            Value inout, Value descriptorA,
-                            Value descriptorB) const {
-    auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
-    auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
-    auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
-    auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
-    auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
-    // todo: handle other input and output types
-    auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
-    auto overflow =
-        NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
-    Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
-        loc, resultStructType, inout, descriptorA, descriptorB, shape, itype,
-        itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
-    return res;
-  }
-
-  LogicalResult
-  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
-    int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
-    int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
-
-    LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
-                      << sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
-                      << sizeN << "] ---===\n");
-
-    int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
-    if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
-                             wgmmaShapeN, wgmmaShapeK))) {
-      return failure();
+    /// Generates WGMMATypesAttr from MLIR Type
+    NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
+      auto getWgmmaType = [](Type elemType) {
+        if (elemType.isF32() || elemType.isTF32())
+          return NVVM::WGMMATypes::tf32;
+        if (elemType.isF16())
+          return NVVM::WGMMATypes::f16;
+        if (elemType.isBF16())
+          return NVVM::WGMMATypes::bf16;
+        if (elemType.isFloat8E4M3FN())
+          return NVVM::WGMMATypes::e4m3;
+        if (elemType.isFloat8E5M2())
+          return NVVM::WGMMATypes::e5m2;
+        if (elemType.isInteger(1))
+          return NVVM::WGMMATypes::b1;
+        if (elemType.isInteger(8))
+          return NVVM::WGMMATypes::s8;
+        if (elemType.isUnsignedInteger(8))
+          return NVVM::WGMMATypes::u8;
+        llvm_unreachable("unsupported type");
+      };
+      return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
     }
 
-    Value descriptorA = adaptor.getDescriptorA();
-    Value descriptorB = adaptor.getDescriptorB();
+    /// Generates layout attribute for the input matrix for wgmma instruction
+    NVVM::MMALayoutAttr
+    generateWgmmaLayout(std::optional<bool> transpose) const {
+      if (transpose.value_or(false))
+        return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
+      return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
+    }
 
-    //  Generate wgmma group
+    /// Generates shape attribute for wgmma instruction
+    NVVM::MMAShapeAttr generateWgmmaShape() const {
+      return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
+    }
 
-    auto loc = op->getLoc();
-    MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
-    MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
+    /// Generates scale attributes of output matrix for wgmma instruction
+    NVVM::WGMMAScaleOutAttr generateScaleOut() const {
+      return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
+                                          NVVM::WGMMAScaleOut::one);
+    }
+    /// Generates scale attributes of input matrix for wgmma instruction
+    NVVM::WGMMAScaleInAttr generateScaleIn() const {
+      return NVVM::WGMMAScaleInAttr::get(op->getContext(),
+                                         NVVM::WGMMAScaleIn::one);
+    }
 
-    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
-      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+    /// Basic function to generate Add
+    Value makeAdd(Value lhs, Value rhs) {
+      return rewriter.create<LLVM::AddOp>(op->getLoc(), lhs.getType(), lhs,
+                                          rhs);
     };
 
-    auto iterateDescA = [&](Value desc, int iterM, int iterN,
-                            int iterK) -> Value {
-      // todo : Handle column major
-      int byte = typeTensorA.getElementTypeBitWidth() / 8;
-      int tileShapeA = typeTensorA.getDimSize(1);
-      int incrementVal =
-          ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
+    /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
+    /// Currently, it only handles row-major.
+    ///
+    /// It moves the pointer like below for [128][64] size:
+    ///                 +2 +4 +6
+    ///                  ↓  ↓  ↓
+    /// descA    ---> +--+--+--+--+
+    ///               |->|->|->|->|
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    /// descA+512---> +-----------+
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               +-----------+
+    ///
+    Value iterateDescriptorA(Value desc, int i, int j, int k) {
+      MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
+      Type elemA = matrixTypeA.getElementType();
+      int byte = elemA.getIntOrFloatBitWidth() / 8;
+      int tileShapeA = matrixTypeA.getDimSize(1);
+      int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
       incrementVal = incrementVal >> exclude4LSB;
-      LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
-                        << iterK << "] [wgmma descriptors] Descriptor A + "
+      LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
+                        << "] [wgmma descriptors] Descriptor A + "
                         << incrementVal << " | \t ");
       if (!incrementVal)
         return desc;
       return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
-    };
+    }
 
-    auto iterateDescB = [&](Value desc, int iterM, int iterN,
-                            int iterK) -> Value {
-      // todo : Handle row major
-      int byte = typeTensorB.getElementTypeBitWidth() / 8;
-      int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
+    /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
+    /// Currently, it only handles column-major.
+    ///
+    /// It moves the pointer like below for [128][64] size:
+    /// descB     ---> +--+--+--+--+--+--+--+--+
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                +--+--+--+--+--+--+--+--+
+    ///
+    Value iterateDescriptorB(Value desc, int i, int j, int k) {
+      MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
+      Type elemB = matrixTypeB.getElementType();
+      int byte = elemB.getIntOrFloatBitWidth() / 8;
+      int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
       incrementVal = incrementVal >> exclude4LSB;
       LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
       if (!incrementVal)
         return desc;
       return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
-    };
+    }
+
+    /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
+    /// descriptors and arranges them based on induction variables: i, j, and k.
+    Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+      LLVM_DEBUG(DBGS() << "\t wgmma."
+                        << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+                        << "(A[" << (iterationM * wgmmaM) << ":"
+                        << (iterationM * wgmmaM) + wgmmaM << "]["
+                        << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "] * "
+                        << " B[" << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
+                        << wgmmaN << "])\n");
+
+      Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
+      Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
+
+      Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
+      NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
+
+      Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
+      NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
+
+      NVVM::MMAShapeAttr shape = generateWgmmaShape();
+      NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
+      NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
+      NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
+      NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
+
+      auto overflow = NVVM::MMAIntOverflowAttr::get(
+          op->getContext(), NVVM::MMAIntOverflow::wrapped);
+
+      Type resultStructType = typeConverter.convertType(matrixD.getType());
+
+      return rewriter.create<NVVM::WgmmaMmaAsyncOp>(
+          op->getLoc(), resultStructType, matrixC, descriptorA, descriptorB,
+          shape, itypeA, itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+          overflow);
+    }
 
-    rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
-
-    SmallVector<Value> wgmmaResults;
-    for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
-      Value matrixC = adaptor.getMatrixC()[iterM];
-      Value matrixD = op.getMatrixD()[iterM];
-      Type structType = getTypeConverter()->convertType(matrixD.getType());
-      LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
-                        << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
-                        << ":" << wgmmaShapeN << "] += \n");
-      for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
-        Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
-        Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
-        LLVM_DEBUG(DBGS() << "\t wgmma."
-                          << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
-                          << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
-                          << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
-                          << (iterK * wgmmaShapeK) << ":"
-                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
-                          << " B[" << (iterK * wgmmaShapeK) << ":"
-                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
-                          << ":" << wgmmaShapeN << "])\n");
-        matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
-                                      wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
-                                      structType, matrixC, descA, descB);
+    /// Generates multiple wgmma instructions to complete the given GEMM shape
+    SmallVector<Value> generateWgmmaGroup() {
+      SmallVector<Value> wgmmaResults;
+
+      // Perform GEMM
+      for (int i = 0; i < iterationM; ++i) {
+        Value matrixC = adaptor.getMatrixC()[i];
+        Value matrixD = op.getMatrixD()[i];
+        for (int j = 0; j < iterationN; ++j)
+          for (int k = 0; k < iterationK; ++k)
+            matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+        wgmmaResults.push_back(matrixC);
       }
-      wgmmaResults.push_back(matrixC);
+
+      return wgmmaResults;
+    }
+
+  public:
+    WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ConversionPatternRewriter &rewriter,
+                  OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
+        : op(op), rewriter(rewriter), adaptor(adaptor),
+          typeConverter(typeConverter) {
+      // Find the entire GEMM Shape
+      totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
+      totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
+      totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+      LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
+                        << "] += A[" << totalM << "][" << totalK << "] * B["
+                        << totalK << "][" << totalN << "] ---===\n");
+
+      // Find the shape for one wgmma instruction
+      findWgmmaShape(
+          totalM, totalN,
+          op.getDescriptorA().getType().getTensor().getElementType());
+
+      // Iterations counts to complete the given shape with wgmma shape
+      iterationM = totalM / wgmmaM;
+      iterationN = totalN / wgmmaN;
+      iterationK = totalK / wgmmaK;
     }
-    rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
-    rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
 
-    ValueRange myres(wgmmaResults);
-    rewriter.replaceOp(op, myres);
+    /// Generates WgmmaMmaAsync Ops to complete the specified GEMM  shape. It
+    /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
+    /// instructions and group synchronization, as well as waiting
+    /// (WgmmaGroupSyncAlignedOp) for group synchronization
+    /// (WgmmaWaitGroupSyncOp) after the instructions.
+    SmallVector<Value> generateWarpgroupMma() {
+      Location loc = op->getLoc();
+      rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
+      SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+      rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
+      rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
+
+      return wgmmaResults;
+    }
+  };
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Step 1. Build a helper class
+    WarpgroupGemm warpgroupGemm(op, rewriter, adaptor,
+                                *this->getTypeConverter());
+
+    // Step 2. Get the entire GEMM Shape
+    SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+
+    // Step 3. Replace fragmented result struct with the op results
+    rewriter.replaceOp(op, wgmmaResults);
     return success();
   }
 };



More information about the Mlir-commits mailing list