[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 18 07:27:23 PDT 2023


================
@@ -1141,11 +1150,246 @@ struct NVGPUTmaCreateDescriptorOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
+  using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+                              int &wgmmaShapeM, int &wgmmaShapeN,
+                              int &wgmmaShapeK) const {
+    wgmmaShapeM = 64;
+    wgmmaShapeN = sizeN;
+    if (inputElemType.isTF32()) {
+      wgmmaShapeK = 8;
+    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+      wgmmaShapeK = 16;
+    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+               inputElemType.isInteger(16)) {
+      wgmmaShapeK = 32;
+    } else if (inputElemType.isInteger(1)) {
+      wgmmaShapeK = 256;
+    } else {
+      return failure();
+    }
+    LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
+                      << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
+                      << "]\n");
+    return success();
+  }
+
+  Value generateNVVMWgmmaOp(MLIRContext *ctx,
+                            ConversionPatternRewriter &rewriter, Location loc,
+                            int m, int n, int k, Type resultStructType,
+                            Value inout, Value descriptorA,
+                            Value descriptorB) const {
+    TypeRange resultTypes = {resultStructType};
+    auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
+    auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
+    auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
+    auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
+    auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
+    // todo input type
+    auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
+    auto overflow =
+        NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
+    Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
+        loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype,
+        scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+    return res;
+  }
+
+  static Type buildOutputStructType(MLIRContext *ctx, Type outElemType,
+                                    int sizeN) {
+    int outputElements = 0;
+    if (outElemType.isF32() || outElemType.isInteger(32))
+      outputElements = sizeN / 2;
----------------
qcolombet wrote:

assert that sizeN is a multiple of 2

https://github.com/llvm/llvm-project/pull/65441


More information about the Mlir-commits mailing list