[Mlir-commits] [mlir] 50ab427 - [MLIR][NVGPU] Introduce Warpgroup Matrix Descriptor Type
Guray Ozen
llvmlistbot at llvm.org
Tue Aug 22 08:02:41 PDT 2023
Author: Guray Ozen
Date: 2023-08-22T17:02:37+02:00
New Revision: 50ab427a29f09ae0dcbe8a1eb4391672e0ff2c24
URL: https://github.com/llvm/llvm-project/commit/50ab427a29f09ae0dcbe8a1eb4391672e0ff2c24
DIFF: https://github.com/llvm/llvm-project/commit/50ab427a29f09ae0dcbe8a1eb4391672e0ff2c24.diff
LOG: [MLIR][NVGPU] Introduce Warpgroup Matrix Descriptor Type
The Warpgroup Matrix Descriptor is a 64-bit integer that holds information about a matrix used by the wgmma instruction.
In this work, a new type is introduced for the descriptor. This enhances the readability of the IR and allows for easier verification using MLIR verification tools.
The type contains a 'memref' related to the descriptor, which is crucial for preserving and conveying information.
Depends on D157382
Reviewed By: qcolombet
Differential Revision: https://reviews.llvm.org/D158403
Added:
Modified:
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 8d9237337401a3..34d4b349cca155 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -169,6 +169,30 @@ def NVGPU_TensorMapDescriptor : NVGPU_Type<"TensorMapDescriptor", "tensormap.des
let assemblyFormat = "`<` struct(params) `>`";
}
+def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "wgmma.descriptor", []> {
+ let summary = "Warpgroup matrix descriptor type";
+ let description = [{
+ The descriptor specifies the properties of the matrix in shared memory that
+ is a multiplicand in the matrix multiply and accumulate operation.
+
+ The descriptor is a 64-bit value contained in a register with the following:
+ ```
+ +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+ | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63|
+ +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+ | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits|
+ +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+ | BaseAddr| 0 | LeadingDim| 0 | Stride | 0 |Offst| 0 |Swzle|
+ +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+ ```
+
+ [See for more details in PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor)
+
+ }];
+ let parameters = (ins "MemRefType":$tensor);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
//===----------------------------------------------------------------------===//
// NVGPU Op Definitions
//===----------------------------------------------------------------------===//
@@ -628,32 +652,18 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
let summary = "Generate a wgmma matrix descriptor";
let description = [{
- This Op builds a wgmma descriptor that is used by wgmma matrix multiply
- and accumulate.
+ This Op builds a `nvgpu.wgmma.descriptor` that is used by warpgroup-level
+ matrix multiply and accumulate.
The descriptor specifies the properties of the matrix in shared memory that
is a multiplicand in the matrix multiply and accumulate operation.
-
- The descriptor is a 64-bit value contained in a register with the following
- ```
- +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
- | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63|
- +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
- | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits|
- +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
- | BaseAddr| 0 | LeadingDim| 0 | Stride | 0 |Offst| 0 |Swzle|
- +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
- ```
-
- See for more details:
- https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
-
}];
- let results = (outs I64:$descriptor);
+ let results = (outs NVGPU_WarpgroupMatrixDescriptor:$descriptor);
let arguments = (ins Arg<AnyMemRef, "", [MemRead]>:$tensor,
NVGPU_TensorMapDescriptor:$tensorMap);
- let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap)}];
+ let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap) `->` type($descriptor)}];
let hasVerifier = 1;
}
+
#endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f826f20db69ebf..c315c23aeac24f 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -417,6 +417,10 @@ struct ConvertNVGPUToNVVMPass
converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
return converter.convertType(IntegerType::get(type.getContext(), 64));
});
+ converter.addConversion(
+ [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
+ return converter.convertType(IntegerType::get(type.getContext(), 64));
+ });
converter.addConversion([&](nvgpu::MBarrierType type) -> Type {
return converter.convertType(
nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 5604f5027fd3cd..0d7ace52ccb36c 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -634,13 +634,13 @@ module @mymodule {
!tensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16,3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
memref.global "private" @dynamicShmem : memref<0xf16,3>
// CHECK-LABEL: func @create_wgmma_descriptor(
-func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{
+func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>{
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
// CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
// CHECK: %[[Sre:.+]] = memref.reinterpret_cast %[[S0]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<0xf16, 3> to memref<128x64xf16, 3>
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[Sre]] : memref<128x64xf16, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[c64:.+]] = llvm.mlir.constant(64 : i64) : i64
+ // CHECK: %[[c64:.+]] = llvm.mlir.constant(64 : i64) : i64
// CHECK: %[[c1024:.+]] = llvm.mlir.constant(1024 : i64) : i64
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S3:.+]] = llvm.ptrtoint %[[S2]] : !llvm.ptr<3> to i64
@@ -659,19 +659,19 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{
// CHECK: %[[S16:.+]] = llvm.or %[[S12]], %[[S15]] : i64
// CHECK: %[[S18:.+]] = llvm.mlir.constant(32 : i64) : i64
// CHECK: %[[S19:.+]] = llvm.shl %[[c64]], %[[S18]] : i64
- // CHECK: %[[S20:.+]] = llvm.or %[[S16]], %[[S19]] : i64
+ // CHECK: %[[S20:.+]] = llvm.or %[[S16]], %[[S19]] : i64
// CHECK: %[[S22:.+]] = llvm.mlir.constant(16 : i64) : i64
// CHECK: %[[S23:.+]] = llvm.shl %[[c1024]], %[[S22]] : i64
// CHECK: %[[S24:.+]] = llvm.or %[[S20]], %[[S23]] : i64
// CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]] : i64
// CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]] : i64
- // CHECK: return %[[S27]] : i64
- %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap
- func.return %descA : i64
+ // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+ // CHECK: return %[[ret]]
+ %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
+ func.return %descA : !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
}
-
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1
@@ -682,4 +682,4 @@ transform.sequence failures(propagate) {
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
{use_opaque_pointers = true}
} {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "scf"], partial_conversion} : !transform.any_op
-}
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list