[Mlir-commits] [mlir] 15bcc36 - [mlir][gpu] Move async copy ops to NVGPU and add caching hints

Thomas Raoux llvmlistbot at llvm.org
Tue May 10 15:30:31 PDT 2022


Author: Thomas Raoux
Date: 2022-05-10T22:30:24Z
New Revision: 15bcc36eede13b24470e554dfa932f1fc40dd4ba

URL: https://github.com/llvm/llvm-project/commit/15bcc36eede13b24470e554dfa932f1fc40dd4ba
DIFF: https://github.com/llvm/llvm-project/commit/15bcc36eede13b24470e554dfa932f1fc40dd4ba.diff

LOG: [mlir][gpu] Move async copy ops to NVGPU and add caching hints

Move async copy operations to NVGPU as they only exist on NV target and are
designed to match ptx semantic. This allows us to also add more fine grain
caching hint attribute to the op.
Add hint to bypass L1 and hook it up to NVVM op.

Differential Revision: https://reviews.llvm.org/D125244

Added: 
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
    mlir/test/Dialect/NVGPU/invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUBase.td
    mlir/include/mlir/Dialect/GPU/GPUDialect.h
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
    mlir/include/mlir/Dialect/NVGPU/NVGPU.td
    mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt
    mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/GPU/ops.mlir
    mlir/test/Dialect/NVGPU/roundtrip.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td
index 89dbde5362b85..ecc457e4bb941 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td
@@ -60,13 +60,6 @@ def GPU_AsyncToken : DialectType<
   GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::AsyncTokenType>()">, "async token type">,
              BuildableType<"mlir::gpu::AsyncTokenType::get($_builder.getContext())">;
 
-/// Device-side synchronization token.
-def GPU_DeviceAsyncToken : DialectType<
-  GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::DeviceAsyncTokenType>()">,
-   "device async token type">,
-   BuildableType<
-     "mlir::gpu::DeviceAsyncTokenType::get($_builder.getContext())">;
-
 // Predicat to check if type is gpu::MMAMatrixType.
 def IsMMAMatrixTypePred : CPred<"$_self.isa<::mlir::gpu::MMAMatrixType>()">;
 

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index d3b9add25ec46..d8dc67a21e2c9 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -43,14 +43,6 @@ class AsyncTokenType
   using Base::Base;
 };
 
-/// Device-side token storage type. There is only one type of device-side token.
-class DeviceAsyncTokenType
-    : public Type::TypeBase<DeviceAsyncTokenType, Type, TypeStorage> {
-public:
-  // Used for generic hooks in TypeBase.
-  using Base::Base;
-};
-
 /// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
 /// and type.
 struct MMAMatrixStorageType : public TypeStorage {

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index dc87edb679d84..7ff75b19f72a0 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -1280,105 +1280,4 @@ def GPU_SubgroupMmaElementwiseOp : GPU_Op<"subgroup_mma_elementwise",
   }];
 }
 
