[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 18 07:27:24 PDT 2023
================
@@ -1141,11 +1150,246 @@ struct NVGPUTmaCreateDescriptorOpLowering
}
};
+struct NVGPUWarpgroupMmaOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
+ using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+ int &wgmmaShapeM, int &wgmmaShapeN,
+ int &wgmmaShapeK) const {
+ wgmmaShapeM = 64;
+ wgmmaShapeN = sizeN;
+ if (inputElemType.isTF32()) {
+ wgmmaShapeK = 8;
+ } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+ wgmmaShapeK = 16;
+ } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+ inputElemType.isInteger(16)) {
+ wgmmaShapeK = 32;
+ } else if (inputElemType.isInteger(1)) {
+ wgmmaShapeK = 256;
+ } else {
+ return failure();
+ }
+ LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
+ << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
+ << "]\n");
+ return success();
+ }
+
+ Value generateNVVMWgmmaOp(MLIRContext *ctx,
+ ConversionPatternRewriter &rewriter, Location loc,
+ int m, int n, int k, Type resultStructType,
+ Value inout, Value descriptorA,
+ Value descriptorB) const {
+ TypeRange resultTypes = {resultStructType};
+ auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
+ auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
+ auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
+ auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
+ auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
+ // todo input type
+ auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
+ auto overflow =
+ NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
+ Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
+ loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype,
+ scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+ return res;
+ }
+
+ static Type buildOutputStructType(MLIRContext *ctx, Type outElemType,
+ int sizeN) {
+ int outputElements = 0;
+ if (outElemType.isF32() || outElemType.isInteger(32))
+ outputElements = sizeN / 2;
+ if (outElemType.isF16())
+ outputElements = sizeN / 4;
+ SmallVector<Type> structBody;
+ for (int i = 0; i < outputElements; i++)
+ structBody.push_back(outElemType);
+ return LLVM::LLVMStructType::getLiteral(ctx, structBody);
+ }
+
+ LogicalResult
+ matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> wgmmaResults;
+
+ int64_t sizeM = op.getMatrixC().getType().getDimSize(0);
+ int64_t sizeN = op.getMatrixC().getType().getDimSize(1);
+ int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+
+ LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
+ << sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
+ << sizeN << "] ---===\n");
+
+ int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
+ if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
----------------
qcolombet wrote:
Either get the type from op or don't match non-f16 types.
https://github.com/llvm/llvm-project/pull/65441
More information about the Mlir-commits
mailing list