[Mlir-commits] [mlir] [MLIR][NVGPU] Adding `nvgpu.warpgroup.mma` Op for Hopper GPUs (PR #65440)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 11 02:43:55 PDT 2023
================
@@ -1141,6 +1144,164 @@ struct NVGPUTmaCreateDescriptorOpLowering
}
};
+struct NVGPUWarpgroupMmaOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
+ using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+ int &wgmmaShapeM, int &wgmmaShapeN,
+ int &wgmmaShapeK) const {
+ wgmmaShapeM = 64;
+ wgmmaShapeN = sizeN;
+ if (inputElemType.isTF32()) {
+ wgmmaShapeK = 8;
+ } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+ wgmmaShapeK = 16;
+ } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+ inputElemType.isInteger(16)) {
+ wgmmaShapeK = 32;
+ } else if (inputElemType.isInteger(1)) {
+ wgmmaShapeK = 256;
+ } else {
+ return failure();
----------------
qcolombet wrote:
Can this be turned in a `llvm_unreachable`?
Put differently, is it possible to reach that point if the input wgmma passed the verifier?
https://github.com/llvm/llvm-project/pull/65440
More information about the Mlir-commits
mailing list