[Mlir-commits] [mlir] [mlir][nvvm] Add `cp.async.bulk.tensor.shared.cluster.global.multicast` (PR #72429)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 15 12:17:52 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

This PR introduce `cp.async.bulk.tensor.shared.cluster.global.multicast` Op in NVVM dialect. It loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster.

It resolves #<!-- -->72368

---
Full diff: https://github.com/llvm/llvm-project/pull/72429.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+37) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+5) 
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+45) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ffe6f25fcd944b6..c4d61492083bfc9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp : 
+  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
+  Arguments<(ins  LLVM_PointerShared:$dstMem,
+                  LLVM_AnyPointer:$tmaDescriptor,
+                  LLVM_PointerShared:$mbar,
+                  I16:$multicastMask,
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $dstMem `,` 
+    $tmaDescriptor `,` 
+    $mbar `,` 
+    $multicastMask `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict  `:` type(operands)
+  }];
+
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      int dim = getCoordinates().size();
+      std::string ptx = "cp.async.bulk.tensor.";
+      ptx += std::to_string(dim) + "d.";
+      ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+      if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+      if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+      if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+      if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+      if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+      return ptx;
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3736978505707e3..1c4e2dc98bda602 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -80,6 +80,11 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
     return emitError("Maximum 5 coordinates and dimension is supported.");
   return success();
 }
+LogicalResult CpAsyncBulkTensorGlobalToSharedMulticastClusterOp::verify() {
+  if (getCoordinates().size() > 5)
+    return emitError("Maximum 5 coordinates and dimension is supported.");
+  return success();
+}
 
 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
   if (getCoordinates().size() > 5)
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index b907a86ebc48072..7160a612f25d14b 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -130,6 +130,51 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
   return
 }
 
+// CHECK-LABEL: @tma_load_multicast1d
+func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32,i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast2d
+func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast3d
+func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast4d
+func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7} ], [$2], $3;", "r,l,r,h,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast5d
+func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor,  %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32, i1
+  return
+}
+
 // CHECK-LABEL: @tma_store_1d
 func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"

``````````

</details>


https://github.com/llvm/llvm-project/pull/72429


More information about the Mlir-commits mailing list