[Mlir-commits] [mlir] [MLIR][NVGPU] Adding `nvgpu.warpgroup.mma` Op for Hopper GPUs (PR #65440)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 11 02:43:56 PDT 2023
================
@@ -1141,6 +1144,164 @@ struct NVGPUTmaCreateDescriptorOpLowering
}
};
+struct NVGPUWarpgroupMmaOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
+ using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+ int &wgmmaShapeM, int &wgmmaShapeN,
+ int &wgmmaShapeK) const {
+ wgmmaShapeM = 64;
+ wgmmaShapeN = sizeN;
+ if (inputElemType.isTF32()) {
+ wgmmaShapeK = 8;
+ } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+ wgmmaShapeK = 16;
+ } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+ inputElemType.isInteger(16)) {
+ wgmmaShapeK = 32;
+ } else if (inputElemType.isInteger(1)) {
+ wgmmaShapeK = 256;
+ } else {
+ return failure();
+ }
+ LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
+ << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
+ << "]\n");
+ return success();
+ }
+
+ Value generateNVVMWgmmaOp(MLIRContext *ctx,
+ ConversionPatternRewriter &rewriter, Location loc,
+ int m, int n, int k, Type resultStructType,
+ Value inout, Value descriptorA,
+ Value descriptorB) const {
+ TypeRange resultTypes = {resultStructType};
+ auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
+ auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
+ auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
+ auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
+ auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
+ // todo input type
----------------
qcolombet wrote:
What does this todo entail?
https://github.com/llvm/llvm-project/pull/65440
More information about the Mlir-commits
mailing list