[Mlir-commits] [mlir] 579c1ff - [mlir][nvvm] Add async copy ops to nvvm dialect

Thomas Raoux llvmlistbot at llvm.org
Wed Dec 8 09:42:42 PST 2021


Author: Thomas Raoux
Date: 2021-12-08T09:42:20-08:00
New Revision: 579c1ff67dbd346cbcd04a6e000133d079a9387a

URL: https://github.com/llvm/llvm-project/commit/579c1ff67dbd346cbcd04a6e000133d079a9387a
DIFF: https://github.com/llvm/llvm-project/commit/579c1ff67dbd346cbcd04a6e000133d079a9387a.diff

LOG: [mlir][nvvm] Add async copy ops to nvvm dialect

Differential Revision: https://reviews.llvm.org/D115314

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index d285e7ae3ce8b..90bdf61408958 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -94,12 +94,12 @@ class LLVM_PointerTo<Type pointee> : Type<
   "LLVM pointer to " # pointee.summary>;
 
 // Type constraints accepting LLVM pointer type to integer of a specific width.
-class LLVM_IntPtrBase<int width> : Type<
+class LLVM_IntPtrBase<int width, int addressSpace = 0> : Type<
   LLVM_PointerTo<I<width>>.predicate,
   "LLVM pointer to " # I<width>.summary>,
   BuildableType<"::mlir::LLVM::LLVMPointerType::get("
                 "::mlir::IntegerType::get($_builder.getContext(), "
-                # width #"))">;
+                # width #"), "# addressSpace #")">;
 
 def LLVM_i8Ptr : LLVM_IntPtrBase<8>;
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index b1182eae3237f..1a364bbc364fe 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -16,6 +16,9 @@
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
+def LLVM_i8Ptr_global : LLVM_IntPtrBase<8, 1>;
+def LLVM_i8Ptr_shared : LLVM_IntPtrBase<8, 3>;
+
 //===----------------------------------------------------------------------===//
 // NVVM dialect definitions
 //===----------------------------------------------------------------------===//
@@ -157,6 +160,56 @@ def NVVM_VoteBallotOp :
   let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
 }
 
+
+def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
+  Arguments<(ins LLVM_i8Ptr_shared:$dst,
+                 LLVM_i8Ptr_global:$src,
+                 I32Attr:$size)> {
+  string llvmBuilder = [{
+      llvm::Intrinsic::ID id;
+      switch ($size) {
+        case 4:
+          id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_4;
+          break;
+        case 8:
+          id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8;
+          break;
+        case 16:
+          id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
+          break;
+        default:
+          llvm_unreachable("unsupported async copy size");
+      }
+      createIntrinsicCall(builder, id, {$dst, $src});
+  }];
+  let verifier = [{
+    if (size() != 4 && size() != 8 && size() != 16)
+      return emitError("expected byte size to be either 4, 8 or 16.");
+    return success();
+  }];
+  let assemblyFormat = "$dst `,` $src `,` $size attr-dict";
+}
+
+def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> {
+  string llvmBuilder = [{
+      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_commit_group);
+  }];
+  let assemblyFormat = "attr-dict";
+}
+
+def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">,
+  Arguments<(ins I32Attr:$n)> {
+  string llvmBuilder = [{
+      createIntrinsicCall(
+        builder,
+        llvm::Intrinsic::nvvm_cp_async_wait_group,
+        llvm::ConstantInt::get(
+          llvm::Type::getInt32Ty(moduleTranslation.getLLVMContext()),
+          $n));
+  }];
+  let assemblyFormat = "$n attr-dict";
+}
+
 def NVVM_MmaOp :
   NVVM_Op<"mma.sync">,
   Results<(outs LLVM_Type:$res)>,

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 9b9df36210a79..41919ceaaed03 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1226,3 +1226,11 @@ func @bitcast(%arg0: vector<2x3xf32>) {
   llvm.bitcast %arg0 : vector<2x3xf32> to vector<2x3xi32>
   return
 }
+
+// -----
+
+func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
+  // expected-error @below {{expected byte size to be either 4, 8 or 16.}}
+  nvvm.cp.async.shared.global %arg0, %arg1, 32
+  return
+}

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e2ca4d71bbe25..75d099cf0ac46 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -95,6 +95,15 @@ func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32,
   llvm.return %r : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
+llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
+// CHECK:  nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16
+  nvvm.cp.async.shared.global %arg0, %arg1, 16
+// CHECK: nvvm.cp.async.commit.group
+  nvvm.cp.async.commit.group
+// CHECK: nvvm.cp.async.wait.group 0
+  nvvm.cp.async.wait.group 0
+  llvm.return
+}
 
 // -----
 

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index a9ec2259dfcff..ab52d2a95f685 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -162,6 +162,20 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
   llvm.return
 }
 
+llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
+// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
+  nvvm.cp.async.shared.global %arg0, %arg1, 4
+// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
+  nvvm.cp.async.shared.global %arg0, %arg1, 8
+// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
+  nvvm.cp.async.shared.global %arg0, %arg1, 16
+// CHECK: call void @llvm.nvvm.cp.async.commit.group()
+  nvvm.cp.async.commit.group
+// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
+  nvvm.cp.async.wait.group 0
+  llvm.return
+}
+
 // This function has the "kernel" attribute attached and should appear in the
 // NVVM annotations after conversion.
 llvm.func @kernel_func() attributes {nvvm.kernel} {


        


More information about the Mlir-commits mailing list