-def GPU_DeviceAsyncCopyOp : GPU_Op<"device_async_copy",
-  [AttrSizedOperandSegments]> {
-  let summary = "device-side asynchronous copy";
-  let description = [{
-    The `gpu.device_async_copy` op initiates an asynchronous copy operation of
-    `$size` elements from source to the destination without blocking the thread.
-    The destination has to be in shared memory.
-
-    This is memory access will be pending to be added to a group.
-
-    This op is meant to be used with `gpu.device_async_create_group` and
-    `gpu.device_async_wait` to synchronize copies as explained in those ops
-    descriptions.
-
-    In order to do a copy and wait for the result we need the following
-    combination:
-    ```
-    // copy 1.
-    %cp1 = gpu.device_async_copy %A[%c0], %B[%c0], 4 :memref<16xf32> to memref<16xf32, 3>
-    // copy 2.
-    %cp2 = gpu.device_async_copy %C[%c0], %D[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
-    // group 1 contains copy 1 and copy 2.
-    %token1 = gpu.device_async_create_group %cp1, %cp2
-    // copy 3.
-    %cp3 = gpu.device_async_copy %E[%c0], %F[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
-    // group 2 contains copy 3.
-    %token2 = gpu.device_async_create_group %cp3
-    // after the wait copy 1 and copy 2 are complete.
-    gpu.device_async_wait %token1
-    // after the wait copy 3 is complete.
-    gpu.device_async_wait %token2
-    ```
-
-    Example:
-
-    ```mlir
-    %0 = gpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 4 :
-      memref<4x5xf32> to memref<2x7x5xf32, 3>
-    ```
-  }];
-  let results = (outs GPU_DeviceAsyncToken:$asyncToken);
-  let arguments = (ins Arg<AnyMemRef, "", [MemWrite]>:$dst,
-                       Variadic<Index>:$dstIndices,
-                       Arg<AnyMemRef, "", [MemRead]>:$src,
-                       Variadic<Index>:$srcIndices,
-                       IndexAttr:$numElements);
-  let assemblyFormat = [{
-    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $numElements
-      attr-dict `:` type($src) `to` type($dst)
-  }];
-  let hasVerifier = 1;
-}
-
-def GPU_DeviceAsyncCreateGroupOp : GPU_Op<"device_async_create_group", []> {
-  let summary = "device side asynchronous create group operation";
-  let description = [{
-  The `gpu.device_async_create_group` op creates a group of memory accesses
-  containing all the pending `device_async_copy` operations associated with
-  argument tokens. Each token can only be part of one group.
-
-  It returns a token that can be use to wait until the group fully completes.
-
-  This is meant to be used with `gpu.device_async_wait` to synchronize copies
-  as explained in those ops descriptions.
-
-  Groups are executed in the order they are created.
-
-  Example:
-
-  ```mlir
-  %0 = gpu.device_async_create_group
-  ```
-  }];
-  let results = (outs GPU_DeviceAsyncToken:$asyncToken);
-  let arguments = (ins Variadic<GPU_DeviceAsyncToken>:$inputTokens);
-  let assemblyFormat = [{
-    $inputTokens attr-dict
-  }];
-}
-
-def GPU_DeviceAsyncWaitOp : GPU_Op<"device_async_wait", []> {
-  let summary = "Wait for async gpu ops to complete.";
-  let description = [{
-  The `gpu.device_async_wait` op will block the execution thread until the group
-  associated with the source token is fully completed.
-
-    The optional `$numGroup` attribute gives a lower bound of the number of
-    groups uncompleted when the wait can unblock the thread.
-  Example:
-
-  ```mlir
-  gpu.device_async_wait %0
-  ```
-  }];
-  let arguments = (ins GPU_DeviceAsyncToken:$asyncDependencies,
-                       OptionalAttr<I32Attr>:$numGroups);
-  let assemblyFormat = [{
-    $asyncDependencies attr-dict
-  }];
-}
-
 #endif // GPU_OPS

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 7be5f49c0326a..dc33973430abc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -25,6 +25,15 @@
 
 namespace mlir {
 namespace NVVM {
+
+/// NVVM memory space identifiers.
+enum NVVMMemorySpace {
+  /// Global memory space identifier.
+  kGlobalMemorySpace = 1,
+  /// Shared memory space identifier.
+  kSharedMemorySpace = 3
+};
+
 /// Return the element type and number of elements associated with a wmma matrix
 /// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
 /// WMMA_REGS structure.

diff  --git a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
index 28147e3fe4aae..f8ed7237b5587 100644
--- a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
@@ -32,8 +32,17 @@ def NVGPU_Dialect : Dialect {
     representing PTX specific operations while using MLIR high level concepts
     like memref and 2-D vector.
   }];
+  let useDefaultAttributePrinterParser = 1;
 }
 
+/// Device-side synchronization token.
+def NVGPU_DeviceAsyncToken : DialectType<
+  NVGPU_Dialect, CPred<"$_self.isa<::mlir::nvgpu::DeviceAsyncTokenType>()">,
+   "device async token type">,
+   BuildableType<
+     "mlir::nvgpu::DeviceAsyncTokenType::get($_builder.getContext())">;
+
+
 //===----------------------------------------------------------------------===//
 // NVGPU Op definitions
 //===----------------------------------------------------------------------===//
@@ -73,24 +82,24 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
   let description = [{
   The `nvgpu.mma.sync` op represents the distributed form of a collective
   matrix-multiply-and-accumulate (mma) operation that is compatible with
-  `nvvm.mma.sync`. The operands and results are fragments of the full matrix 
+  `nvvm.mma.sync`. The operands and results are fragments of the full matrix
   operands. The full shape of the distributed mma operation is given by the
-  `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`.  
+  `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`.
 
   This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and
   is an intermediate point between lowering from `vector.contract` to
   `nvvm.mma.sync`.
-  
+
   This operation is meant to follow the semantic of described here:
     https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma
-  
+
   Example:
-  
+
   ```mlir
   nvgpu.mma.sync (%a, %b, %c) :
     (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   ```
-  }];   
+  }];
   let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB,
                        AnyVector:$matrixC, I64ArrayAttr:$mmaShape);
 
@@ -102,4 +111,110 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
   }];
 }
 
