[Mlir-commits] [mlir] b74cfc1 - [mlir][nvgpu] Improve nvgpu->nvvm transformation of `warpgroup.mma` Op (NFC) (#67325)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 5 01:17:03 PDT 2023


Author: Guray Ozen
Date: 2023-10-05T10:16:59+02:00
New Revision: b74cfc139a1bce527a02ed9a32da5b0cb9c955bf

URL: https://github.com/llvm/llvm-project/commit/b74cfc139a1bce527a02ed9a32da5b0cb9c955bf
DIFF: https://github.com/llvm/llvm-project/commit/b74cfc139a1bce527a02ed9a32da5b0cb9c955bf.diff

LOG: [mlir][nvgpu] Improve nvgpu->nvvm transformation of `warpgroup.mma` Op (NFC) (#67325)

This PR introduces substantial improvements to the readability and
maintainability of the `nvgpu.warpgroup.mma` Op transformation from
nvgpu->nvvm. This transformation plays a crucial role in GEMM and
manages complex operations such as generating multiple wgmma ops and
iterating their descriptors. The prior code lacked clarity, but this PR
addresses that issue effectively.

**PR does followings:**
**Introduces a helper class:** `WarpgroupGemm` class encapsulates the
necessary functionality, making the code cleaner and more
understandable.

**Detailed Documentation:** Each function within the helper class is
thoroughly documented to provide clear insights into its purpose and
functionality.

Added: 
    

Modified: 
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b0df2feae16b49f..cce525dcdcbe268 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -39,7 +39,7 @@ namespace mlir {
 
 using namespace mlir;
 
-/// Number of bits that needs to excluded when building matrix descriptor for
+/// Number of bits that needs to be excluded when building matrix descriptor for
 /// wgmma operations.
 constexpr int exclude4LSB = 4;
 
@@ -1160,137 +1160,276 @@ struct NVGPUWarpgroupMmaOpLowering
     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
   using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
 
-  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
-                              int &wgmmaShapeM, int &wgmmaShapeN,
-                              int &wgmmaShapeK) const {
-    wgmmaShapeM = 64;
-    wgmmaShapeN = sizeN;
-    if (inputElemType.isTF32()) {
-      wgmmaShapeK = 8;
-    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
-      wgmmaShapeK = 16;
-    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
-               inputElemType.isInteger(16)) {
-      wgmmaShapeK = 32;
-    } else if (inputElemType.isInteger(1)) {
-      wgmmaShapeK = 256;
-    } else {
-      llvm_unreachable("msg: not supported K shape");
+  /// This is a helper class to generate required NVVM Ops for warp-group level
+  /// matrix multiplication.
+  /// When the given GEMM shape is larger than the shape of
+  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
+  /// Op(s), group and execute them asynchronously. The class also handles
+  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
+  /// create descriptors for each instruction.
+  ///
+  /// For example this is the case when the shape of GEMM is 128x128x128
+  ///
+  ///    nvvm.wgmma.fence.aligned
+  ///
+  ///    nvvm.wgmma.mma.async descA, descB
+  ///    iterate(descA, descB)
+  ///    nvvm.wgmma.mma.async descA, descB
+  ///    [6x times more]
+  ///
+  ///    nvvm.wgmma.group.sync.aligned
+  ///    nvvm.wgmma.wait.group.sync [groupId]
+  ///
+  class WarpgroupGemm {
+    nvgpu::WarpgroupMmaOp op;
+    ImplicitLocOpBuilder b;
+    OpAdaptor adaptor;
+    const LLVMTypeConverter &typeConverter;
+
+    // Entire shape of the given Op
+    int64_t totalM, totalN, totalK;
+
+    // Shape of one wgmma instruction
+    int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
+
+    // Iteration counts for GEMM
+    int iterationM = 0, iterationN = 0, iterationK = 0;
+
+    /// The function returns the shape of wgmma instruction that is defined in
+    /// PTX programming guide.
+    /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
+    void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
+      wgmmaM = 64;
+      wgmmaN = sizeN;
+      if (inputElemType.isTF32()) {
+        wgmmaK = 8;
+      } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+        wgmmaK = 16;
+      } else if (inputElemType.isFloat8E4M3FN() ||
+                 inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
+        wgmmaK = 32;
+      } else if (inputElemType.isInteger(1)) {
+        wgmmaK = 256;
+      } else {
+        llvm_unreachable("msg: not supported K shape");
+      }
+      LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+                        << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
     }
-    LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
-                      << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
-                      << "]\n");
-    return success();
-  }
 
-  Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
-                            Type resultStructType, Value inout,
-                            Value descriptorA, Value descriptorB) const {
-    MLIRContext *ctx = b.getContext();
-    auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
-    auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
-    auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
-    auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
-    auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
-    // todo: handle other input and output types
-    auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
-    auto overflow =
-        NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
-    Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
-        resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
-        scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
-    return res;
-  }
-
-  LogicalResult
-  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
-    int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
-    int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
-    int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
-
-    LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
-                      << sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
-                      << sizeN << "] ---===\n");
+    /// Generates WGMMATypesAttr from MLIR Type
+    NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
+      auto getWgmmaType = [](Type elemType) {
+        if (elemType.isF32() || elemType.isTF32())
+          return NVVM::WGMMATypes::tf32;
+        if (elemType.isF16())
+          return NVVM::WGMMATypes::f16;
+        if (elemType.isBF16())
+          return NVVM::WGMMATypes::bf16;
+        if (elemType.isFloat8E4M3FN())
+          return NVVM::WGMMATypes::e4m3;
+        if (elemType.isFloat8E5M2())
+          return NVVM::WGMMATypes::e5m2;
+        if (elemType.isInteger(1))
+          return NVVM::WGMMATypes::b1;
+        if (elemType.isInteger(8))
+          return NVVM::WGMMATypes::s8;
+        if (elemType.isUnsignedInteger(8))
+          return NVVM::WGMMATypes::u8;
+        llvm_unreachable("unsupported type");
+      };
+      return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
+    }
 
