[Mlir-commits] [mlir] b74cfc1 - [mlir][nvgpu] Improve nvgpu->nvvm transformation of `warpgroup.mma` Op (NFC) (#67325)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 5 01:17:03 PDT 2023
Author: Guray Ozen
Date: 2023-10-05T10:16:59+02:00
New Revision: b74cfc139a1bce527a02ed9a32da5b0cb9c955bf
URL: https://github.com/llvm/llvm-project/commit/b74cfc139a1bce527a02ed9a32da5b0cb9c955bf
DIFF: https://github.com/llvm/llvm-project/commit/b74cfc139a1bce527a02ed9a32da5b0cb9c955bf.diff
LOG: [mlir][nvgpu] Improve nvgpu->nvvm transformation of `warpgroup.mma` Op (NFC) (#67325)
This PR introduces substantial improvements to the readability and
maintainability of the `nvgpu.warpgroup.mma` Op transformation from
nvgpu->nvvm. This transformation plays a crucial role in GEMM and
manages complex operations such as generating multiple wgmma ops and
iterating their descriptors. The prior code lacked clarity, but this PR
addresses that issue effectively.
**PR does followings:**
**Introduces a helper class:** `WarpgroupGemm` class encapsulates the
necessary functionality, making the code cleaner and more
understandable.
**Detailed Documentation:** Each function within the helper class is
thoroughly documented to provide clear insights into its purpose and
functionality.
Added:
Modified:
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b0df2feae16b49f..cce525dcdcbe268 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -39,7 +39,7 @@ namespace mlir {
using namespace mlir;
-/// Number of bits that needs to excluded when building matrix descriptor for
+/// Number of bits that needs to be excluded when building matrix descriptor for
/// wgmma operations.
constexpr int exclude4LSB = 4;
@@ -1160,137 +1160,276 @@ struct NVGPUWarpgroupMmaOpLowering
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
- LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
- int &wgmmaShapeM, int &wgmmaShapeN,
- int &wgmmaShapeK) const {
- wgmmaShapeM = 64;
- wgmmaShapeN = sizeN;
- if (inputElemType.isTF32()) {
- wgmmaShapeK = 8;
- } else if (inputElemType.isF16() || inputElemType.isBF16()) {
- wgmmaShapeK = 16;
- } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
- inputElemType.isInteger(16)) {
- wgmmaShapeK = 32;
- } else if (inputElemType.isInteger(1)) {
- wgmmaShapeK = 256;
- } else {
- llvm_unreachable("msg: not supported K shape");
+ /// This is a helper class to generate required NVVM Ops for warp-group level
+ /// matrix multiplication.
+ /// When the given GEMM shape is larger than the shape of
+ /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
+ /// Op(s), group and execute them asynchronously. The class also handles
+ /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
+ /// create descriptors for each instruction.
+ ///
+ /// For example this is the case when the shape of GEMM is 128x128x128
+ ///
+ /// nvvm.wgmma.fence.aligned
+ ///
+ /// nvvm.wgmma.mma.async descA, descB
+ /// iterate(descA, descB)
+ /// nvvm.wgmma.mma.async descA, descB
+ /// [6x times more]
+ ///
+ /// nvvm.wgmma.group.sync.aligned
+ /// nvvm.wgmma.wait.group.sync [groupId]
+ ///
+ class WarpgroupGemm {
+ nvgpu::WarpgroupMmaOp op;
+ ImplicitLocOpBuilder b;
+ OpAdaptor adaptor;
+ const LLVMTypeConverter &typeConverter;
+
+ // Entire shape of the given Op
+ int64_t totalM, totalN, totalK;
+
+ // Shape of one wgmma instruction
+ int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
+
+ // Iteration counts for GEMM
+ int iterationM = 0, iterationN = 0, iterationK = 0;
+
+ /// The function returns the shape of wgmma instruction that is defined in
+ /// PTX programming guide.
+ /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
+ void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
+ wgmmaM = 64;
+ wgmmaN = sizeN;
+ if (inputElemType.isTF32()) {
+ wgmmaK = 8;
+ } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+ wgmmaK = 16;
+ } else if (inputElemType.isFloat8E4M3FN() ||
+ inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
+ wgmmaK = 32;
+ } else if (inputElemType.isInteger(1)) {
+ wgmmaK = 256;
+ } else {
+ llvm_unreachable("msg: not supported K shape");
+ }
+ LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+ << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
}
- LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
- << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
- << "]\n");
- return success();
- }
- Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
- Type resultStructType, Value inout,
- Value descriptorA, Value descriptorB) const {
- MLIRContext *ctx = b.getContext();
- auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
- auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
- auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
- auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
- auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
- // todo: handle other input and output types
- auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
- auto overflow =
- NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
- Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
- resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
- scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
- return res;
- }
-
- LogicalResult
- matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- ImplicitLocOpBuilder b(op->getLoc(), rewriter);
- int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
- int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
- int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
-
- LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
- << sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
- << sizeN << "] ---===\n");
+ /// Generates WGMMATypesAttr from MLIR Type
+ NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
+ auto getWgmmaType = [](Type elemType) {
+ if (elemType.isF32() || elemType.isTF32())
+ return NVVM::WGMMATypes::tf32;
+ if (elemType.isF16())
+ return NVVM::WGMMATypes::f16;
+ if (elemType.isBF16())
+ return NVVM::WGMMATypes::bf16;
+ if (elemType.isFloat8E4M3FN())
+ return NVVM::WGMMATypes::e4m3;
+ if (elemType.isFloat8E5M2())
+ return NVVM::WGMMATypes::e5m2;
+ if (elemType.isInteger(1))
+ return NVVM::WGMMATypes::b1;
+ if (elemType.isInteger(8))
+ return NVVM::WGMMATypes::s8;
+ if (elemType.isUnsignedInteger(8))
+ return NVVM::WGMMATypes::u8;
+ llvm_unreachable("unsupported type");
+ };
+ return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
+ }
- int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
- if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
- wgmmaShapeN, wgmmaShapeK))) {
- return failure();
+ /// Generates layout attribute for the input matrix for wgmma instruction
+ NVVM::MMALayoutAttr
+ generateWgmmaLayout(std::optional<bool> transpose) const {
+ if (transpose.value_or(false))
+ return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
+ return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
}
- Value descriptorA = adaptor.getDescriptorA();
- Value descriptorB = adaptor.getDescriptorB();
+ /// Generates shape attribute for wgmma instruction
+ NVVM::MMAShapeAttr generateWgmmaShape() const {
+ return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
+ }
- // Generate wgmma group
- MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
- MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
+ /// Generates scale attributes of output matrix for wgmma instruction
+ NVVM::WGMMAScaleOutAttr generateScaleOut() const {
+ return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
+ NVVM::WGMMAScaleOut::one);
+ }
+ /// Generates scale attributes of input matrix for wgmma instruction
+ NVVM::WGMMAScaleInAttr generateScaleIn() const {
+ return NVVM::WGMMAScaleInAttr::get(op->getContext(),
+ NVVM::WGMMAScaleIn::one);
+ }
- auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+ /// Basic function to generate Add
+ Value makeAdd(Value lhs, Value rhs) {
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
};
- auto iterateDescA = [&](Value desc, int iterM, int iterN,
- int iterK) -> Value {
- // todo : Handle column major
- int byte = typeTensorA.getElementTypeBitWidth() / 8;
- int tileShapeA = typeTensorA.getDimSize(1);
- int incrementVal =
- ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
+ /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
+ /// Currently, it only handles row-major.
+ ///
+ /// It moves the pointer like below for [128][64] size:
+ /// +2 +4 +6
+ /// ↓ ↓ ↓
+ /// descA ---> +--+--+--+--+
+ /// |->|->|->|->|
+ /// | | | | |
+ /// | | | | |
+ /// | | | | |
+ /// descA+512---> +-----------+
+ /// | | | | |
+ /// | | | | |
+ /// | | | | |
+ /// | | | | |
+ /// +-----------+
+ ///
+ Value iterateDescriptorA(Value desc, int i, int j, int k) {
+ MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
+ Type elemA = matrixTypeA.getElementType();
+ int byte = elemA.getIntOrFloatBitWidth() / 8;
+ int tileShapeA = matrixTypeA.getDimSize(1);
+ int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
- << iterK << "] [wgmma descriptors] Descriptor A + "
+ LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
+ << "] [wgmma descriptors] Descriptor A + "
<< incrementVal << " | \t ");
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
- };
+ }
- auto iterateDescB = [&](Value desc, int iterM, int iterN,
- int iterK) -> Value {
- // todo : Handle row major
- int byte = typeTensorB.getElementTypeBitWidth() / 8;
- int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
+ /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
+ /// Currently, it only handles column-major.
+ ///
+ /// It moves the pointer like below for [128][64] size:
+ /// descB ---> +--+--+--+--+--+--+--+--+
+ /// |↓ | | | | | | | |
+ /// |↓ | | | | | | | |
+ /// |↓ | | | | | | | |
+ /// |↓ | | | | | | | |
+ /// +--+--+--+--+--+--+--+--+
+ ///
+ Value iterateDescriptorB(Value desc, int i, int j, int k) {
+ MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
+ Type elemB = matrixTypeB.getElementType();
+ int byte = elemB.getIntOrFloatBitWidth() / 8;
+ int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
- };
+ }
+
+ /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
+ /// descriptors and arranges them based on induction variables: i, j, and k.
+ Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+ LLVM_DEBUG(DBGS() << "\t wgmma."
+ << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+ << "(A[" << (iterationM * wgmmaM) << ":"
+ << (iterationM * wgmmaM) + wgmmaM << "]["
+ << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * "
+ << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
+ << wgmmaN << "])\n");
+
+ Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
+ Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
+
+ Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
+ NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
+
+ Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
+ NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
+
+ NVVM::MMAShapeAttr shape = generateWgmmaShape();
+ NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
+ NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
+ NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
+ NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
+
+ auto overflow = NVVM::MMAIntOverflowAttr::get(
+ op->getContext(), NVVM::MMAIntOverflow::wrapped);
+
+ Type resultStructType = typeConverter.convertType(matrixD.getType());
+
+ return b.create<NVVM::WgmmaMmaAsyncOp>(
+ resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
+ itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+ }
- b.create<NVVM::WgmmaFenceAlignedOp>();
-
- SmallVector<Value> wgmmaResults;
- for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
- Value matrixC = adaptor.getMatrixC()[iterM];
- Value matrixD = op.getMatrixD()[iterM];
- Type structType = getTypeConverter()->convertType(matrixD.getType());
- LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
- << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
- << ":" << wgmmaShapeN << "] += \n");
- for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
- Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
- Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
- << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
- << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
- << (iterK * wgmmaShapeK) << ":"
- << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
- << " B[" << (iterK * wgmmaShapeK) << ":"
- << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
- << ":" << wgmmaShapeN << "])\n");
- matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
- structType, matrixC, descA, descB);
+ /// Generates multiple wgmma instructions to complete the given GEMM shape
+ SmallVector<Value> generateWgmmaGroup() {
+ SmallVector<Value> wgmmaResults;
+
+ // Perform GEMM
+ for (int i = 0; i < iterationM; ++i) {
+ Value matrixC = adaptor.getMatrixC()[i];
+ Value matrixD = op.getMatrixD()[i];
+ for (int j = 0; j < iterationN; ++j)
+ for (int k = 0; k < iterationK; ++k)
+ matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+ wgmmaResults.push_back(matrixC);
}
- wgmmaResults.push_back(matrixC);
+
+ return wgmmaResults;
+ }
+
+ public:
+ WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
+ OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
+ : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
+ // Find the entire GEMM Shape
+ totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
+ totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
+ totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+ LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
+ << "] += A[" << totalM << "][" << totalK << "] * B["
+ << totalK << "][" << totalN << "] ---===\n");
+
+ // Find the shape for one wgmma instruction
+ findWgmmaShape(
+ totalM, totalN,
+ op.getDescriptorA().getType().getTensor().getElementType());
+
+ // Iterations counts to complete the given shape with wgmma shape
+ iterationM = totalM / wgmmaM;
+ iterationN = totalN / wgmmaN;
+ iterationK = totalK / wgmmaK;
}
- b.create<NVVM::WgmmaGroupSyncAlignedOp>();
- b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
- ValueRange myres(wgmmaResults);
- rewriter.replaceOp(op, myres);
+ /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
+ /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
+ /// instructions and group synchronization, as well as waiting
+ /// (WgmmaGroupSyncAlignedOp) for group synchronization
+ /// (WgmmaWaitGroupSyncOp) after the instructions.
+ SmallVector<Value> generateWarpgroupMma() {
+ b.create<NVVM::WgmmaFenceAlignedOp>();
+ SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+ b.create<NVVM::WgmmaGroupSyncAlignedOp>();
+ b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
+ return wgmmaResults;
+ }
+ };
+
+ LogicalResult
+ matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ // Step 1. Build a helper class
+ WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
+
+ // Step 2. Get the entire GEMM Shape
+ SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+
+ // Step 3. Replace fragmented result struct with the op results
+ rewriter.replaceOp(op, wgmmaResults);
return success();
}
};
More information about the Mlir-commits
mailing list