+
+def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy",
+  [AttrSizedOperandSegments]> {
+  let summary = "device-side asynchronous copy";
+  let description = [{
+    The `gpu.device_async_copy` op initiates an asynchronous copy operation of
+    `$size` elements from source to the destination without blocking the thread.
+    The destination has to be in shared memory.
+
+    This is memory access will be pending to be added to a group.
+
+    This op is meant to be used with `gpu.device_async_create_group` and
+    `gpu.device_async_wait` to synchronize copies as explained in those ops
+    descriptions. 
+    `bypassL1` attribute is hint to the backend and hardware that
+    the copy should by pass the L1 cache, this may be dropped by the backend or
+    hardware. 
+
+    In order to do a copy and wait for the result we need the following
+    combination:
+    ```
+    // copy 1.
+    %cp1 = gpu.device_async_copy %A[%c0], %B[%c0], 4 :memref<16xf32> to memref<16xf32, 3>
+    // copy 2.
+    %cp2 = gpu.device_async_copy %C[%c0], %D[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
+    // group 1 contains copy 1 and copy 2.
+    %token1 = gpu.device_async_create_group %cp1, %cp2
+    // copy 3.
+    %cp3 = gpu.device_async_copy %E[%c0], %F[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
+    // group 2 contains copy 3.
+    %token2 = gpu.device_async_create_group %cp3
+    // after the wait copy 1 and copy 2 are complete.
+    gpu.device_async_wait %token1
+    // after the wait copy 3 is complete.
+    gpu.device_async_wait %token2
+    ```
+
+    Example:
+
+    ```mlir
+    %0 = gpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 4 :
+      memref<4x5xf32> to memref<2x7x5xf32, 3>
+    ```
+  }];
+  let results = (outs NVGPU_DeviceAsyncToken:$asyncToken);
+  let arguments = (ins Arg<AnyMemRef, "", [MemWrite]>:$dst,
+                       Variadic<Index>:$dstIndices,
+                       Arg<AnyMemRef, "", [MemRead]>:$src,
+                       Variadic<Index>:$srcIndices,
+                       IndexAttr:$numElements,
+                       OptionalAttr<UnitAttr>:$bypassL1);
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $numElements
+      attr-dict `:` type($src) `to` type($dst)
+  }];
+  let hasVerifier = 1;
+}
+
+def NVGPU_DeviceAsyncCreateGroupOp : NVGPU_Op<"device_async_create_group", []> {
+  let summary = "device side asynchronous create group operation";
+  let description = [{
+  The `gpu.device_async_create_group` op creates a group of memory accesses
+  containing all the pending `device_async_copy` operations associated with
+  argument tokens. Each token can only be part of one group.
+
+  It returns a token that can be use to wait until the group fully completes.
+
+  This is meant to be used with `gpu.device_async_wait` to synchronize copies
+  as explained in those ops descriptions.
+
+  Groups are executed in the order they are created.
+
+  Example:
+
+  ```mlir
+  %0 = gpu.device_async_create_group
+  ```
+  }];
+  let results = (outs NVGPU_DeviceAsyncToken:$asyncToken);
+  let arguments = (ins Variadic<NVGPU_DeviceAsyncToken>:$inputTokens);
+  let assemblyFormat = [{
+    $inputTokens attr-dict
+  }];
+}
+
+def NVGPU_DeviceAsyncWaitOp : NVGPU_Op<"device_async_wait", []> {
+  let summary = "Wait for async gpu ops to complete.";
+  let description = [{
+  The `gpu.device_async_wait` op will block the execution thread until the group
+  associated with the source token is fully completed.
+
+    The optional `$numGroup` attribute gives a lower bound of the number of
+    groups uncompleted when the wait can unblock the thread.
+  Example:
+
+  ```mlir
+  gpu.device_async_wait %0
+  ```
+  }];
+  let arguments = (ins NVGPU_DeviceAsyncToken:$asyncDependencies,
+                       OptionalAttr<I32Attr>:$numGroups);
+  let assemblyFormat = [{
+    $asyncDependencies attr-dict
+  }];
+}
+
 #endif // NVGPU

diff  --git a/mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h
index efa14433ccf24..431ba0ac0d31e 100644
--- a/mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h
@@ -18,6 +18,20 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+namespace mlir {
+namespace nvgpu {
+
+/// Device-side token storage type. There is only one type of device-side token.
+class DeviceAsyncTokenType
+    : public Type::TypeBase<DeviceAsyncTokenType, Type, TypeStorage> {
+public:
+  // Used for generic hooks in TypeBase.
+  using Base::Base;
+};
+
+} // namespace nvgpu
+} // namespace mlir
+
 #include "mlir/Dialect/NVGPU/NVGPUDialect.h.inc"
 
 #define GET_OP_CLASSES

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 6ccc8a064396d..56b3733c5e45c 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -42,14 +42,6 @@ using namespace mlir;
 
 namespace {
 
-/// NVVM memory space identifiers.
-enum NVVMMemorySpace {
-  /// Global memory space identifier.
-  kGlobalMemorySpace = 1,
-  /// Shared memory space identifier.
-  kSharedMemorySpace = 3
-};
-
 /// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
   switch (mode) {
@@ -132,83 +124,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   }
 };
 