-    int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
-    if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
-                             wgmmaShapeN, wgmmaShapeK))) {
-      return failure();
+    /// Generates layout attribute for the input matrix for wgmma instruction
+    NVVM::MMALayoutAttr
+    generateWgmmaLayout(std::optional<bool> transpose) const {
+      if (transpose.value_or(false))
+        return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
+      return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
     }
 
-    Value descriptorA = adaptor.getDescriptorA();
-    Value descriptorB = adaptor.getDescriptorB();
+    /// Generates shape attribute for wgmma instruction
+    NVVM::MMAShapeAttr generateWgmmaShape() const {
+      return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
+    }
 
-    //  Generate wgmma group
-    MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
-    MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
+    /// Generates scale attributes of output matrix for wgmma instruction
+    NVVM::WGMMAScaleOutAttr generateScaleOut() const {
+      return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
+                                          NVVM::WGMMAScaleOut::one);
+    }
+    /// Generates scale attributes of input matrix for wgmma instruction
+    NVVM::WGMMAScaleInAttr generateScaleIn() const {
+      return NVVM::WGMMAScaleInAttr::get(op->getContext(),
+                                         NVVM::WGMMAScaleIn::one);
+    }
 
-    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+    /// Basic function to generate Add
+    Value makeAdd(Value lhs, Value rhs) {
       return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
     };
 
-    auto iterateDescA = [&](Value desc, int iterM, int iterN,
-                            int iterK) -> Value {
-      // todo : Handle column major
-      int byte = typeTensorA.getElementTypeBitWidth() / 8;
-      int tileShapeA = typeTensorA.getDimSize(1);
-      int incrementVal =
-          ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
+    /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
+    /// Currently, it only handles row-major.
+    ///
+    /// It moves the pointer like below for [128][64] size:
+    ///                 +2 +4 +6
+    ///                  ↓  ↓  ↓
+    /// descA    ---> +--+--+--+--+
+    ///               |->|->|->|->|
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    /// descA+512---> +-----------+
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               +-----------+
+    ///
+    Value iterateDescriptorA(Value desc, int i, int j, int k) {
+      MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
+      Type elemA = matrixTypeA.getElementType();
+      int byte = elemA.getIntOrFloatBitWidth() / 8;
+      int tileShapeA = matrixTypeA.getDimSize(1);
+      int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
       incrementVal = incrementVal >> exclude4LSB;
-      LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
-                        << iterK << "] [wgmma descriptors] Descriptor A + "
+      LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
+                        << "] [wgmma descriptors] Descriptor A + "
                         << incrementVal << " | \t ");
       if (!incrementVal)
         return desc;
       return makeAdd(desc, makeI64Const(b, incrementVal));
-    };
+    }
 
