[Mlir-commits] [mlir] [mlir][nvvm] Add `cp.async.bulk.tensor.shared.cluster.global.multicast` (PR #72429)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 16 03:15:34 PST 2023


================
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp : 
+  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
+  Arguments<(ins  LLVM_PointerShared:$dstMem,
+                  LLVM_AnyPointer:$tmaDescriptor,
+                  LLVM_PointerShared:$mbar,
+                  I16:$multicastMask,
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $dstMem `,` 
+    $tmaDescriptor `,` 
+    $mbar `,` 
+    $multicastMask `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict  `:` type(operands)
+  }];
+
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      int dim = getCoordinates().size();
+      std::string ptx = "cp.async.bulk.tensor.";
+      ptx += std::to_string(dim) + "d.";
+      ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+      if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+      if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+      if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+      if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+      if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+      return ptx;
+    }
+  }];
+  let hasVerifier = 1;
+}
+
----------------
durga4github wrote:

> I have updated the PR to use the existing Op.

The updated version looks good to me.

> 
> Yet, when we include other traits such as l2 cache hint and im2col, the Op will grow. Personally I find it consistent with PTX. Do you have any concerns? If not, I can put up a follow-up PR to support the remaining features.

I do not see any concerns. We can extend it the same way for cache-hint.

I believe, im2col itself will be a variadic type (since it can be of size 1,2,3). So, as long as we can have an operand that's both variadic + optional, we are good with this direction.

> 
> For example the current op is below:
> 
> ```
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
> box [%crd0,%crd1,%crd2,%crd3] 
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
> ```
> 
> with `multicast_mask`
> 
> ```
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
> multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3] 
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
> ```
> 
> with `multicast_mask` + `l2_cache_hint`
> 
> ```
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
> multicast_mask = %multicastMask, l2_cache_hint = %cache, 
> box [%crd0,%crd1,%crd2,%crd3] 
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32
> ```
> 
> with `multicast_mask` + `l2_cache_hint` + `im2col`
> 
> ```
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier,
> multicast_mask = %multicastMask, l2_cache_hint = %cache, 
> box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2] 
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16
> ```
> 
> Same as above with `predicate`
> 
> ```
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
> box [%crd0,%crd1,%crd2,%crd3], 
> predicate = %p 
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
> 
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
> multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3], 
> predicate = %p  
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i1
> 
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
> multicast_mask = %multicastMask, l2_cache_hint = %cache, 
> box [%crd0,%crd1,%crd2,%crd3], 
> predicate = %p  
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i1
> 
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
> multicast_mask = %multicastMask, l2_cache_hint = %cache, 
> box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2], 
> predicate = %p  
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16, i1
> ```



https://github.com/llvm/llvm-project/pull/72429


More information about the Mlir-commits mailing list