[Mlir-commits] [mlir] [mlir][nvgpu] Improve nvgpu->nvvm transformation of `warpgroup.mma` Op (NFC) (PR #67325)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 25 06:15:13 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This PR introduces substantial improvements to the readability and maintainability of the `nvgpu.warpgroup.mma` Op transformation from nvgpu->nvvm. This transformation plays a crucial role in GEMM and manages complex operations such as generating multiple wgmma ops and iterating their descriptors. The prior code lacked clarity, but this PR addresses that issue effectively.

**PR does followings:**
**Introduces a helper class:** `WarpgroupGemm` class encapsulates the necessary functionality, making the code cleaner and more understandable. 

**Detailed Documentation:** Each function within the helper class is thoroughly documented to provide clear insights into its purpose and functionality.

---
Full diff: https://github.com/llvm/llvm-project/pull/67325.diff


1 Files Affected:

- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+238-111) 


``````````diff
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..3bbee8934a1d4ae 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -38,7 +38,7 @@ namespace mlir {
 
 using namespace mlir;
 
-/// Number of bits that needs to excluded when building matrix descriptor for
+/// Number of bits that needs to be excluded when building matrix descriptor for
 /// wgmma operations.
 constexpr int exclude4LSB = 4;
 
@@ -1168,140 +1168,267 @@ struct NVGPUWarpgroupMmaOpLowering
     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
   using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
 
-  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
-                              int &wgmmaShapeM, int &wgmmaShapeN,
-                              int &wgmmaShapeK) const {
-    wgmmaShapeM = 64;
-    wgmmaShapeN = sizeN;
-    if (inputElemType.isTF32()) {
-      wgmmaShapeK = 8;
-    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
-      wgmmaShapeK = 16;
-    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
-               inputElemType.isInteger(16)) {
-      wgmmaShapeK = 32;
-    } else if (inputElemType.isInteger(1)) {
-      wgmmaShapeK = 256;
-    } else {
-      llvm_unreachable("msg: not supported K shape");
+  /// This class assists in generating WgmmaMmaAsyncOp instructions to complete
+  /// a specified shape. If the GEMM shape is larger than the shape of a wgmma
+  /// instrution, it can generate multiple wgmma instructions, group and execute
+  /// them asynchronously. The class also handles waiting for instruction
+  /// completion and iterates through GenerateGmmaDescriptor to create
+  /// descriptors for each instruction.
+  class WarpgroupGemm {
+    nvgpu::WarpgroupMmaOp op;
+    ConversionPatternRewriter &rewriter;
+    OpAdaptor adaptor;
+    const LLVMTypeConverter &typeConverter;
+
+    // Entire shape of the given Op
+    int64_t totalM, totalN, totalK;
+
+    // Shape of one wgmma instruction
+    int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
+
+    // Iteration counts for GEMM
+    int iterationM = 0, iterationN = 0, iterationK = 0;
+
+    /// The function returns the shape of wgmma instruction that is defined in
+    /// PTX programming guide.
+    /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
+    void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
+      wgmmaM = 64;
+      wgmmaN = sizeN;
+      if (inputElemType.isTF32()) {
+        wgmmaK = 8;
+      } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+        wgmmaK = 16;
+      } else if (inputElemType.isFloat8E4M3FN() ||
+                 inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
+        wgmmaK = 32;
+      } else if (inputElemType.isInteger(1)) {
+        wgmmaK = 256;
+      } else {
+        llvm_unreachable("msg: not supported K shape");
+      }
+      LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+                        << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
     }
-    LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
-                      << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
-                      << "]\n");
-    return success();
-  }
 
-  Value generateNVVMWgmmaOp(MLIRContext *ctx,
-                            ConversionPatternRewriter &rewriter, Location loc,
-                            int m, int n, int k, Type resultStructType,
-                            Value inout, Value descriptorA,
-                            Value descriptorB) const {
-    auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
-    auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
-    auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
-    auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
-    auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
-    // todo: handle other input and output types
-    auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
-    auto overflow =
-        NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
-    Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
-        loc, resultStructType, inout, descriptorA, descriptorB, shape, itype,
-        itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
-    return res;
-  }
-
-  LogicalResult
-  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
-    int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
-    int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
-
-    LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
-                      << sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
-                      << sizeN << "] ---===\n");
-
-    int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
-    if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
-                             wgmmaShapeN, wgmmaShapeK))) {
-      return failure();
+    /// Generates WGMMATypesAttr from MLIR Type
+    NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
+      auto getWgmmaType = [](Type elemType) {
+        if (elemType.isF32() || elemType.isTF32())
+          return NVVM::WGMMATypes::tf32;
+        if (elemType.isF16())
+          return NVVM::WGMMATypes::f16;
+        if (elemType.isBF16())
+          return NVVM::WGMMATypes::bf16;
+        if (elemType.isFloat8E4M3FN())
+          return NVVM::WGMMATypes::e4m3;
+        if (elemType.isFloat8E5M2())
+          return NVVM::WGMMATypes::e5m2;
+        if (elemType.isInteger(1))
+          return NVVM::WGMMATypes::b1;
+        if (elemType.isInteger(8))
+          return NVVM::WGMMATypes::s8;
+        if (elemType.isUnsignedInteger(8))
+          return NVVM::WGMMATypes::u8;
+        llvm_unreachable("unsupported type");
+      };
+      return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
     }
 
