[Mlir-commits] [mlir] [mlir][nvvm] Add `cp.async.bulk.tensor.shared.cluster.global.multicast` (PR #72429)

Guray Ozen llvmlistbot at llvm.org
Thu Nov 16 02:18:14 PST 2023


================
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp : 
+  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
+  Arguments<(ins  LLVM_PointerShared:$dstMem,
+                  LLVM_AnyPointer:$tmaDescriptor,
+                  LLVM_PointerShared:$mbar,
+                  I16:$multicastMask,
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $dstMem `,` 
+    $tmaDescriptor `,` 
+    $mbar `,` 
+    $multicastMask `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict  `:` type(operands)
+  }];
+
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      int dim = getCoordinates().size();
+      std::string ptx = "cp.async.bulk.tensor.";
+      ptx += std::to_string(dim) + "d.";
+      ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+      if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+      if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+      if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+      if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+      if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+      return ptx;
+    }
+  }];
+  let hasVerifier = 1;
+}
+
----------------
grypp wrote:

The Op will grow a lot when we include its other traits such as l2 cache hint, and im2col. I think it looks good to me as it's consistent with the PTX. Do you have any concern? 

For example the current op is below:
```
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
```

with `multicast_mask` 
```
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
```
with `multicast_mask` + `l2_cache_hint`
```
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32
```
with `multicast_mask` + `l2_cache_hint` + `im2col` 
```
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier,
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2] 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16
```

Same as above with `predicate`
```
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p 
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i1

nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, 
multicast_mask = %multicastMask, l2_cache_hint = %cache, 
box [%crd0,%crd1,%crd2,%crd3] im2col [%off1, %off2], 
predicate = %p  
: !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i64, i32, i32, i32, i32, i16, i16, i1
```

https://github.com/llvm/llvm-project/pull/72429


More information about the Mlir-commits mailing list