[Mlir-commits] [mlir] cce3e8e - [MLIR][NVGPU] Introduction of wgmma.generate.descriptor Op

Guray Ozen llvmlistbot at llvm.org
Tue Aug 22 07:12:38 PDT 2023


Author: Guray Ozen
Date: 2023-08-22T16:12:25+02:00
New Revision: cce3e8ed895b2d4c1396929c363c071e15fdbf8b

URL: https://github.com/llvm/llvm-project/commit/cce3e8ed895b2d4c1396929c363c071e15fdbf8b
DIFF: https://github.com/llvm/llvm-project/commit/cce3e8ed895b2d4c1396929c363c071e15fdbf8b.diff

LOG: [MLIR][NVGPU] Introduction of wgmma.generate.descriptor Op

This work introduces a new Op, `wgmma.generate.descriptor`, designed to create a wgmma descriptor for inputs of matrix multiply and accumulate operations using `wgmma.mma_async` PTX instruction.

The descriptor format specifications can be found in the following link:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor

It's important to note that this op is in its initial phase, and it does come with certain limitations. It only supports 128b swizzling and does not incorporate interleaving. In the future, different calculations will be addressed in separate works, expanding the capabilities of the op.

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D157382

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 85c03d85a7d9aa..8d9237337401a3 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -625,4 +625,35 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
   let hasVerifier = 1;
 }
 
+def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
+  let summary = "Generate a wgmma matrix descriptor";
+  let description = [{
+  This Op builds a wgmma descriptor that is used by wgmma matrix multiply 
+  and accumulate.
+
+  The descriptor specifies the properties of the matrix in shared memory that 
+  is a multiplicand in the matrix multiply and accumulate operation. 
+  
+  The descriptor is a 64-bit value contained in a register with the following
+  ```
+  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+  |   0-13  |14-15|   16-29   |30-31|   32-45   |46-48|49-51|   52-61   |62-63|
+  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+  |  14bits |2bits|   14bits  |2bits|   14bits  |2bits|3bits|   10bits  |2bits|
+  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+  | BaseAddr|  0  | LeadingDim|  0  |   Stride  |  0  |Offst|     0     |Swzle|
+  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+  ```
+   
+  See for more details:
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
+
+  }];  
+  let results = (outs I64:$descriptor);
+  let arguments = (ins Arg<AnyMemRef, "", [MemRead]>:$tensor, 
+                       NVGPU_TensorMapDescriptor:$tensorMap);
+  let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap)}];
+  let hasVerifier = 1;
+}
+
 #endif // NVGPU

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 21c6780cc7887f..f826f20db69ebf 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -934,6 +934,76 @@ struct NVGPUTmaAsyncLoadOpLowering
     return success();
   }
 };
+struct NVGPUGenerateGmmaDescriptorLowering
+    : public ConvertOpToLLVMPattern<nvgpu::GenerateGmmaDescriptorOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Location loc = op->getLoc();
+
+    nvgpu::TensorMapSwizzleKind swizzleKind =
+        op.getTensorMap().getType().getSwizzle();
+
+    unsigned layout =
+        (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B)  ? 128
+        : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
+        : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
+                                                                    : 1;
+    unsigned swizzle =
+        (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B)  ? 1
+        : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
+        : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
+                                                                    : 0;
+
+    auto ti64 = rewriter.getIntegerType(64);
+    auto makeConst = [&](uint64_t index) -> Value {
+      return rewriter.create<LLVM::ConstantOp>(
+          loc, ti64, rewriter.getI64IntegerAttr(index));
+    };
+    auto shiftLeft = [&](Value value, unsigned shift) -> Value {
+      return rewriter.create<LLVM::ShlOp>(loc, ti64, value, makeConst(shift));
+    };
+    auto shiftRight = [&](Value value, unsigned shift) -> Value {
+      return rewriter.create<LLVM::LShrOp>(loc, ti64, value, makeConst(shift));
+    };
+    auto insertBit = [&](Value desc, Value val, int startBit) {
+      return rewriter.create<LLVM::OrOp>(loc, ti64, desc,
+                                         shiftLeft(val, startBit));
+    };
+
+    int ex4LSB = 4;
+    Value strideDim = makeConst((layout << 3) >> ex4LSB);
+    int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
+    Value leadDim = makeConst((sizeN * layout) >> ex4LSB);
+    Value baseAddr = getStridedElementPtr(
+        op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
+        adaptor.getTensor(), {}, rewriter);
+    Value basePtr = rewriter.create<LLVM::PtrToIntOp>(loc, ti64, baseAddr);
+    // Just use 14 bits for base address
+    Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
+
+    int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
+        startLeadBit = 16, startBaseAddrBit = 0;
+    Value dsc = makeConst(0);
+    // // [62,64)  swizzle type
+    dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
+    // // [49,52)  base_offset
+    dsc = insertBit(dsc, makeConst(0), startOffsetBit);
+    // // [32,46)  stride
+    dsc = insertBit(dsc, strideDim, startStrideBit);
+    // // [16,30)  leading dimension
+    dsc = insertBit(dsc, leadDim, startLeadBit);
+    // // [0,14)   start_address
+    dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
+
+    rewriter.replaceOp(op, dsc);
+    return success();
+  }
+};
 
 static Value makeI64Const(RewriterBase &rewriter, Operation *op,
                           int32_t index) {
@@ -1064,6 +1134,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUTmaCreateDescriptorOpLowering,    // nvgpu.tma.create.descriptor
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
       NVGPUTmaAsyncLoadOpLowering,           // nvgpu.tma.async.load
+      NVGPUGenerateGmmaDescriptorLowering,   // nvgpu.wgmma.generate.descriptor
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);

