[Mlir-commits] [mlir] bf62748 - [mlir][nvvm] Introduce Syncronization Ops for WGMMA

Guray Ozen llvmlistbot at llvm.org
Wed Jul 19 02:45:09 PDT 2023


Author: Guray Ozen
Date: 2023-07-19T11:45:04+02:00
New Revision: bf62748342438d7136ca78ef3875b31442b1ccd3

URL: https://github.com/llvm/llvm-project/commit/bf62748342438d7136ca78ef3875b31442b1ccd3
DIFF: https://github.com/llvm/llvm-project/commit/bf62748342438d7136ca78ef3875b31442b1ccd3.diff

LOG: [mlir][nvvm] Introduce Syncronization Ops for WGMMA

This work introduces : `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned` and `wgmma.wait.group.sync.aligned` Ops. They are used to syncronize warpgroup level matrix multiply-accumulate instructions, as known as WGMMA.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D155676

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
    mlir/test/Dialect/LLVMIR/nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index efdd3d691f9e30..ef17a6c2ac4ff5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1419,4 +1419,51 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM Wgmma Ops
+//===----------------------------------------------------------------------===//
+
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", 
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
+  let arguments = (ins);
+  let description = [{
+    Enforce an ordering of register accesses between warpgroup level matrix 
+    multiplication and other operations. 
+    See for more information:
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence
+  }];
+  let assemblyFormat = "attr-dict";
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() { return std::string("wgmma.fence.sync.aligned;"); }
+  }];
+}
+
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", 
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+  Arguments<(ins )> {
+  let assemblyFormat = "attr-dict";
+  let description = [{
+    Commits all prior uncommitted warpgroup level matrix multiplication operations.
+    See for more information:
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group
+  }];
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() { return std::string("wgmma.commit_group.sync.aligned;"); }
+  }];
+}
+
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", 
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
+  let arguments = (ins I32Attr:$group);
+  let assemblyFormat = "attr-dict $group";
+  let description = [{
+    Signal the completion of a preceding warpgroup operation.
+    See for more information:
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group
+  }];
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() { return std::string("wgmma.wait_group.sync.aligned %0;"); }
+  }];
+}
+
 #endif // NVVMIR_OPS

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 4201c7b81e0ef3..5d3218ef1c7f5d 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -80,3 +80,23 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
   nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
   return
 }
+
+
+// CHECK-LABEL : @wgmma_execute
+func.func @wgmma_execute() {  
+  nvvm.wgmma.fence.aligned
+  nvvm.wgmma.commit.group.sync.aligned
+  nvvm.wgmma.wait.group.sync.aligned 0
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;", ""
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;", ""
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32)
+  
+
+  nvvm.wgmma.fence.aligned
+  nvvm.wgmma.commit.group.sync.aligned
+  nvvm.wgmma.wait.group.sync.aligned 1
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;", ""
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;", ""
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32)
+  return
+}

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index bbc7676b45eafc..b26f3b02658ffd 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -407,3 +407,25 @@ llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i6
   %isComplete = nvvm.mbarrier.test.wait.shared %barrier, %token : !llvm.ptr<3>, i64 -> i1
   llvm.return
 }
+
+// CHECK-LABEL : @wgmma_fence_aligned
+func.func @wgmma_fence_aligned() {
+  // CHECK : nvvm.wgmma.fence.aligned
+  nvvm.wgmma.fence.aligned
+  return
+}
+
+// CHECK-LABEL : @wgmma_commit_group_sync_aligned
+func.func @wgmma_commit_group_sync_aligned() {
+  // CHECK : nvvm.wgmma.commit.group.sync.aligned
+  nvvm.wgmma.commit.group.sync.aligned
+  return
+}
+
+
+// CHECK-LABEL : @wgmma_commit_group_sync_aligned
+func.func @wgmma_wait_group_sync_aligned() {
+  // CHECK : nvvm.wgmma.wait.group.sync.aligned
+  nvvm.wgmma.wait.group.sync.aligned 0
+  return
+}


        


More information about the Mlir-commits mailing list