[Mlir-commits] [mlir] [mlir][nvvm] Add `cp.async.bulk.tensor.shared.cluster.global.multicast` (PR #72429)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 16 03:15:34 PST 2023
================
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp :
+ NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ AttrSizedOperandSegments]>,
+ Arguments<(ins LLVM_PointerShared:$dstMem,
+ LLVM_AnyPointer:$tmaDescriptor,
+ LLVM_PointerShared:$mbar,
+ I16:$multicastMask,
+ Variadic<I32>:$coordinates,
+ PtxPredicate:$predicate)> {
+ let assemblyFormat = [{
+ $dstMem `,`
+ $tmaDescriptor `,`
+ $mbar `,`
+ $multicastMask `,`
+ `box` `[`$coordinates `]`
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type(operands)
+ }];
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ int dim = getCoordinates().size();
+ std::string ptx = "cp.async.bulk.tensor.";
+ ptx += std::to_string(dim) + "d.";
+ ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+ if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+ if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+ if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+ if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+ if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+ return ptx;
+ }
+ }];
+ let hasVerifier = 1;
+}
+
----------------
durga4github wrote:
> I have updated the PR to use the existing Op.
The updated version looks good to me.
>
> Yet, when we include other traits such as l2 cache hint and im2col, the Op will grow. Personally I find it consistent with PTX. Do you have any concerns? If not, I can put up a follow-up PR to support the remaining features.
I do not see any concerns. We can extend it the same way for cache-hint.
I believe, im2col itself will be a variadic type (since it can be of size 1,2,3). So, as long as we can have an operand that's both variadic + optional, we are good with this direction.
>
> For example the current op is below:
>
> ```
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
> box [%crd0,%crd1,%crd2,%crd3]
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
> ```
>
> with `multicast_mask`
>
> ```
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
> multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3]
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
> ```
>
> with `multicast_mask` + `l2_cache_hint`
>
> ```
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
> multicast_mask = %multicastMask, l2_cache_hint = %cache,
> box [%crd0,%crd1,%crd2,%crd3]
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32
> ```
>
> with `multicast_mask` + `l2_cache_hint` + `im2col`
>
> ```
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
> multicast_mask = %multicastMask, l2_cache_hint = %cache,
> box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2]
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16
> ```
>
> Same as above with `predicate`
>
> ```
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
> box [%crd0,%crd1,%crd2,%crd3],
> predicate = %p
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
>
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
> multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3],
> predicate = %p
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i1
>
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
> multicast_mask = %multicastMask, l2_cache_hint = %cache,
> box [%crd0,%crd1,%crd2,%crd3],
> predicate = %p
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i1
>
> nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
> multicast_mask = %multicastMask, l2_cache_hint = %cache,
> box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2],
> predicate = %p
> : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16, i1
> ```
https://github.com/llvm/llvm-project/pull/72429
More information about the Mlir-commits
mailing list