[Mlir-commits] [mlir] [mlir][nvvm] Add `cp.async.bulk.tensor.shared.cluster.global.multicast` (PR #72429)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 15 23:18:10 PST 2023


================
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp : 
+  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
+  Arguments<(ins  LLVM_PointerShared:$dstMem,
+                  LLVM_AnyPointer:$tmaDescriptor,
+                  LLVM_PointerShared:$mbar,
+                  I16:$multicastMask,
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $dstMem `,` 
+    $tmaDescriptor `,` 
+    $mbar `,` 
+    $multicastMask `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict  `:` type(operands)
+  }];
+
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      int dim = getCoordinates().size();
+      std::string ptx = "cp.async.bulk.tensor.";
+      ptx += std::to_string(dim) + "d.";
+      ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+      if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+      if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+      if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+      if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+      if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+      return ptx;
+    }
+  }];
+  let hasVerifier = 1;
+}
+
----------------
durga4github wrote:

+1 for extending the existing Op.

I am not sure if we can have the mask as an attribute (since it may not be a compile-time constant always). 

However, we can use the same Op with the mask as an optional operand. That way, if we have the mask available, we generate the multicast variant (but use the existing one otherwise).

https://github.com/llvm/llvm-project/pull/72429


More information about the Mlir-commits mailing list