[Mlir-commits] [mlir] [MLIR][NVGPU] Test warpgroup matrix multiply 128x128x64 (PR #68817)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 11 09:51:41 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

Add a test that performs warpgroup matrix multiply 128x128x64. The test uses three Ops to do that.

---
Full diff: https://github.com/llvm/llvm-project/pull/68817.diff


1 Files Affected:

- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+63) 


``````````diff
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index ca030575e5e961e..389158fb79303ea 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -972,6 +972,69 @@ func.func @warpgroup_mma_init() {
   return 
 }
 
+// CHECK-LABEL: @warpgroup_matrix_multiply_m128n128k64(  
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
+func.func @warpgroup_matrix_multiply_m128n128k64(
+      %descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
+      %descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
+      %matrixD: memref<128x128xf32, 3>) 
+{
+  // Init
+  %m1, %m2 = nvgpu.warpgroup.mma.init.accumulator ->  
+                      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+                      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
+
+  // GEMM
+  %r1, %r2 = nvgpu.warpgroup.mma %descA, %descB, %m1, %m2 {transposeB}: 
+        !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
+        !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
+        !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+        !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> 
+        -> 
+        !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+        !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>  
+
+
+  // Epilogue 
+  nvgpu.warpgroup.mma.store [%r1, %r2], %matrixD : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> 
+    to memref<128x128xf32,3>
+
+
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S3:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S4:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[S5:.+]] = llvm.insertvalue %[[S4]], %[[S3]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+// CHECK: %[[S68:.+]] = llvm.insertvalue %[[S4]], %{{.*}}[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+// CHECK: %[[S69:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S70:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[S134:.+]] = llvm.insertvalue %70, %133[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+// CHECK: nvvm.wgmma.fence.aligned
+// CHECK: %[[S135:.+]] = nvvm.wgmma.mma_async %0, %1, <m = 64, n = 128, k = 16>, D[%[[S68]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: nvvm.wgmma.mma_async
+// CHECK: nvvm.wgmma.mma_async
+// CHECK: %[[S150:.+]] = nvvm.wgmma.mma_async
+// CHECK: nvvm.wgmma.mma_async
+// CHECK: nvvm.wgmma.mma_async
+// CHECK: nvvm.wgmma.mma_async
+// CHECK: %[[S168:.+]] = nvvm.wgmma.mma_async
+// CHECK: nvvm.wgmma.commit.group.sync.aligned
+// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
+// CHECK: %[[S193:.+]] = llvm.extractvalue %[[S150]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+// CHECK: %[[S194:.+]] = llvm.extractvalue %[[S150]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+// CHECK: memref.store %[[S193]], %[[arg2]][%{{.*}}, %{{.*}}] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S194]], %[[arg2]][%{{.*}}, %{{.*}}] : memref<128x128xf32, 3>
+// CHECK: %[[S503:.+]] = llvm.extractvalue %[[S168]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+// CHECK: %[[S504:.+]] = llvm.extractvalue %[[S168]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+// CHECK: memref.store %[[S503]], %[[arg2]][%{{.*}}, %{{.*}}] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S504]], %[[arg2]][%{{.*}}, %{{.*}}] : memref<128x128xf32, 3>
+  return 
+}
+
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 

``````````

</details>


https://github.com/llvm/llvm-project/pull/68817


More information about the Mlir-commits mailing list