-struct GPUAsyncCopyLowering
-    : public ConvertOpToLLVMPattern<gpu::DeviceAsyncCopyOp> {
-  using ConvertOpToLLVMPattern<gpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
-    auto dstMemrefType = op.dst().getType().cast<MemRefType>();
-    Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.dst(),
-                                        adaptor.dstIndices(), rewriter);
-    auto i8Ty = IntegerType::get(op.getContext(), 8);
-    auto dstPointerType =
-        LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
-    dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
-
-    auto srcMemrefType = op.src().getType().cast<MemRefType>();
-
-    Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.src(),
-                                        adaptor.srcIndices(), rewriter);
-    auto srcPointerType =
-        LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
-    scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
-    // Intrinsics takes a global pointer so we need an address space cast.
-    auto srcPointerGlobalType =
-        LLVM::LLVMPointerType::get(i8Ty, NVVMMemorySpace::kGlobalMemorySpace);
-    scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
-                                                    scrPtr);
-    int64_t numElements = adaptor.numElements().getZExtValue();
-    int64_t sizeInBytes =
-        (dstMemrefType.getElementTypeBitWidth() / 8) * numElements;
-    rewriter.create<NVVM::CpAsyncOp>(loc, dstPtr, scrPtr,
-                                     rewriter.getI32IntegerAttr(sizeInBytes),
-                                     /*bypassL1=*/UnitAttr());
-
-    // Drop the result token.
-    Value zero = rewriter.create<LLVM::ConstantOp>(
-        op->getLoc(), IntegerType::get(op.getContext(), 32),
-        rewriter.getI32IntegerAttr(0));
-    rewriter.replaceOp(op, zero);
-    return success();
-  }
-};
-
-struct GPUAsyncCreateGroupLowering
-    : public ConvertOpToLLVMPattern<gpu::DeviceAsyncCreateGroupOp> {
-  using ConvertOpToLLVMPattern<
-      gpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
-    // Drop the result token.
-    Value zero = rewriter.create<LLVM::ConstantOp>(
-        op->getLoc(), IntegerType::get(op.getContext(), 32),
-        rewriter.getI32IntegerAttr(0));
-    rewriter.replaceOp(op, zero);
-    return success();
-  }
-};
-
-struct GPUAsyncWaitLowering
-    : public ConvertOpToLLVMPattern<gpu::DeviceAsyncWaitOp> {
-  using ConvertOpToLLVMPattern<gpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // If numGroup is not present pick 0 as a conservative correct value.
-    int32_t numGroups = adaptor.numGroups() ? *adaptor.numGroups() : 0;
-    rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
 struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
 
@@ -270,11 +185,6 @@ struct LowerGpuOpsToNVVMOpsPass
         return llvm::None;
       return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
     });
-    /// device-side async tokens cannot be materialized in nvvm. We just convert
-    /// them to a dummy i32 type in order to easily drop them during conversion.
-    converter.addConversion([&](gpu::DeviceAsyncTokenType type) -> Type {
-      return converter.convertType(IntegerType::get(type.getContext(), 32));
-    });
     // Lowering for MMAMatrixType.
     converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
       return convertMMAToLLVMType(type);
@@ -375,8 +285,6 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                    "__nv_sqrt");
   patterns.add<OpToFuncCallLowering<math::TanhOp>>(converter, "__nv_tanhf",
                                                    "__nv_tanh");
-  patterns.add<GPUAsyncCopyLowering, GPUAsyncCreateGroupLowering,
-               GPUAsyncWaitLowering>(converter);
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
index 50750e478095b..4d9e46907590e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRNVGPUToNVVM
   Core
 
   LINK_LIBS PUBLIC
+  MLIRGPUOps
   MLIRLLVMCommonConversion
   MLIRLLVMIR
   MLIRNVVMIR

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 3ea4330c64082..5152beaa5f61e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -10,6 +10,7 @@
 #include "../PassDetail.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/NVGPU/NVGPUDialect.h"
 
@@ -327,6 +328,11 @@ struct ConvertNVGPUToNVVMPass
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     LLVMTypeConverter converter(&getContext());
+    /// device-side async tokens cannot be materialized in nvvm. We just convert
+    /// them to a dummy i32 type in order to easily drop them during conversion.
+    converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
+      return converter.convertType(IntegerType::get(type.getContext(), 32));
+    });
     populateNVGPUToNVVMConversionPatterns(converter, patterns);
     LLVMConversionTarget target(getContext());
     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
