[Mlir-commits] [mlir] [MLIR][NVGPU] Test warpgroup matrix multiply 128x128x64 (PR #68817)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 11 09:51:41 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
Add a test that performs warpgroup matrix multiply 128x128x64. The test uses three Ops to do that.
---
Full diff: https://github.com/llvm/llvm-project/pull/68817.diff
1 Files Affected:
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+63)
``````````diff
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index ca030575e5e961e..389158fb79303ea 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -972,6 +972,69 @@ func.func @warpgroup_mma_init() {
return
}
+// CHECK-LABEL: @warpgroup_matrix_multiply_m128n128k64(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
+func.func @warpgroup_matrix_multiply_m128n128k64(
+ %descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
+ %descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
+ %matrixD: memref<128x128xf32, 3>)
+{
+ // Init
+ %m1, %m2 = nvgpu.warpgroup.mma.init.accumulator ->
+ !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+ !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
+
+ // GEMM
+ %r1, %r2 = nvgpu.warpgroup.mma %descA, %descB, %m1, %m2 {transposeB}:
+ !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
+ !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
+ !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+ !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
+ ->
+ !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+ !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
+
+
+ // Epilogue
+ nvgpu.warpgroup.mma.store [%r1, %r2], %matrixD :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ to memref<128x128xf32,3>
+
+
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S3:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S4:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[S5:.+]] = llvm.insertvalue %[[S4]], %[[S3]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S68:.+]] = llvm.insertvalue %[[S4]], %{{.*}}[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S69:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S70:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[S134:.+]] = llvm.insertvalue %70, %133[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: nvvm.wgmma.fence.aligned
+// CHECK: %[[S135:.+]] = nvvm.wgmma.mma_async %0, %1, <m = 64, n = 128, k = 16>, D[%[[S68]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: nvvm.wgmma.mma_async
+// CHECK: nvvm.wgmma.mma_async
+// CHECK: %[[S150:.+]] = nvvm.wgmma.mma_async
+// CHECK: nvvm.wgmma.mma_async
+// CHECK: nvvm.wgmma.mma_async
+// CHECK: nvvm.wgmma.mma_async
+// CHECK: %[[S168:.+]] = nvvm.wgmma.mma_async
+// CHECK: nvvm.wgmma.commit.group.sync.aligned
+// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
+// CHECK: %[[S193:.+]] = llvm.extractvalue %[[S150]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S194:.+]] = llvm.extractvalue %[[S150]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: memref.store %[[S193]], %[[arg2]][%{{.*}}, %{{.*}}] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S194]], %[[arg2]][%{{.*}}, %{{.*}}] : memref<128x128xf32, 3>
+// CHECK: %[[S503:.+]] = llvm.extractvalue %[[S168]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S504:.+]] = llvm.extractvalue %[[S168]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: memref.store %[[S503]], %[[arg2]][%{{.*}}, %{{.*}}] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S504]], %[[arg2]][%{{.*}}, %{{.*}}] : memref<128x128xf32, 3>
+ return
+}
+
+
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1
``````````
</details>
https://github.com/llvm/llvm-project/pull/68817
More information about the Mlir-commits
mailing list