[Mlir-commits] [mlir] [MLIR][NVGPU] Adding `nvgpu.warpgroup.mma` Op for Hopper GPUs (PR #65440)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 11 02:43:55 PDT 2023


================
@@ -1141,6 +1144,164 @@ struct NVGPUTmaCreateDescriptorOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
+  using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+                              int &wgmmaShapeM, int &wgmmaShapeN,
+                              int &wgmmaShapeK) const {
+    wgmmaShapeM = 64;
+    wgmmaShapeN = sizeN;
+    if (inputElemType.isTF32()) {
+      wgmmaShapeK = 8;
+    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+      wgmmaShapeK = 16;
+    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+               inputElemType.isInteger(16)) {
+      wgmmaShapeK = 32;
+    } else if (inputElemType.isInteger(1)) {
+      wgmmaShapeK = 256;
+    } else {
+      return failure();
----------------
qcolombet wrote:

Can this be turned in a `llvm_unreachable`?

Put differently, is it possible to reach that point if the input wgmma passed the verifier?

https://github.com/llvm/llvm-project/pull/65440


More information about the Mlir-commits mailing list