[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 18 07:27:23 PDT 2023
================
@@ -1141,11 +1150,246 @@ struct NVGPUTmaCreateDescriptorOpLowering
}
};
+struct NVGPUWarpgroupMmaOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
+ using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+ int &wgmmaShapeM, int &wgmmaShapeN,
+ int &wgmmaShapeK) const {
+ wgmmaShapeM = 64;
+ wgmmaShapeN = sizeN;
+ if (inputElemType.isTF32()) {
+ wgmmaShapeK = 8;
+ } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+ wgmmaShapeK = 16;
+ } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+ inputElemType.isInteger(16)) {
+ wgmmaShapeK = 32;
+ } else if (inputElemType.isInteger(1)) {
+ wgmmaShapeK = 256;
+ } else {
+ return failure();
+ }
+ LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
+ << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
+ << "]\n");
+ return success();
+ }
+
+ Value generateNVVMWgmmaOp(MLIRContext *ctx,
+ ConversionPatternRewriter &rewriter, Location loc,
+ int m, int n, int k, Type resultStructType,
+ Value inout, Value descriptorA,
+ Value descriptorB) const {
+ TypeRange resultTypes = {resultStructType};
+ auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
+ auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
+ auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
+ auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
+ auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
+ // todo input type
+ auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
+ auto overflow =
+ NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
+ Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
+ loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype,
+ scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+ return res;
+ }
+
+ static Type buildOutputStructType(MLIRContext *ctx, Type outElemType,
+ int sizeN) {
+ int outputElements = 0;
+ if (outElemType.isF32() || outElemType.isInteger(32))
+ outputElements = sizeN / 2;
----------------
qcolombet wrote:
assert that sizeN is a multiple of 2
https://github.com/llvm/llvm-project/pull/65441
More information about the Mlir-commits
mailing list