[Mlir-commits] [mlir] [MLIR][NVGPU] Adding `nvgpu.warpgroup.mma` Op for Hopper GPUs (PR #65440)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 11 02:43:56 PDT 2023


================
@@ -1141,6 +1144,164 @@ struct NVGPUTmaCreateDescriptorOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
+  using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+                              int &wgmmaShapeM, int &wgmmaShapeN,
+                              int &wgmmaShapeK) const {
+    wgmmaShapeM = 64;
+    wgmmaShapeN = sizeN;
+    if (inputElemType.isTF32()) {
+      wgmmaShapeK = 8;
+    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+      wgmmaShapeK = 16;
+    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+               inputElemType.isInteger(16)) {
+      wgmmaShapeK = 32;
+    } else if (inputElemType.isInteger(1)) {
+      wgmmaShapeK = 256;
+    } else {
+      return failure();
+    }
+    LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
+                      << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
+                      << "]\n");
+    return success();
+  }
+
+  Value generateNVVMWgmmaOp(MLIRContext *ctx,
+                            ConversionPatternRewriter &rewriter, Location loc,
+                            int m, int n, int k, Type resultStructType,
+                            Value inout, Value descriptorA,
+                            Value descriptorB) const {
+    TypeRange resultTypes = {resultStructType};
+    auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
+    auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
+    auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
+    auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
+    auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
+    // todo input type
----------------
qcolombet wrote:

What does this todo entail?

https://github.com/llvm/llvm-project/pull/65440


More information about the Mlir-commits mailing list