@@ -337,10 +343,94 @@ struct ConvertNVGPUToNVVMPass
   }
 };
 
+struct NVGPUAsyncCopyLowering
+    : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto dstMemrefType = op.dst().getType().cast<MemRefType>();
+    Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.dst(),
+                                        adaptor.dstIndices(), rewriter);
+    auto i8Ty = IntegerType::get(op.getContext(), 8);
+    auto dstPointerType =
+        LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
+    dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
+
+    auto srcMemrefType = op.src().getType().cast<MemRefType>();
+
+    Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.src(),
+                                        adaptor.srcIndices(), rewriter);
+    auto srcPointerType =
+        LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
+    scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
+    // Intrinsics takes a global pointer so we need an address space cast.
+    auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
+        i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
+    scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
+                                                    scrPtr);
+    int64_t numElements = adaptor.numElements().getZExtValue();
+    int64_t sizeInBytes =
+        (dstMemrefType.getElementTypeBitWidth() / 8) * numElements;
+    // bypass L1 is only supported for byte sizes of 16, we drop the hint
+    // otherwise.
+    UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.bypassL1Attr() : UnitAttr();
+    rewriter.create<NVVM::CpAsyncOp>(
+        loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1);
+
+    // Drop the result token.
+    Value zero = rewriter.create<LLVM::ConstantOp>(
+        op->getLoc(), IntegerType::get(op.getContext(), 32),
+        rewriter.getI32IntegerAttr(0));
+    rewriter.replaceOp(op, zero);
+    return success();
+  }
+};
+
+struct NVGPUAsyncCreateGroupLowering
+    : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
+    // Drop the result token.
+    Value zero = rewriter.create<LLVM::ConstantOp>(
+        op->getLoc(), IntegerType::get(op.getContext(), 32),
+        rewriter.getI32IntegerAttr(0));
+    rewriter.replaceOp(op, zero);
+    return success();
+  }
+};
+
+struct NVGPUAsyncWaitLowering
+    : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // If numGroup is not present pick 0 as a conservative correct value.
+    int32_t numGroups = adaptor.numGroups() ? *adaptor.numGroups() : 0;
+    rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 } // namespace
+
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
-  patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter);
+  patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
+               NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>(
+      converter);
 }
 
 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 9d297f1bc35f3..efcb1300178b7 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -117,7 +117,6 @@ struct GPUInlinerInterface : public DialectInlinerInterface {
 
 void GPUDialect::initialize() {
   addTypes<AsyncTokenType>();
-  addTypes<DeviceAsyncTokenType>();
   addTypes<MMAMatrixType>();
   addOperations<
 #define GET_OP_LIST
@@ -140,9 +139,6 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
   // Handle 'async token' types.
   if (keyword == "async.token")
     return AsyncTokenType::get(context);
-  // Handle 'device async token' types.
-  if (keyword == "device.async.token")
-    return DeviceAsyncTokenType::get(context);
 
   if (keyword == "mma_matrix") {
     SMLoc beginLoc = parser.getNameLoc();
@@ -183,7 +179,6 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
-      .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; })
       .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
         os << "mma_matrix<";
         auto shape = fragTy.getShape();
@@ -1366,32 +1361,6 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyDimOfAllocOp>(context);
 }
 
-//===----------------------------------------------------------------------===//
-// GPU_DeviceAsyncCopyOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult DeviceAsyncCopyOp::verify() {
-  auto srcMemref = src().getType().cast<MemRefType>();
-  auto dstMemref = dst().getType().cast<MemRefType>();
-  unsigned workgroupAddressSpace = GPUDialect::getWorkgroupAddressSpace();
-  if (!isLastMemrefDimUnitStride(srcMemref))
-    return emitError("source memref most minor dim must have unit stride");
-  if (!isLastMemrefDimUnitStride(dstMemref))
-    return emitError("destination memref most minor dim must have unit stride");
-  if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace)
-    return emitError("destination memref must have memory space ")
-           << workgroupAddressSpace;
-  if (dstMemref.getElementType() != srcMemref.getElementType())
-    return emitError("source and destination must have the same element type");
-  if (size_t(srcMemref.getRank()) != srcIndices().size())
-    return emitOpError() << "expected " << srcMemref.getRank()
-                         << " source indices, got " << srcIndices().size();
-  if (size_t(dstMemref.getRank()) != dstIndices().size())
-    return emitOpError() << "expected " << dstMemref.getRank()
-                         << " destination indices, got " << dstIndices().size();
-  return success();
-}
-
 #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc"
 #include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc"
 

diff  --git a/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt
index d65c2dfd1fa49..64d232d6170f3 100644
--- a/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRNVGPU
   MLIRNVGPUIncGen
 
   LINK_LIBS PUBLIC
