[Mlir-commits] [mlir] [mlir][nvvm] Add `cp.async.bulk.tensor.shared.cluster.global.multicast` (PR #72429)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 15 12:17:52 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
This PR introduce `cp.async.bulk.tensor.shared.cluster.global.multicast` Op in NVVM dialect. It loads data using TMA data from global memory to shared memory of multiple CTAs in the cluster.
It resolves #<!-- -->72368
---
Full diff: https://github.com/llvm/llvm-project/pull/72429.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+37)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+5)
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+45)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ffe6f25fcd944b6..c4d61492083bfc9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1398,6 +1398,43 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
+def NVVM_CpAsyncBulkTensorGlobalToSharedMulticastClusterOp :
+ NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global.multicast",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ AttrSizedOperandSegments]>,
+ Arguments<(ins LLVM_PointerShared:$dstMem,
+ LLVM_AnyPointer:$tmaDescriptor,
+ LLVM_PointerShared:$mbar,
+ I16:$multicastMask,
+ Variadic<I32>:$coordinates,
+ PtxPredicate:$predicate)> {
+ let assemblyFormat = [{
+ $dstMem `,`
+ $tmaDescriptor `,`
+ $mbar `,`
+ $multicastMask `,`
+ `box` `[`$coordinates `]`
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type(operands)
+ }];
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ int dim = getCoordinates().size();
+ std::string ptx = "cp.async.bulk.tensor.";
+ ptx += std::to_string(dim) + "d.";
+ ptx += "shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster";
+ if(dim == 1) ptx += " [%0], [%1, {%4} ], [%2], %3;";
+ if(dim == 2) ptx += " [%0], [%1, {%4, %5} ], [%2], %3;";
+ if(dim == 3) ptx += " [%0], [%1, {%4, %5, %6} ], [%2], %3;";
+ if(dim == 4) ptx += " [%0], [%1, {%4, %5, %6, %7} ], [%2], %3;";
+ if(dim == 5) ptx += " [%0], [%1, {%4, %5, %6, %7, %8} ], [%2], %3;";
+ return ptx;
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3736978505707e3..1c4e2dc98bda602 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -80,6 +80,11 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
return emitError("Maximum 5 coordinates and dimension is supported.");
return success();
}
+LogicalResult CpAsyncBulkTensorGlobalToSharedMulticastClusterOp::verify() {
+ if (getCoordinates().size() > 5)
+ return emitError("Maximum 5 coordinates and dimension is supported.");
+ return success();
+}
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
if (getCoordinates().size() > 5)
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index b907a86ebc48072..7160a612f25d14b 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -130,6 +130,51 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
return
}
+// CHECK-LABEL: @tma_load_multicast1d
+func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4} ], [$2], $3;", "r,l,r,h,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32,i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast2d
+func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5} ], [$2], $3;", "r,l,r,h,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast3d
+func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast4d
+func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7} ], [$2], $3;", "r,l,r,h,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6} ], [$2], $3;", "r,l,r,h,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast5d
+func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$4, $5, $6, $7, $8} ], [$2], $3;", "r,l,r,h,r,r,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global.multicast %dest, %tmaDescriptor, %barrier, %multicastMask, box [%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i16, i32, i32, i32, i32, i32, i1
+ return
+}
+
// CHECK-LABEL: @tma_store_1d
func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
``````````
</details>
https://github.com/llvm/llvm-project/pull/72429
More information about the Mlir-commits
mailing list