diff  --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 24c490568a4383..d832a983a132d6 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -366,6 +366,42 @@ LogicalResult TmaCreateDescriptorOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// NVGPU_GenerateGmmaDescriptorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GenerateGmmaDescriptorOp::verify() {
+  MemRefType memrefType = getTensor().getType();
+  MemRefType tensorMapType = getTensorMap().getType().getTensor();
+
+  if (memrefType != tensorMapType)
+    return emitError() << "memref and tensor map type mismatch";
+
+  if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape())
+    return emitError() << "supports only static shapes";
+
+  if (memrefType.getRank() != 2)
+    return emitError() << "supports only 2d memref is supported for now";
+
+  if (getTensorMap().getType().getSwizzle() !=
+      TensorMapSwizzleKind::SWIZZLE_128B) {
+    return emitError() << "supports only "
+                       << stringifyTensorMapSwizzleKind(
+                              TensorMapSwizzleKind::SWIZZLE_128B)
+                       << " is supported for the time being";
+  }
+
+  if (getTensorMap().getType().getInterleave() !=
+      TensorMapInterleaveKind::INTERLEAVE_NONE) {
+    return emitError() << "supports only "
+                       << stringifyTensorMapInterleaveKind(
+                              TensorMapInterleaveKind::INTERLEAVE_NONE)
+                       << " is supported for the time being";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd dialect, type, and op definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 9721b371aa90bf..5604f5027fd3cd 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -631,6 +631,47 @@ module @mymodule {
   }
 }
 
+!tensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16,3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
+memref.global "private" @dynamicShmem : memref<0xf16,3>
+// CHECK-LABEL: func @create_wgmma_descriptor(
+func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{
+  %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
+  %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
+    // CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
+    // CHECK: %[[Sre:.+]] = memref.reinterpret_cast %[[S0]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<0xf16, 3> to memref<128x64xf16, 3>
+    // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[Sre]] : memref<128x64xf16, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK: %[[c64:.+]] = llvm.mlir.constant(64 : i64) : i64    
+    // CHECK: %[[c1024:.+]] = llvm.mlir.constant(1024 : i64) : i64
+    // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+    // CHECK: %[[S3:.+]] = llvm.ptrtoint %[[S2]] : !llvm.ptr<3> to i64
+    // CHECK: %[[S4:.+]] = llvm.mlir.constant(46 : i64) : i64
+    // CHECK: %[[S5:.+]] = llvm.shl %[[S3]], %[[S4]]  : i64
+    // CHECK: %[[S6:.+]] = llvm.mlir.constant(50 : i64) : i64
+    // CHECK: %[[S7:.+]] = llvm.lshr %[[S5]], %[[S6]]  : i64
+    // CHECK: %[[S8:.+]] = llvm.mlir.constant(0 : i64) : i64
+    // CHECK: %[[S9:.+]] = llvm.mlir.constant(1 : i64) : i64
+    // CHECK: %[[S10:.+]] = llvm.mlir.constant(62 : i64) : i64
+    // CHECK: %[[S11:.+]] = llvm.shl %[[S9]], %[[S10]]  : i64
+    // CHECK: %[[S12:.+]] = llvm.or %[[S8]], %[[S11]]  : i64
+    // CHECK: %[[S13:.+]] = llvm.mlir.constant(0 : i64) : i64
+    // CHECK: %[[S14:.+]] = llvm.mlir.constant(49 : i64) : i64
+    // CHECK: %[[S15:.+]] = llvm.shl %[[S13]], %[[S14]]  : i64
+    // CHECK: %[[S16:.+]] = llvm.or %[[S12]], %[[S15]]  : i64
+    // CHECK: %[[S18:.+]] = llvm.mlir.constant(32 : i64) : i64
+    // CHECK: %[[S19:.+]] = llvm.shl %[[c64]], %[[S18]]  : i64
+    // CHECK: %[[S20:.+]] = llvm.or %[[S16]], %[[S19]]  : i64    
+    // CHECK: %[[S22:.+]] = llvm.mlir.constant(16 : i64) : i64
+    // CHECK: %[[S23:.+]] = llvm.shl %[[c1024]], %[[S22]]  : i64
+    // CHECK: %[[S24:.+]] = llvm.or %[[S20]], %[[S23]]  : i64
+    // CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64
+    // CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]]  : i64
+    // CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]]  : i64
+    // CHECK: return %[[S27]] : i64
+  %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap
+  func.return %descA : i64
+}
+
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 


        


More information about the Mlir-commits mailing list