[Mlir-commits] [mlir] [MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics (PR #78900)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 21 05:38:49 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Durgadoss R (durga4github)
<details>
<summary>Changes</summary>
This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics.
* Doc updated for the commit_group Op.
* Tests are added to verify the lowering to the intrinsics.
While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test.
---
Full diff: https://github.com/llvm/llvm-project/pull/78900.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+18-12)
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+7-11)
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+22-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f986..3916896382163ea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
-def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">,
+def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); }
+ let description = [{
+ This Op commits all prior initiated but uncommitted cp.async.bulk
+ instructions into a cp.async.bulk-group.
+
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
+ }];
+
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
}];
}
-def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
+def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
Arguments<(ins
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
- OptionalAttr<UnitAttr>:$read)>
-{
+ OptionalAttr<UnitAttr>:$read)> {
let assemblyFormat = "$group attr-dict";
let description = [{
Op waits for completion of the most recent bulk async-groups.
@@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
}];
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- auto ptx = std::string("cp.async.bulk.wait_group");
- if(getRead()) ptx += ".read";
- ptx += " %0;"; return ptx; }
+ string llvmBuilder = [{
+ auto intId = op.getRead() ?
+ llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
+ llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
+ createIntrinsicCall(builder, intId, builder.getInt32($group));
}];
}
-
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9487bdf3bd218a..40131af6826487a 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -638,23 +638,19 @@ func.func @set_max_register() {
// -----
-func.func @cp_bulk_commit() {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;"
+func.func @cp_async_bulk_commit() {
+ // CHECK: nvvm.cp.async.bulk.commit.group
nvvm.cp.async.bulk.commit.group
func.return
}
// -----
-func.func @cp_bulk_wait_group() {
- // CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
- // CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
- // CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
- // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
+func.func @cp_async_bulk_wait_group() {
+ // CHECK: nvvm.cp.async.bulk.wait_group 1
+ // CHECK: nvvm.cp.async.bulk.wait_group 0
+ // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
+ // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
nvvm.cp.async.bulk.wait_group 1
nvvm.cp.async.bulk.wait_group 0
nvvm.cp.async.bulk.wait_group 5 {read}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f68..49f9426daabc21b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
// CHECK-LABEL: @llvm_nvvm_setmaxregister
llvm.func @llvm_nvvm_setmaxregister() {
- // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
+ // CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
nvvm.setmaxregister increase 256
- // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
+ // CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
nvvm.setmaxregister decrease 24
llvm.return
}
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group
+llvm.func @llvm_nvvm_cp_async_bulk_commit_group() {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group()
+ nvvm.cp.async.bulk.commit.group
+ llvm.return
+}
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group
+llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0)
+ nvvm.cp.async.bulk.wait_group 0
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3)
+ nvvm.cp.async.bulk.wait_group 3
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0)
+ nvvm.cp.async.bulk.wait_group 0 {read}
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3)
+ nvvm.cp.async.bulk.wait_group 3 {read}
+ llvm.return
+}
+
// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
``````````
</details>
https://github.com/llvm/llvm-project/pull/78900
More information about the Mlir-commits
mailing list