[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 18 07:27:23 PDT 2023

@@ -1141,11 +1150,246 @@ struct NVGPUTmaCreateDescriptorOpLowering
+struct NVGPUWarpgroupMmaOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
+  using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
+  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+                              int &wgmmaShapeM, int &wgmmaShapeN,
+                              int &wgmmaShapeK) const {
+    wgmmaShapeM = 64;
+    wgmmaShapeN = sizeN;
+    if (inputElemType.isTF32()) {
+      wgmmaShapeK = 8;
+    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+      wgmmaShapeK = 16;
+    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+               inputElemType.isInteger(16)) {
+      wgmmaShapeK = 32;
+    } else if (inputElemType.isInteger(1)) {
+      wgmmaShapeK = 256;
+    } else {
+      return failure();
+    }
+    LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
+                      << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
+                      << "]\n");
+    return success();
+  }
+  Value generateNVVMWgmmaOp(MLIRContext *ctx,
+                            ConversionPatternRewriter &rewriter, Location loc,
+                            int m, int n, int k, Type resultStructType,
+                            Value inout, Value descriptorA,
+                            Value descriptorB) const {
+    TypeRange resultTypes = {resultStructType};
+    auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
+    auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
+    auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
+    auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
+    auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
+    // todo input type
+    auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
+    auto overflow =
+        NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
+    Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
+        loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype,
+        scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+    return res;
+  }
+  static Type buildOutputStructType(MLIRContext *ctx, Type outElemType,
+                                    int sizeN) {
+    int outputElements = 0;
+    if (outElemType.isF32() || outElemType.isInteger(32))
+      outputElements = sizeN / 2;
+    if (outElemType.isF16())
+      outputElements = sizeN / 4;
+    SmallVector<Type> structBody;
+    for (int i = 0; i < outputElements; i++)
+      structBody.push_back(outElemType);
+    return LLVM::LLVMStructType::getLiteral(ctx, structBody);
+  }
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> wgmmaResults;
+    int64_t sizeM = op.getMatrixC().getType().getDimSize(0);
+    int64_t sizeN = op.getMatrixC().getType().getDimSize(1);
+    int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+    LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
+                      << sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
+                      << sizeN << "] ---===\n");
+    int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
+    if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
+                             wgmmaShapeN, wgmmaShapeK))) {
+      return failure();
+    }
+    Value descriptorA = adaptor.getDescriptorA();
+    Value descriptorB = adaptor.getDescriptorB();
+    //  Generate wgmma group
+    auto loc = op->getLoc();
+    Type outElemType = op.getMatrixC().getType().getElementType();
+    Type stype = buildOutputStructType(op->getContext(), outElemType, sizeN);
+    MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
+    MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
+    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+    };
+    auto iterateDescA = [&](Value desc, int iterM, int iterN,
+                            int iterK) -> Value {
+      // todo : Handle column major
+      int byte = typeTensorA.getElementTypeBitWidth() / 8;
+      int tileShapeA = typeTensorA.getDimSize(1);
+      int incrementVal =
+          ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
+      incrementVal = incrementVal >> exclude4LSB;
+      LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
+                        << iterK << "] [wgmma descriptors] Descriptor A + "
+                        << incrementVal << " | \t ");
+      return incrementVal
+                 ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal))
+                 : desc;
+    };
+    auto iterateDescB = [&](Value desc, int iterM, int iterN,
+                            int iterK) -> Value {
+      // todo : Handle row major
qcolombet wrote:

Isn't the code for row major already written in `iterateDescA`?
I.e., just a small refactoring and that should be supported too.
(Note: we can do that later, that's fine.)


More information about the Mlir-commits mailing list