[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)

Guray Ozen llvmlistbot at llvm.org
Fri Sep 22 07:55:13 PDT 2023


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/65441

>From 7b71da55fca8fe2a7dbe4982b1959be6a6175fa1 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 7 Sep 2023 11:52:38 +0200
Subject: [PATCH 1/2] [MLIR][NVGPU] Introduce `nvgpu.warpgroup.mma.store` Op
 for Hopper GPUs

This work introduces a new operation called `warpgroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref.

An example of fragmentation is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d

The `warpgroup.mma.store` does followings:
1) Takes one or more fragmented results matrix.
2) Calculates indexes per thread in warp group and stores the data into give memref.

Here's an example usage of the `nvgpu.warpgroup.mma` operation:
```
// Performs matmul, results are fragmented and in registers
%res, %res2 = nvgpu.warpgroup.mma ...

// Stores the fragmented result to the give memory
nvgpu.warpgroup.mma.store [%res1, %res2], %matrixD :
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>,
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
                to memref<128x128xf32,3>
```

Depends on #65440
---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   | 19 +++++
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 83 ++++++++++++++++++-
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp    | 29 +++++++
 3 files changed, 129 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 90381648dac6acc..e102ae0dc581013 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -721,4 +721,23 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
   let hasVerifier = 1;
 }
 
+def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
+  let description = [{
+    The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result 
+    in $matrixD to give memref. 
+
+    [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
+
+    Note that, the op must be run with warp group.
+  }];
+
+  let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD,
+                       Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+  
+  let assemblyFormat = [{
+    `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f74aa05c0c4c4ff..4f1a0bc651e81b7 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -409,8 +410,8 @@ struct ConvertNVGPUToNVVMPass
   using Base::Base;
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
+    registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
+                    arith::ArithDialect>();
   }
 
   void runOnOperation() override {
@@ -451,6 +452,7 @@ struct ConvertNVGPUToNVVMPass
     populateNVGPUToNVVMConversionPatterns(converter, patterns);
     LLVMConversionTarget target(getContext());
     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+    target.addLegalDialect<::mlir::arith::ArithDialect>();
     target.addLegalDialect<::mlir::memref::MemRefDialect>();
     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
@@ -1299,11 +1301,88 @@ struct NVGPUWarpgroupMmaOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaStoreOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
+
+  void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
+                             OpAdaptor adaptor,
+                             ConversionPatternRewriter &rewriter,
+                             int offset) const {
+    Location loc = op->getLoc();
+    Type i32 = rewriter.getI32Type();
+
+    auto makeConst = [&](int32_t index) -> Value {
+      return rewriter.create<LLVM::ConstantOp>(
+          loc, i32, rewriter.getI32IntegerAttr(index));
+    };
+    Value c4 = makeConst(4);
+    Value c32 = makeConst(kWarpSize);
+    Value c8 = makeConst(8);
+    Value c2 = makeConst(2);
+    Value c1 = makeConst(1);
+    Value c16 = makeConst(16);
+
+    auto makeMul = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
+    };
+    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+    };
+
+    Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
+    Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
+    Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
+    Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
+    Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
+
+    auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
+                                   TypedValue<::mlir::MemRefType> memref) {
+      Type it = rewriter.getIndexType();
+      Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
+      Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
+      Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
+      Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
+      Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
+      rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
+      rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
+    };
+
+    Value tj = makeMul(lane4modId, c2);
+    Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
+    if (offset)
+      ti = makeAdd(ti, makeConst(offset));
+    for (int i = 0; i < 2; ++i) {
+      Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
+      for (int j = 0; j < 16; ++j) {
+        Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
+        int sIndex = i * 2 + j * 4;
+        makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
+      }
+    }
+  }
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    int offset = 0;
+    for (auto result : adaptor.getMatrixD()) {
+      auto stype = result.getType().cast<LLVM::LLVMStructType>();
+      storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
+      offset += stype.getBody().size();
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   patterns.add<
+      NVGPUWarpgroupMmaStoreOpLowering,      // nvgpu.warpgroup.mma.store`
       NVGPUMBarrierCreateLowering,           // nvgpu.mbarrier.create
       NVGPUMBarrierInitLowering,             // nvgpu.mbarrier.init
       NVGPUMBarrierArriveLowering,           // nvgpu.mbarrier.arrive
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d96ed69982870b4..fc85df1654198d5 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -529,6 +530,34 @@ LogicalResult WarpgroupMmaOp::verify() {
   return success();
 }
 
+LogicalResult WarpgroupMmaStoreOp::verify() {
+  Type stype =
+      getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor();
+
+  for (auto result : getMatrixD()) {
+    auto resultStype = result.getType()
+                           .cast<WarpgroupResultType>()
+                           .getTensor()
+                           .dyn_cast<LLVM::LLVMStructType>();
+    if (!resultStype)
+      return emitOpError() << "result is " << result.getType()
+                           << "  but must keep type of llvm struct";
+    if (stype != resultStype)
+      return emitOpError() << "all results must be the same type";
+
+    // todo improve this limitation
+    if (!resultStype.getBody().front().isF32()) {
+      return emitOpError() << "supporst only f32 results for the time being";
+    }
+  }
+
+  if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
+    return emitOpError() << "all element types must be equal  ";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd dialect, type, and op definitions
 //===----------------------------------------------------------------------===//

>From 4a1824f3e6ae955b78f7262178fa1b8e4608e3da Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 22 Sep 2023 16:53:21 +0200
Subject: [PATCH 2/2] use new type `WarpgroupAccumulator`

---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td     |  5 +++--
 mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp |  2 +-
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp      | 11 +++++++----
 3 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index e102ae0dc581013..4e80c33aec6043d 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -726,12 +726,13 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
     The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result 
     in $matrixD to give memref. 
 
-    [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
+    [See the details of register fragment layout for accumulator matrix D]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
 
     Note that, the op must be run with warp group.
   }];
 
-  let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD,
+  let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
                        Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
   
   let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4f1a0bc651e81b7..006ecbef2546e3e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1382,7 +1382,6 @@ struct NVGPUWarpgroupMmaStoreOpLowering
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   patterns.add<
-      NVGPUWarpgroupMmaStoreOpLowering,      // nvgpu.warpgroup.mma.store`
       NVGPUMBarrierCreateLowering,           // nvgpu.mbarrier.create
       NVGPUMBarrierInitLowering,             // nvgpu.mbarrier.init
       NVGPUMBarrierArriveLowering,           // nvgpu.mbarrier.arrive
@@ -1394,6 +1393,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
       NVGPUGenerateGmmaDescriptorLowering,   // nvgpu.wgmma.generate.descriptor
       NVGPUWarpgroupMmaOpLowering,           // nvgpu.warpgroup.mma
+      NVGPUWarpgroupMmaStoreOpLowering,      // nvgpu.warpgroup.mma.store`
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index fc85df1654198d5..1486bba5d3e57f6 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -531,13 +531,16 @@ LogicalResult WarpgroupMmaOp::verify() {
 }
 
 LogicalResult WarpgroupMmaStoreOp::verify() {
-  Type stype =
-      getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor();
+  Type stype = getMatrixD()
+                   .front()
+                   .getType()
+                   .cast<WarpgroupAccumulatorType>()
+                   .getFragmented();
 
   for (auto result : getMatrixD()) {
     auto resultStype = result.getType()
-                           .cast<WarpgroupResultType>()
-                           .getTensor()
+                           .cast<WarpgroupAccumulatorType>()
+                           .getFragmented()
                            .dyn_cast<LLVM::LLVMStructType>();
     if (!resultStype)
       return emitOpError() << "result is " << result.getType()



More information about the Mlir-commits mailing list