-    auto iterateDescB = [&](Value desc, int iterM, int iterN,
-                            int iterK) -> Value {
-      // todo : Handle row major
-      int byte = typeTensorB.getElementTypeBitWidth() / 8;
-      int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
+    /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
+    /// Currently, it only handles column-major.
+    ///
+    /// It moves the pointer like below for [128][64] size:
+    /// descB     ---> +--+--+--+--+--+--+--+--+
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                +--+--+--+--+--+--+--+--+
+    ///
+    Value iterateDescriptorB(Value desc, int i, int j, int k) {
+      MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
+      Type elemB = matrixTypeB.getElementType();
+      int byte = elemB.getIntOrFloatBitWidth() / 8;
+      int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
       incrementVal = incrementVal >> exclude4LSB;
       LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
       if (!incrementVal)
         return desc;
       return makeAdd(desc, makeI64Const(b, incrementVal));
-    };
+    }
+
+    /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
+    /// descriptors and arranges them based on induction variables: i, j, and k.
+    Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+      LLVM_DEBUG(DBGS() << "\t wgmma."
+                        << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+                        << "(A[" << (iterationM * wgmmaM) << ":"
+                        << (iterationM * wgmmaM) + wgmmaM << "]["
+                        << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "] * "
+                        << " B[" << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
+                        << wgmmaN << "])\n");
+
+      Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
+      Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
+
+      Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
+      NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
+
+      Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
+      NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
+
+      NVVM::MMAShapeAttr shape = generateWgmmaShape();
+      NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
+      NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
+      NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
+      NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
+
+      auto overflow = NVVM::MMAIntOverflowAttr::get(
+          op->getContext(), NVVM::MMAIntOverflow::wrapped);
+
+      Type resultStructType = typeConverter.convertType(matrixD.getType());
+
+      return b.create<NVVM::WgmmaMmaAsyncOp>(
+          resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
+          itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+    }
 
-    b.create<NVVM::WgmmaFenceAlignedOp>();
-
-    SmallVector<Value> wgmmaResults;
-    for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
-      Value matrixC = adaptor.getMatrixC()[iterM];
-      Value matrixD = op.getMatrixD()[iterM];
-      Type structType = getTypeConverter()->convertType(matrixD.getType());
-      LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
-                        << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
-                        << ":" << wgmmaShapeN << "] += \n");
-      for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
-        Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
-        Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
-        LLVM_DEBUG(DBGS() << "\t wgmma."
-                          << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
-                          << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
-                          << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
-                          << (iterK * wgmmaShapeK) << ":"
-                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
-                          << " B[" << (iterK * wgmmaShapeK) << ":"
-                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
-                          << ":" << wgmmaShapeN << "])\n");
-        matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
-                                      structType, matrixC, descA, descB);
+    /// Generates multiple wgmma instructions to complete the given GEMM shape
+    SmallVector<Value> generateWgmmaGroup() {
+      SmallVector<Value> wgmmaResults;
+
+      // Perform GEMM
+      for (int i = 0; i < iterationM; ++i) {
+        Value matrixC = adaptor.getMatrixC()[i];
+        Value matrixD = op.getMatrixD()[i];
+        for (int j = 0; j < iterationN; ++j)
+          for (int k = 0; k < iterationK; ++k)
+            matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+        wgmmaResults.push_back(matrixC);
       }
-      wgmmaResults.push_back(matrixC);
+
+      return wgmmaResults;
+    }
+
+  public:
+    WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
+                  OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
+        : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
+      // Find the entire GEMM Shape
+      totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
+      totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
+      totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+      LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
+                        << "] += A[" << totalM << "][" << totalK << "] * B["
+                        << totalK << "][" << totalN << "] ---===\n");
+
+      // Find the shape for one wgmma instruction
+      findWgmmaShape(
+          totalM, totalN,
+          op.getDescriptorA().getType().getTensor().getElementType());
+
+      // Iterations counts to complete the given shape with wgmma shape
+      iterationM = totalM / wgmmaM;
+      iterationN = totalN / wgmmaN;
+      iterationK = totalK / wgmmaK;
     }
-    b.create<NVVM::WgmmaGroupSyncAlignedOp>();
-    b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
 
-    ValueRange myres(wgmmaResults);
-    rewriter.replaceOp(op, myres);
+    /// Generates WgmmaMmaAsync Ops to complete the specified GEMM  shape. It
+    /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
+    /// instructions and group synchronization, as well as waiting
+    /// (WgmmaGroupSyncAlignedOp) for group synchronization
+    /// (WgmmaWaitGroupSyncOp) after the instructions.
+    SmallVector<Value> generateWarpgroupMma() {
+      b.create<NVVM::WgmmaFenceAlignedOp>();
+      SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+      b.create<NVVM::WgmmaGroupSyncAlignedOp>();
+      b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
+      return wgmmaResults;
+    }
+  };
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+    // Step 1. Build a helper class
+    WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
+
+    // Step 2. Get the entire GEMM Shape
+    SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+
+    // Step 3. Replace fragmented result struct with the op results
+    rewriter.replaceOp(op, wgmmaResults);
     return success();
   }
 };


        


More information about the Mlir-commits mailing list