+  MLIRGPUOps
   MLIRIR
   MLIRSideEffectInterfaces
   )

diff  --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 6c4318f4d4967..ad5b6bd1db652 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -11,20 +11,81 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/NVGPU/NVGPUDialect.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
+using namespace mlir::nvgpu;
 
 #include "mlir/Dialect/NVGPU/NVGPUDialect.cpp.inc"
 
 void nvgpu::NVGPUDialect::initialize() {
+  addTypes<DeviceAsyncTokenType>();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/NVGPU/NVGPU.cpp.inc"
       >();
 }
 
+Type NVGPUDialect::parseType(DialectAsmParser &parser) const {
+  // Parse the main keyword for the type.
+  StringRef keyword;
+  if (parser.parseKeyword(&keyword))
+    return Type();
+  MLIRContext *context = getContext();
+  // Handle 'device async token' types.
+  if (keyword == "device.async.token")
+    return DeviceAsyncTokenType::get(context);
+
+  parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword);
+  return Type();
+}
+
+void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const {
+  TypeSwitch<Type>(type)
+      .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; })
+      .Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); });
+}
+//===----------------------------------------------------------------------===//
+// NVGPU_DeviceAsyncCopyOp
+//===----------------------------------------------------------------------===//
+
+/// Return true if the last dimension of the MemRefType has unit stride. Also
+/// return true for memrefs with no strides.
+static bool isLastMemrefDimUnitStride(MemRefType type) {
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  if (failed(getStridesAndOffset(type, strides, offset))) {
+    return false;
+  }
+  return strides.back() == 1;
+}
+
+LogicalResult DeviceAsyncCopyOp::verify() {
+  auto srcMemref = src().getType().cast<MemRefType>();
+  auto dstMemref = dst().getType().cast<MemRefType>();
+  unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
+  if (!isLastMemrefDimUnitStride(srcMemref))
+    return emitError("source memref most minor dim must have unit stride");
+  if (!isLastMemrefDimUnitStride(dstMemref))
+    return emitError("destination memref most minor dim must have unit stride");
+  if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace)
+    return emitError("destination memref must have memory space ")
+           << workgroupAddressSpace;
+  if (dstMemref.getElementType() != srcMemref.getElementType())
+    return emitError("source and destination must have the same element type");
+  if (size_t(srcMemref.getRank()) != srcIndices().size())
+    return emitOpError() << "expected " << srcMemref.getRank()
+                         << " source indices, got " << srcIndices().size();
+  if (size_t(dstMemref.getRank()) != dstIndices().size())
+    return emitOpError() << "expected " << dstMemref.getRank()
+                         << " destination indices, got " << dstIndices().size();
+  return success();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/NVGPU/NVGPU.cpp.inc"

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 771e2f10d3bdb..15f454bb348ce 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -488,35 +488,3 @@ gpu.module @test_module {
   }
 }
 