-    Value descriptorA = adaptor.getDescriptorA();
-    Value descriptorB = adaptor.getDescriptorB();
+    /// Generates layout attribute for the input matrix for wgmma instruction
+    NVVM::MMALayoutAttr
+    generateWgmmaLayout(std::optional<bool> transpose) const {
+      if (transpose.value_or(false))
+        return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
+      return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
+    }
 
-    //  Generate wgmma group
+    /// Generates shape attribute for wgmma instruction
+    NVVM::MMAShapeAttr generateWgmmaShape() const {
+      return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
+    }
 
-    auto loc = op->getLoc();
-    MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
-    MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
+    /// Generates scale attributes of output matrix for wgmma instruction
+    NVVM::WGMMAScaleOutAttr generateScaleOut() const {
+      return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
+                                          NVVM::WGMMAScaleOut::one);
+    }
+    /// Generates scale attributes of input matrix for wgmma instruction
+    NVVM::WGMMAScaleInAttr generateScaleIn() const {
+      return NVVM::WGMMAScaleInAttr::get(op->getContext(),
+                                         NVVM::WGMMAScaleIn::one);
+    }
 
-    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
-      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+    /// Basic function to generate Add
+    Value makeAdd(Value lhs, Value rhs) {
+      return rewriter.create<LLVM::AddOp>(op->getLoc(), lhs.getType(), lhs,
+                                          rhs);
     };
 
-    auto iterateDescA = [&](Value desc, int iterM, int iterN,
-                            int iterK) -> Value {
-      // todo : Handle column major
-      int byte = typeTensorA.getElementTypeBitWidth() / 8;
-      int tileShapeA = typeTensorA.getDimSize(1);
-      int incrementVal =
-          ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
+    /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
+    /// Currently, it only handles row-major.
+    ///
+    /// It moves the pointer like below for [128][64] size:
+    ///                 +2 +4 +6
+    ///                  ↓  ↓  ↓
+    /// descA    ---> +--+--+--+--+
+    ///               |->|->|->|->|
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    /// descA+512---> +-----------+
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               +-----------+
+    ///
+    Value iterateDescriptorA(Value desc, int i, int j, int k) {
+      MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
+      Type elemA = matrixTypeA.getElementType();
+      int byte = elemA.getIntOrFloatBitWidth() / 8;
+      int tileShapeA = matrixTypeA.getDimSize(1);
+      int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
       incrementVal = incrementVal >> exclude4LSB;
-      LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
-                        << iterK << "] [wgmma descriptors] Descriptor A + "
+      LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
+                        << "] [wgmma descriptors] Descriptor A + "
                         << incrementVal << " | \t ");
       if (!incrementVal)
         return desc;
       return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
-    };
+    }
 
