[Mlir-commits] [mlir] 50ab427 - [MLIR][NVGPU] Introduce Warpgroup Matrix Descriptor Type

Guray Ozen llvmlistbot at llvm.org
Tue Aug 22 08:02:41 PDT 2023


Author: Guray Ozen
Date: 2023-08-22T17:02:37+02:00
New Revision: 50ab427a29f09ae0dcbe8a1eb4391672e0ff2c24

URL: https://github.com/llvm/llvm-project/commit/50ab427a29f09ae0dcbe8a1eb4391672e0ff2c24
DIFF: https://github.com/llvm/llvm-project/commit/50ab427a29f09ae0dcbe8a1eb4391672e0ff2c24.diff

LOG: [MLIR][NVGPU] Introduce Warpgroup Matrix Descriptor Type

The Warpgroup Matrix Descriptor is a 64-bit integer that holds information about a matrix used by the wgmma instruction.

In this work, a new type is introduced for the descriptor. This enhances the readability of the IR and allows for easier verification using MLIR verification tools.

The type contains a 'memref' related to the descriptor, which is crucial for preserving and conveying information.

Depends on D157382

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D158403

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 8d9237337401a3..34d4b349cca155 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -169,6 +169,30 @@ def NVGPU_TensorMapDescriptor : NVGPU_Type<"TensorMapDescriptor", "tensormap.des
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "wgmma.descriptor", []> {
+  let summary = "Warpgroup matrix descriptor type";
+  let description = [{
+  The descriptor specifies the properties of the matrix in shared memory that 
+  is a multiplicand in the matrix multiply and accumulate operation. 
+  
+  The descriptor is a 64-bit value contained in a register with the following:
+  ```
+  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+  |   0-13  |14-15|   16-29   |30-31|   32-45   |46-48|49-51|   52-61   |62-63|
+  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+  |  14bits |2bits|   14bits  |2bits|   14bits  |2bits|3bits|   10bits  |2bits|
+  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+  | BaseAddr|  0  | LeadingDim|  0  |   Stride  |  0  |Offst|     0     |Swzle|
+  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+  ```
+   
+  [See for more details in PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor) 
+  
+  }];  
+  let parameters = (ins "MemRefType":$tensor);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // NVGPU Op Definitions
 //===----------------------------------------------------------------------===//
@@ -628,32 +652,18 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
 def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
   let summary = "Generate a wgmma matrix descriptor";
   let description = [{
-  This Op builds a wgmma descriptor that is used by wgmma matrix multiply 
-  and accumulate.
+  This Op builds a `nvgpu.wgmma.descriptor` that is used by warpgroup-level 
+  matrix multiply and accumulate.
 
   The descriptor specifies the properties of the matrix in shared memory that 
   is a multiplicand in the matrix multiply and accumulate operation. 
-  
-  The descriptor is a 64-bit value contained in a register with the following
-  ```
-  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
-  |   0-13  |14-15|   16-29   |30-31|   32-45   |46-48|49-51|   52-61   |62-63|
-  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
-  |  14bits |2bits|   14bits  |2bits|   14bits  |2bits|3bits|   10bits  |2bits|
-  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
-  | BaseAddr|  0  | LeadingDim|  0  |   Stride  |  0  |Offst|     0     |Swzle|
-  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
-  ```
-   
-  See for more details:
-    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
-
   }];  
-  let results = (outs I64:$descriptor);
+  let results = (outs NVGPU_WarpgroupMatrixDescriptor:$descriptor);
   let arguments = (ins Arg<AnyMemRef, "", [MemRead]>:$tensor, 
                        NVGPU_TensorMapDescriptor:$tensorMap);
-  let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap)}];
+  let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap) `->` type($descriptor)}];
   let hasVerifier = 1;
 }
 
+
 #endif // NVGPU

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f826f20db69ebf..c315c23aeac24f 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -417,6 +417,10 @@ struct ConvertNVGPUToNVVMPass
     converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
       return converter.convertType(IntegerType::get(type.getContext(), 64));
     });
+    converter.addConversion(
+        [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
+          return converter.convertType(IntegerType::get(type.getContext(), 64));
+        });
     converter.addConversion([&](nvgpu::MBarrierType type) -> Type {
       return converter.convertType(
           nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 5604f5027fd3cd..0d7ace52ccb36c 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -634,13 +634,13 @@ module @mymodule {
 !tensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16,3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
 memref.global "private" @dynamicShmem : memref<0xf16,3>
 // CHECK-LABEL: func @create_wgmma_descriptor(
-func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{
+func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>{
   %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
   %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
     // CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
     // CHECK: %[[Sre:.+]] = memref.reinterpret_cast %[[S0]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<0xf16, 3> to memref<128x64xf16, 3>
     // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[Sre]] : memref<128x64xf16, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
-    // CHECK: %[[c64:.+]] = llvm.mlir.constant(64 : i64) : i64    
+    // CHECK: %[[c64:.+]] =  llvm.mlir.constant(64 : i64) : i64
     // CHECK: %[[c1024:.+]] = llvm.mlir.constant(1024 : i64) : i64
     // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
     // CHECK: %[[S3:.+]] = llvm.ptrtoint %[[S2]] : !llvm.ptr<3> to i64
@@ -659,19 +659,19 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{
     // CHECK: %[[S16:.+]] = llvm.or %[[S12]], %[[S15]]  : i64
     // CHECK: %[[S18:.+]] = llvm.mlir.constant(32 : i64) : i64
     // CHECK: %[[S19:.+]] = llvm.shl %[[c64]], %[[S18]]  : i64
-    // CHECK: %[[S20:.+]] = llvm.or %[[S16]], %[[S19]]  : i64    
+    // CHECK: %[[S20:.+]] = llvm.or %[[S16]], %[[S19]]  : i64
     // CHECK: %[[S22:.+]] = llvm.mlir.constant(16 : i64) : i64
     // CHECK: %[[S23:.+]] = llvm.shl %[[c1024]], %[[S22]]  : i64
     // CHECK: %[[S24:.+]] = llvm.or %[[S20]], %[[S23]]  : i64
     // CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64
     // CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]]  : i64
     // CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]]  : i64
-    // CHECK: return %[[S27]] : i64
-  %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap
-  func.return %descA : i64
+    // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>> 
+    // CHECK: return %[[ret]]
+  %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
+  func.return %descA : !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
 }
 
-
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 
@@ -682,4 +682,4 @@ transform.sequence failures(propagate) {
     transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
       {use_opaque_pointers = true}
   } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "scf"], partial_conversion} : !transform.any_op
-}
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list