[Mlir-commits] [mlir] [MLIR][NVGPU] Change name `wgmma.descriptor` to `warpgroup.descriptor` (NFC) (PR #67526)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 27 01:06:38 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
NVGPU dialect is gaining large support for warpgroup level operations, and their names always starts with `warpgroup....`.
This PR changes name of Op and type from `wgmma.descriptor` to `warpgroup.descriptor` for sake of consistency.
---
Full diff: https://github.com/llvm/llvm-project/pull/67526.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+9-8)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+7-7)
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+2-2)
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+11-11)
- (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+8-8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 31b137160545772..3e657da52be5f72 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -174,7 +174,7 @@ def NVGPU_TensorMapDescriptor : NVGPU_Type<"TensorMapDescriptor", "tensormap.des
let assemblyFormat = "`<` struct(params) `>`";
}
-def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "wgmma.descriptor", []> {
+def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "warpgroup.descriptor", []> {
let summary = "Warpgroup matrix descriptor type";
let description = [{
The descriptor specifies the properties of the matrix in shared memory that
@@ -667,11 +667,12 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
let hasVerifier = 1;
}
-def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
- let summary = "Generate a wgmma matrix descriptor";
+def NVGPU_GenerateWarpgroupDescriptorOp : NVGPU_Op<"warpgroup.generate.descriptor", []> {
+ let summary = "Generate a warpgroup matrix descriptor";
let description = [{
- This Op builds a `nvgpu.wgmma.descriptor` that is used by warpgroup-level
- matrix multiply and accumulate.
+ This Op builds a `nvgpu.warpgroup.descriptor` that is used by
+ `nvgpu.warpgroup.mma` to perform warpgroup-level matrix multiply and
+ accumulate.
The descriptor specifies the properties of the matrix in shared memory that
is a multiplicand in the matrix multiply and accumulate operation.
@@ -702,9 +703,9 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
Example:
```mlir
- %r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2:
- !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
- !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
+ %r1,%r2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2:
+ !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
+ !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
->
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..d4bca1d8c846576 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -967,13 +967,13 @@ struct NVGPUTmaAsyncLoadOpLowering
return success();
}
};
-struct NVGPUGenerateGmmaDescriptorLowering
- : public ConvertOpToLLVMPattern<nvgpu::GenerateGmmaDescriptorOp> {
+struct NVGPUGenerateWarpgroupDescriptorLowering
+ : public ConvertOpToLLVMPattern<nvgpu::GenerateWarpgroupDescriptorOp> {
using ConvertOpToLLVMPattern<
- nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern;
+ nvgpu::GenerateWarpgroupDescriptorOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor,
+ matchAndRewrite(nvgpu::GenerateWarpgroupDescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
@@ -1037,7 +1037,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating wgmma.descriptor: "
+ LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
<< "leading_off:" << leadDimVal << "\t"
<< "stride_off :" << strideDimVal << "\t"
<< "base_offset:" << offsetVal << "\t"
@@ -1320,8 +1320,8 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
- NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
- NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
+ NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
+ NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index dfec17986800417..eb8fc4b65bc89ad 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -367,10 +367,10 @@ LogicalResult TmaCreateDescriptorOp::verify() {
}
//===----------------------------------------------------------------------===//
-// NVGPU_GenerateGmmaDescriptorOp
+// NVGPU_GenerateWarpgroupDescriptorOp
//===----------------------------------------------------------------------===//
-LogicalResult GenerateGmmaDescriptorOp::verify() {
+LogicalResult GenerateWarpgroupDescriptorOp::verify() {
MemRefType memrefType = getTensor().getType();
MemRefType tensorMapType = getTensorMap().getType().getTensor();
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 8c2f8dbbd5ad9a3..3710b06288e2a7f 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -674,7 +674,7 @@ module @mymodule {
!tensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16,3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
memref.global "private" @dynamicShmem : memref<0xf16,3>
// CHECK-LABEL: func @create_wgmma_descriptor(
-func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>{
+func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>{
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
// CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
@@ -706,22 +706,22 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc
// CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]] : i64
// CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]] : i64
- // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+ // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
// CHECK: return %[[ret]]
- %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
- func.return %descA : !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
+ %descA = nvgpu.warpgroup.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>
+ func.return %descA : !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>
}
// CHECK-LABEL: @warpgroup_mma_128_128_64(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
func.func @warpgroup_mma_128_128_64(
- %descA: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
- %descB: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
+ %descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
+ %descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
%acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
%acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>)
{
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvvm.wgmma.fence.aligned
@@ -762,8 +762,8 @@ func.func @warpgroup_mma_128_128_64(
// CHECK: nvvm.wgmma.commit.group.sync.aligned
// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
%wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}:
- !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
- !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
+ !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
+ !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
->
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index ff391e469815d74..66652070ec15f34 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -225,8 +225,8 @@ func.func @async_cp_size_invalid_f64(
// -----
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
-!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x121xf16, 3>>
+!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x121xf16, 3>>
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}}
@@ -237,8 +237,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
// -----
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<128xf32>>
-!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>
+!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
// expected-error @+1 {{'nvgpu.warpgroup.mma' op has matrices A, B, C and D, they must be 2 dimensional}}
%0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
@@ -247,8 +247,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
// -----
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
-!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf32, 3>>
+!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf32, 3>>
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}}
%0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
@@ -258,8 +258,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
// -----
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
-!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x512xf16, 3>>
+!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x512xf16, 3>>
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 512 ) != 2nd dim matrix-C ( 128 )}}
%0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
``````````
</details>
https://github.com/llvm/llvm-project/pull/67526
More information about the Mlir-commits
mailing list