[Mlir-commits] [mlir] 579c1ff - [mlir][nvvm] Add async copy ops to nvvm dialect
Thomas Raoux
llvmlistbot at llvm.org
Wed Dec 8 09:42:42 PST 2021
Author: Thomas Raoux
Date: 2021-12-08T09:42:20-08:00
New Revision: 579c1ff67dbd346cbcd04a6e000133d079a9387a
URL: https://github.com/llvm/llvm-project/commit/579c1ff67dbd346cbcd04a6e000133d079a9387a
DIFF: https://github.com/llvm/llvm-project/commit/579c1ff67dbd346cbcd04a6e000133d079a9387a.diff
LOG: [mlir][nvvm] Add async copy ops to nvvm dialect
Differential Revision: https://reviews.llvm.org/D115314
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index d285e7ae3ce8b..90bdf61408958 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -94,12 +94,12 @@ class LLVM_PointerTo<Type pointee> : Type<
"LLVM pointer to " # pointee.summary>;
// Type constraints accepting LLVM pointer type to integer of a specific width.
-class LLVM_IntPtrBase<int width> : Type<
+class LLVM_IntPtrBase<int width, int addressSpace = 0> : Type<
LLVM_PointerTo<I<width>>.predicate,
"LLVM pointer to " # I<width>.summary>,
BuildableType<"::mlir::LLVM::LLVMPointerType::get("
"::mlir::IntegerType::get($_builder.getContext(), "
- # width #"))">;
+ # width #"), "# addressSpace #")">;
def LLVM_i8Ptr : LLVM_IntPtrBase<8>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index b1182eae3237f..1a364bbc364fe 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -16,6 +16,9 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+def LLVM_i8Ptr_global : LLVM_IntPtrBase<8, 1>;
+def LLVM_i8Ptr_shared : LLVM_IntPtrBase<8, 3>;
+
//===----------------------------------------------------------------------===//
// NVVM dialect definitions
//===----------------------------------------------------------------------===//
@@ -157,6 +160,56 @@ def NVVM_VoteBallotOp :
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
}
+
+def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
+ Arguments<(ins LLVM_i8Ptr_shared:$dst,
+ LLVM_i8Ptr_global:$src,
+ I32Attr:$size)> {
+ string llvmBuilder = [{
+ llvm::Intrinsic::ID id;
+ switch ($size) {
+ case 4:
+ id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_4;
+ break;
+ case 8:
+ id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8;
+ break;
+ case 16:
+ id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
+ break;
+ default:
+ llvm_unreachable("unsupported async copy size");
+ }
+ createIntrinsicCall(builder, id, {$dst, $src});
+ }];
+ let verifier = [{
+ if (size() != 4 && size() != 8 && size() != 16)
+ return emitError("expected byte size to be either 4, 8 or 16.");
+ return success();
+ }];
+ let assemblyFormat = "$dst `,` $src `,` $size attr-dict";
+}
+
+def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> {
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_commit_group);
+ }];
+ let assemblyFormat = "attr-dict";
+}
+
+def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">,
+ Arguments<(ins I32Attr:$n)> {
+ string llvmBuilder = [{
+ createIntrinsicCall(
+ builder,
+ llvm::Intrinsic::nvvm_cp_async_wait_group,
+ llvm::ConstantInt::get(
+ llvm::Type::getInt32Ty(moduleTranslation.getLLVMContext()),
+ $n));
+ }];
+ let assemblyFormat = "$n attr-dict";
+}
+
def NVVM_MmaOp :
NVVM_Op<"mma.sync">,
Results<(outs LLVM_Type:$res)>,
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 9b9df36210a79..41919ceaaed03 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1226,3 +1226,11 @@ func @bitcast(%arg0: vector<2x3xf32>) {
llvm.bitcast %arg0 : vector<2x3xf32> to vector<2x3xi32>
return
}
+
+// -----
+
+func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
+ // expected-error @below {{expected byte size to be either 4, 8 or 16.}}
+ nvvm.cp.async.shared.global %arg0, %arg1, 32
+ return
+}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e2ca4d71bbe25..75d099cf0ac46 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -95,6 +95,15 @@ func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32,
llvm.return %r : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
}
+llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
+// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16
+ nvvm.cp.async.shared.global %arg0, %arg1, 16
+// CHECK: nvvm.cp.async.commit.group
+ nvvm.cp.async.commit.group
+// CHECK: nvvm.cp.async.wait.group 0
+ nvvm.cp.async.wait.group 0
+ llvm.return
+}
// -----
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index a9ec2259dfcff..ab52d2a95f685 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -162,6 +162,20 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
llvm.return
}
+llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
+// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
+ nvvm.cp.async.shared.global %arg0, %arg1, 4
+// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
+ nvvm.cp.async.shared.global %arg0, %arg1, 8
+// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
+ nvvm.cp.async.shared.global %arg0, %arg1, 16
+// CHECK: call void @llvm.nvvm.cp.async.commit.group()
+ nvvm.cp.async.commit.group
+// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
+ nvvm.cp.async.wait.group 0
+ llvm.return
+}
+
// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {
More information about the Mlir-commits
mailing list