[Mlir-commits] [mlir] [mlir][nvvm] Add `cp.async.bulk.tensor.shared.cluster.global.multicast` (PR #72429)
Guray Ozen
llvmlistbot at llvm.org
Thu Nov 16 02:18:14 PST 2023
================
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp :
+ NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ AttrSizedOperandSegments]>,
+ Arguments<(ins LLVM_PointerShared:$dstMem,
+ LLVM_AnyPointer:$tmaDescriptor,
+ LLVM_PointerShared:$mbar,
+ I16:$multicastMask,
+ Variadic<I32>:$coordinates,
+ PtxPredicate:$predicate)> {
+ let assemblyFormat = [{
+ $dstMem `,`
+ $tmaDescriptor `,`
+ $mbar `,`
+ $multicastMask `,`
+ `box` `[`$coordinates `]`
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type(operands)
+ }];
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ int dim = getCoordinates().size();
+ std::string ptx = "cp.async.bulk.tensor.";
+ ptx += std::to_string(dim) + "d.";
+ ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+ if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+ if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+ if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+ if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+ if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+ return ptx;
+ }
+ }];
+ let hasVerifier = 1;
+}
+
----------------
grypp wrote:
The Op will grow a lot when we include its other traits such as l2 cache hint, and im2col. I think it looks good to me as it's consistent with the PTX. Do you have any concern?
For example the current op is below:
```
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
box [%crd0,%crd1,%crd2,%crd3]
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
```
with `multicast_mask`
```
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3]
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
```
with `multicast_mask` + `l2_cache_hint`
```
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache,
box [%crd0,%crd1,%crd2,%crd3]
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32
```
with `multicast_mask` + `l2_cache_hint` + `im2col`
```
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache,
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2]
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16
```
Same as above with `predicate`
```
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
box [%crd0,%crd1,%crd2,%crd3],
predicate = %p
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3],
predicate = %p
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i1
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache,
box [%crd0,%crd1,%crd2,%crd3],
predicate = %p
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i1
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache,
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2],
predicate = %p
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16, i1
```
https://github.com/llvm/llvm-project/pull/72429
More information about the Mlir-commits
mailing list