-// -----
-
-gpu.module @test_module {
-  // CHECK-LABEL: @async_cp(
-  // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: i64)
-  gpu.func @async_cp(
-    %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index) kernel {
-    // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<3 x i64>, array<3 x i64>)>
-    // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(2048 : index) : i64
-    // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX]], %[[S0]] : i64
-    // CHECK-DAG: %[[S1:.*]] = llvm.mlir.constant(128 : index) : i64
-    // CHECK-DAG: %[[FI0:.*]] = llvm.mul %[[IDX]], %[[S1]] : i64
-    // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[FI0]] : i64
-    // CHECK-DAG: %[[FI2:.*]] = llvm.add %[[FI1]], %[[IDX]] : i64
-    // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr<f32, 3>, i64) -> !llvm.ptr<f32, 3>
-    // CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr<f32, 3> to !llvm.ptr<i8, 3>
-    // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-    // CHECK-DAG: %[[S3:.*]] = llvm.mlir.constant(128 : index) : i64
-    // CHECK-DAG: %[[FI3:.*]] = llvm.mul %[[IDX]], %[[S3]]  : i64
-    // CHECK-DAG: %[[FI4:.*]] = llvm.add %[[FI3]], %[[IDX]]  : i64
-    // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-    // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-    // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr<i8> to !llvm.ptr<i8, 1>
-    // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16
-    %0 = gpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 : memref<128x128xf32> to memref<3x16x128xf32, 3>
-    // CHECK: nvvm.cp.async.commit.group
-    %1 = gpu.device_async_create_group %0
-    // CHECK: nvvm.cp.async.wait.group 1
-    gpu.device_async_wait %1 { numGroups = 1 : i32 }
-    gpu.return
-  }
-}

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
similarity index 81%
rename from mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir
rename to mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 345270ab5635d..249cb48944c88 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -182,4 +182,40 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
   // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
   // CHECK-COUNT-4: llvm.insertvalue {{.*}} : !llvm.array<4 x vector<1xf32>>
   return %d : vector<4x1xf32>
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-LABEL: @async_cp(
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index)
+func.func @async_cp(
+  %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index) {
+  // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+  // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<3 x i64>, array<3 x i64>)>
+  // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(2048 : index) : i64
+  // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64
+  // CHECK-DAG: %[[S1:.*]] = llvm.mlir.constant(128 : index) : i64
+  // CHECK-DAG: %[[FI0:.*]] = llvm.mul %[[IDX1]], %[[S1]] : i64
+  // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[FI0]] : i64
+  // CHECK-DAG: %[[FI2:.*]] = llvm.add %[[FI1]], %[[IDX1]] : i64
+  // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr<f32, 3>, i64) -> !llvm.ptr<f32, 3>
+  // CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr<f32, 3> to !llvm.ptr<i8, 3>
+  // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-DAG: %[[S3:.*]] = llvm.mlir.constant(128 : index) : i64
+  // CHECK-DAG: %[[FI3:.*]] = llvm.mul %[[IDX1]], %[[S3]]  : i64
+  // CHECK-DAG: %[[FI4:.*]] = llvm.add %[[FI3]], %[[IDX1]]  : i64
+  // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+  // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+  // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr<i8> to !llvm.ptr<i8, 1>
+  // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16
+  %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 : memref<128x128xf32> to memref<3x16x128xf32, 3>
+  // CHECK: nvvm.cp.async.commit.group
+  %1 = nvgpu.device_async_create_group %0
+  // CHECK: nvvm.cp.async.wait.group 1
+  nvgpu.device_async_wait %1 { numGroups = 1 : i32 }
+
+  // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1}
+  %2 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
+  return
+}
+

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index ff9def1e5f191..1e10349f63864 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -559,62 +559,6 @@ func.func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp
 
 // -----
 