-    auto iterateDescB = [&](Value desc, int iterM, int iterN,
-                            int iterK) -> Value {
-      // todo : Handle row major
-      int byte = typeTensorB.getElementTypeBitWidth() / 8;
-      int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
+    /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
+    /// Currently, it only handles column-major.
+    ///
+    /// It moves the pointer like below for [128][64] size:
+    /// descB     ---> +--+--+--+--+--+--+--+--+
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                +--+--+--+--+--+--+--+--+
+    ///
+    Value iterateDescriptorB(Value desc, int i, int j, int k) {
+      MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
+      Type elemB = matrixTypeB.getElementType();
+      int byte = elemB.getIntOrFloatBitWidth() / 8;
+      int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
       incrementVal = incrementVal >> exclude4LSB;
       LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
       if (!incrementVal)
         return desc;
       return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
-    };
+    }
+
+    /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
+    /// descriptors and arranges them based on induction variables: i, j, and k.
+    Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+      LLVM_DEBUG(DBGS() << "\t wgmma."
+                        << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+                        << "(A[" << (iterationM * wgmmaM) << ":"
+                        << (iterationM * wgmmaM) + wgmmaM << "]["
+                        << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "] * "
+                        << " B[" << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
+                        << wgmmaN << "])\n");
+
+      Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
+      Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
+
+      Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
+      NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
+
+      Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
+      NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
+
+      NVVM::MMAShapeAttr shape = generateWgmmaShape();
+      NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
+      NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
+      NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
+      NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
+
+      auto overflow = NVVM::MMAIntOverflowAttr::get(
+          op->getContext(), NVVM::MMAIntOverflow::wrapped);
+
+      Type resultStructType = typeConverter.convertType(matrixD.getType());
+
+      return rewriter.create<NVVM::WgmmaMmaAsyncOp>(
+          op->getLoc(), resultStructType, matrixC, descriptorA, descriptorB,
+          shape, itypeA, itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+          overflow);
+    }
 
-    rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
-
-    SmallVector<Value> wgmmaResults;
-    for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
-      Value matrixC = adaptor.getMatrixC()[iterM];
-      Value matrixD = op.getMatrixD()[iterM];
-      Type structType = getTypeConverter()->convertType(matrixD.getType());
-      LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
-                        << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
-                        << ":" << wgmmaShapeN << "] += \n");
-      for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
-        Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
-        Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
-        LLVM_DEBUG(DBGS() << "\t wgmma."
-                          << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
-                          << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
-                          << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
-                          << (iterK * wgmmaShapeK) << ":"
-                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
-                          << " B[" << (iterK * wgmmaShapeK) << ":"
-                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
-                          << ":" << wgmmaShapeN << "])\n");
-        matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
-                                      wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
-                                      structType, matrixC, descA, descB);
+    /// Generates multiple wgmma instructions to complete the given GEMM shape
+    SmallVector<Value> generateWgmmaGroup() {
+      SmallVector<Value> wgmmaResults;
+
+      // Perform GEMM
+      for (int i = 0; i < iterationM; ++i) {
+        Value matrixC = adaptor.getMatrixC()[i];
+        Value matrixD = op.getMatrixD()[i];
+        for (int j = 0; j < iterationN; ++j)
+          for (int k = 0; k < iterationK; ++k)
+            matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+        wgmmaResults.push_back(matrixC);
       }
-      wgmmaResults.push_back(matrixC);
+
+      return wgmmaResults;
+    }
+
+  public:
+    WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ConversionPatternRewriter &rewriter,
+                  OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
+        : op(op), rewriter(rewriter), adaptor(adaptor),
+          typeConverter(typeConverter) {
+      // Find the entire GEMM Shape
+      totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
+      totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
+      totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+      LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
+                        << "] += A[" << totalM << "][" << totalK << "] * B["
+                        << totalK << "][" << totalN << "] ---===\n");
+
+      // Find the shape for one wgmma instruction
+      findWgmmaShape(
+          totalM, totalN,
+          op.getDescriptorA().getType().getTensor().getElementType());
+
+      // Iterations counts to complete the given shape with wgmma shape
+      iterationM = totalM / wgmmaM;
+      iterationN = totalN / wgmmaN;
+      iterationK = totalK / wgmmaK;
     }
-    rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
-    rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
 
-    ValueRange myres(wgmmaResults);
-    rewriter.replaceOp(op, myres);
+    /// Generates WgmmaMmaAsync Ops to complete the specified GEMM  shape. It
+    /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
+    /// instructions and group synchronization, as well as waiting
+    /// (WgmmaGroupSyncAlignedOp) for group synchronization
+    /// (WgmmaWaitGroupSyncOp) after the instructions.
+    SmallVector<Value> generateWarpgroupMma() {
+      Location loc = op->getLoc();
+      rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
+      SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+      rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
+      rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
+
+      return wgmmaResults;
+    }
+  };
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Step 1. Build a helper class
+    WarpgroupGemm warpgroupGemm(op, rewriter, adaptor,
+                                *this->getTypeConverter());
+
+    // Step 2. Get the entire GEMM Shape
+    SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+
+    // Step 3. Replace fragmented result struct with the op results
+    rewriter.replaceOp(op, wgmmaResults);
     return success();
   }
 };

``````````

</details>


https://github.com/llvm/llvm-project/pull/67325


More information about the Mlir-commits mailing list