[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 22 07:56:19 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-nvgpu

<details>
<summary>Changes</summary>

[MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs

This work introduces a new operation called `wargroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref.

An example of fragmentation is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d

The `wargroup.mma.store` does followings:
1) Takes one or more fragmented results matrix.
2) Calculates indexes per thread in warp group and stores the data into give memref.

Here's an example usage of the `nvgpu.wargroup.mma` operation:
```
// Performs matmul, results are fragmented and in registers
%res, %res2 = nvgpu.wargroup.mma ...

// Stores the fragmented result to the give memory
nvgpu.wargroup.mma.store [%res1, %res2], %matrixD : 
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>, 
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>> 
                to memref<128x128xf32,3>
```

Depends on #<!-- -->65440



---
Full diff: https://github.com/llvm/llvm-project/pull/65441.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+20) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+81-2) 
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+32) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 90381648dac6acc..4e80c33aec6043d 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -721,4 +721,24 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
   let hasVerifier = 1;
 }
 
+def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
+  let description = [{
+    The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result 
+    in $matrixD to give memref. 
+
+    [See the details of register fragment layout for accumulator matrix D]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
+
+    Note that, the op must be run with warp group.
+  }];
+
+  let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
+                       Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+  
+  let assemblyFormat = [{
+    `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f74aa05c0c4c4ff..006ecbef2546e3e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -409,8 +410,8 @@ struct ConvertNVGPUToNVVMPass
   using Base::Base;
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
+    registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
+                    arith::ArithDialect>();
   }
 
   void runOnOperation() override {
@@ -451,6 +452,7 @@ struct ConvertNVGPUToNVVMPass
     populateNVGPUToNVVMConversionPatterns(converter, patterns);
     LLVMConversionTarget target(getContext());
     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+    target.addLegalDialect<::mlir::arith::ArithDialect>();
     target.addLegalDialect<::mlir::memref::MemRefDialect>();
     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
@@ -1299,6 +1301,82 @@ struct NVGPUWarpgroupMmaOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaStoreOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
+
+  void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
+                             OpAdaptor adaptor,
+                             ConversionPatternRewriter &rewriter,
+                             int offset) const {
+    Location loc = op->getLoc();
+    Type i32 = rewriter.getI32Type();
+
+    auto makeConst = [&](int32_t index) -> Value {
+      return rewriter.create<LLVM::ConstantOp>(
+          loc, i32, rewriter.getI32IntegerAttr(index));
+    };
+    Value c4 = makeConst(4);
+    Value c32 = makeConst(kWarpSize);
+    Value c8 = makeConst(8);
+    Value c2 = makeConst(2);
+    Value c1 = makeConst(1);
+    Value c16 = makeConst(16);
+
+    auto makeMul = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
+    };
+    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+    };
+
+    Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
+    Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
+    Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
+    Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
+    Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
+
+    auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
+                                   TypedValue<::mlir::MemRefType> memref) {
+      Type it = rewriter.getIndexType();
+      Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
+      Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
+      Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
+      Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
+      Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
+      rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
+      rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
+    };
+
+    Value tj = makeMul(lane4modId, c2);
+    Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
+    if (offset)
+      ti = makeAdd(ti, makeConst(offset));
+    for (int i = 0; i < 2; ++i) {
+      Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
+      for (int j = 0; j < 16; ++j) {
+        Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
+        int sIndex = i * 2 + j * 4;
+        makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
+      }
+    }
+  }
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    int offset = 0;
+    for (auto result : adaptor.getMatrixD()) {
+      auto stype = result.getType().cast<LLVM::LLVMStructType>();
+      storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
+      offset += stype.getBody().size();
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1315,6 +1393,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
       NVGPUGenerateGmmaDescriptorLowering,   // nvgpu.wgmma.generate.descriptor
       NVGPUWarpgroupMmaOpLowering,           // nvgpu.warpgroup.mma
+      NVGPUWarpgroupMmaStoreOpLowering,      // nvgpu.warpgroup.mma.store`
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d96ed69982870b4..1486bba5d3e57f6 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -529,6 +530,37 @@ LogicalResult WarpgroupMmaOp::verify() {
   return success();
 }
 
+LogicalResult WarpgroupMmaStoreOp::verify() {
+  Type stype = getMatrixD()
+                   .front()
+                   .getType()
+                   .cast<WarpgroupAccumulatorType>()
+                   .getFragmented();
+
+  for (auto result : getMatrixD()) {
+    auto resultStype = result.getType()
+                           .cast<WarpgroupAccumulatorType>()
+                           .getFragmented()
+                           .dyn_cast<LLVM::LLVMStructType>();
+    if (!resultStype)
+      return emitOpError() << "result is " << result.getType()
+                           << "  but must keep type of llvm struct";
+    if (stype != resultStype)
+      return emitOpError() << "all results must be the same type";
+
+    // todo improve this limitation
+    if (!resultStype.getBody().front().isF32()) {
+      return emitOpError() << "supporst only f32 results for the time being";
+    }
+  }
+
+  if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
+    return emitOpError() << "all element types must be equal  ";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd dialect, type, and op definitions
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/65441


More information about the Mlir-commits mailing list