-func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
-  // expected-error @+1 {{destination memref must have memory space 3}}
-  gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xf32>
-  return
-}
-
-// -----
-
-func.func @async_cp_memref_type(%dst : memref<16xi32, 3>, %src : memref<16xf32>, %i : index) -> () {
-  // expected-error @+1 {{source and destination must have the same element type}}
-  gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xi32, 3>
-  return
-}
-
-// -----
-
-func.func @async_cp_num_src_indices(%dst : memref<16xf32, 3>, %src : memref<16x16xf32>, %i : index) -> () {
-  // expected-error @+1 {{expected 2 source indices, got 1}}
-  gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16x16xf32> to memref<16xf32, 3>
-  return
-}
-
-// -----
-
-func.func @async_cp_num_dst_indices(%dst : memref<16x16xf32, 3>, %src : memref<16xf32>, %i : index) -> () {
-  // expected-error @+1 {{expected 2 destination indices, got 1}}
-  gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16x16xf32, 3>
-  return
-}
-
-// -----
-
-func.func @async_cp_num_src_stride(
-  %dst : memref<200x100xf32, 3>,
-  %src : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>,
-  %i : index) -> () {
-  // expected-error @+1 {{source memref most minor dim must have unit stride}}
-  gpu.device_async_copy %src[%i, %i], %dst[%i, %i], 16 :
-    memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>> to memref<200x100xf32, 3>
-  return
-}
-
-// -----
-
-func.func @async_cp_num_dst_stride(
-  %dst : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>, 3>,
-  %src : memref<200x100xf32>,
-  %i : index) -> () {
-  // expected-error @+1 {{destination memref most minor dim must have unit stride}}
-  gpu.device_async_copy %src[%i, %i], %dst[%i, %i], 16 :
-    memref<200x100xf32> to memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>, 3>
-  return
-}
-
-// -----
-
 // Number of symbol operand count less than memref symbol count.
 func.func @alloc() {
    // expected-error at +1 {{symbol operand count does not equal memref symbol count}}

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 518ca28236b04..f882a8e5d630a 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -277,18 +277,6 @@ module attributes {gpu.container_module} {
     return
   }
 
-  func.func @async_cp(%dst : memref<2x7x5xf32, 3>, %src : memref<4x5xf32>){
-    // CHECK-LABEL: func @async_cp
-    %c0 = arith.constant 0 : index
-    // CHECK: gpu.device_async_copy %{{.*}}[{{.*}}, {{.*}}], %{{.*}}[{{.*}}, {{.*}}, {{.*}}], 4 : memref<4x5xf32> to memref<2x7x5xf32, 3>
-    %0 = gpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 4 : memref<4x5xf32> to memref<2x7x5xf32, 3>
-    // CHECK: %{{.*}} = gpu.device_async_create_group
-    %token = gpu.device_async_create_group %0
-    // CHECK: gpu.device_async_wait %{{.*}} {numGroups = 1 : i32}
-    gpu.device_async_wait %token {numGroups = 1 : i32}
-    return
-  }
-
   // CHECK-LABEL: func @set_default_device
   func.func @set_default_device(%arg0: i32) {
     // CHECK: gpu.set_default_device

diff  --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
new file mode 100644
index 0000000000000..7a9acb4727d45
--- /dev/null
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
+  // expected-error @+1 {{destination memref must have memory space 3}}
+  nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xf32>
+  return
+}
+
+// -----
+
+func.func @async_cp_memref_type(%dst : memref<16xi32, 3>, %src : memref<16xf32>, %i : index) -> () {
+  // expected-error @+1 {{source and destination must have the same element type}}
+  nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xi32, 3>
+  return
+}
+
+// -----
+
+func.func @async_cp_num_src_indices(%dst : memref<16xf32, 3>, %src : memref<16x16xf32>, %i : index) -> () {
+  // expected-error @+1 {{expected 2 source indices, got 1}}
+  nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16x16xf32> to memref<16xf32, 3>
+  return
+}
+
+// -----
+
+func.func @async_cp_num_dst_indices(%dst : memref<16x16xf32, 3>, %src : memref<16xf32>, %i : index) -> () {
+  // expected-error @+1 {{expected 2 destination indices, got 1}}
+  nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16x16xf32, 3>
+  return
+}
+
+// -----
+
+func.func @async_cp_num_src_stride(
+  %dst : memref<200x100xf32, 3>,
+  %src : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>,
+  %i : index) -> () {
+  // expected-error @+1 {{source memref most minor dim must have unit stride}}
+  nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 16 :
+    memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>> to memref<200x100xf32, 3>
+  return
+}
+
+// -----
+
+func.func @async_cp_num_dst_stride(
+  %dst : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>, 3>,
+  %src : memref<200x100xf32>,
+  %i : index) -> () {
+  // expected-error @+1 {{destination memref most minor dim must have unit stride}}
+  nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 16 :
+    memref<200x100xf32> to memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>, 3>
+  return
+}

diff  --git a/mlir/test/Dialect/NVGPU/roundtrip.mlir b/mlir/test/Dialect/NVGPU/roundtrip.mlir
index 255694003eab0..524f1fd6907b7 100644
--- a/mlir/test/Dialect/NVGPU/roundtrip.mlir
+++ b/mlir/test/Dialect/NVGPU/roundtrip.mlir
@@ -13,8 +13,21 @@ func.func @ldmatrix(%arg0: memref<?x?xf16, 3>, %x: index, %y: index) {
 func.func @mma_sync(%arg0: vector<4x2xf16>,
                %arg1: vector<2x2xf16>,
                %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
-//       CHECK: nvgpu.mma.sync(%{{.*}}, %{{.*}}, %{{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+//       CHECK: nvgpu.mma.sync(%{{.*}}, %{{.*}}, %{{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   %d = nvgpu.mma.sync(%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} :
-    (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+    (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
+
+
+func.func @async_cp(%dst : memref<2x7x5xf32, 3>, %src : memref<4x5xf32>){
+  // CHECK-LABEL: func @async_cp
+  %c0 = arith.constant 0 : index
+  // CHECK: nvgpu.device_async_copy %{{.*}}[{{.*}}, {{.*}}], %{{.*}}[{{.*}}, {{.*}}, {{.*}}], 4 : memref<4x5xf32> to memref<2x7x5xf32, 3>
+  %0 = nvgpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 4 : memref<4x5xf32> to memref<2x7x5xf32, 3>
+  // CHECK: %{{.*}} = nvgpu.device_async_create_group
+  %token = nvgpu.device_async_create_group %0
+  // CHECK: nvgpu.device_async_wait %{{.*}} {numGroups = 1 : i32}
+  nvgpu.device_async_wait %token {numGroups = 1 : i32}
+  return
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e187d3b275338..2138b73100a77 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2128,6 +2128,7 @@ cc_library(
     hdrs = ["include/mlir/Dialect/NVGPU/NVGPUDialect.h"],
     includes = ["include"],
     deps = [
+        ":GPUDialect",
         ":IR",
         ":NVGPUIncGen",
         ":SideEffectInterfaces",
@@ -3715,6 +3716,7 @@ cc_library(
     includes = ["include"],
     deps = [
         ":ConversionPassIncGen",
+        ":GPUDialect",
         ":IR",
         ":LLVMCommonConversion",
         ":NVGPU",


        


More information about the Mlir-commits mailing list