[Mlir-commits] [mlir] 5ab04bc - [mlir][gpu] Add device side async copy operations

Thomas Raoux llvmlistbot at llvm.org
Thu Feb 10 17:26:12 PST 2022


Author: Thomas Raoux
Date: 2022-02-10T17:25:59-08:00
New Revision: 5ab04bc068d546a12e75b83970a8f737d6ebd813

URL: https://github.com/llvm/llvm-project/commit/5ab04bc068d546a12e75b83970a8f737d6ebd813
DIFF: https://github.com/llvm/llvm-project/commit/5ab04bc068d546a12e75b83970a8f737d6ebd813.diff

LOG: [mlir][gpu] Add device side async copy operations

Add new operations to the gpu dialect to represent device side
asynchronous copies. This also add the lowering of those operations to
nvvm dialect.
Those ops are meant to be low level and map directly to llvm dialects
like nvvm or rocdl.

We can further add higher level of abstraction by building on top of
those operations.
This has been discuss here:
https://discourse.llvm.org/t/modeling-gpu-async-copy-ampere-feature/4924

Differential Revision: https://reviews.llvm.org/D119191

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUBase.td
    mlir/include/mlir/Dialect/GPU/GPUDialect.h
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/GPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td
index ecc457e4bb941..89dbde5362b85 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td
@@ -60,6 +60,13 @@ def GPU_AsyncToken : DialectType<
   GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::AsyncTokenType>()">, "async token type">,
              BuildableType<"mlir::gpu::AsyncTokenType::get($_builder.getContext())">;
 
+/// Device-side synchronization token.
+def GPU_DeviceAsyncToken : DialectType<
+  GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::DeviceAsyncTokenType>()">,
+   "device async token type">,
+   BuildableType<
+     "mlir::gpu::DeviceAsyncTokenType::get($_builder.getContext())">;
+
 // Predicat to check if type is gpu::MMAMatrixType.
 def IsMMAMatrixTypePred : CPred<"$_self.isa<::mlir::gpu::MMAMatrixType>()">;
 

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index 0ec9a27b38bfd..7712c31894534 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -46,6 +46,14 @@ class AsyncTokenType
   using Base::Base;
 };
 
+/// Device-side token storage type. There is only one type of device-side token.
+class DeviceAsyncTokenType
+    : public Type::TypeBase<DeviceAsyncTokenType, Type, TypeStorage> {
+public:
+  // Used for generic hooks in TypeBase.
+  using Base::Base;
+};
+
 /// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
 /// and type.
 struct MMAMatrixStorageType : public TypeStorage {

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 877e20cd8931a..f6a5d4ca7eb8c 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -1226,4 +1226,105 @@ def GPU_SubgroupMmaElementwiseOp : GPU_Op<"subgroup_mma_elementwise",
   }];
 }
 
+def GPU_DeviceAsyncCopyOp : GPU_Op<"device_async_copy",
+  [AttrSizedOperandSegments]> {
+  let summary = "device-side asynchronous copy";
+  let description = [{
+    The `gpu.device_async_copy` op initiates an asynchronous copy operation of
+    `$size` elements from source to the destination without blocking the thread.
+    The destination has to be in shared memory.
+
+    This is memory access will be pending to be added to a group.
+
+    This op is meant to be used with `gpu.device_async_create_group` and
+    `gpu.device_async_wait` to synchronize copies as explained in those ops
+    descriptions.
+
+    In order to do a copy and wait for the result we need the following
+    combination:
+    ```
+    // copy 1.
+    %cp1 = gpu.device_async_copy %A[%c0], %B[%c0], 4 :memref<16xf32> to memref<16xf32, 3>
+    // copy 2.
+    %cp2 = gpu.device_async_copy %C[%c0], %D[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
+    // group 1 contains copy 1 and copy 2.
+    %token1 = gpu.device_async_create_group %cp1, %cp2
+    // copy 3.
+    %cp3 = gpu.device_async_copy %E[%c0], %F[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
+    // group 2 contains copy 3.
+    %token2 = gpu.device_async_create_group %cp3
+    // after the wait copy 1 and copy 2 are complete.
+    gpu.device_async_wait %token1
+    // after the wait copy 3 is complete.
+    gpu.device_async_wait %token2
+    ```
+
+    Example:
+
+    ```mlir
+    %0 = gpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 4 :
+      memref<4x5xf32> to memref<2x7x5xf32, 3>
+    ```
+  }];
+  let results = (outs GPU_DeviceAsyncToken:$asyncToken);
+  let arguments = (ins Arg<AnyMemRef, "", [MemWrite]>:$dst,
+                       Variadic<Index>:$dstIndices,
+                       Arg<AnyMemRef, "", [MemRead]>:$src,
+                       Variadic<Index>:$srcIndices,
+                       IndexAttr:$numElements);
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $numElements
+      attr-dict `:` type($src) `to` type($dst)
+  }];
+  let hasVerifier = 1;
+}
+
+def GPU_DeviceAsyncCreateGroupOp : GPU_Op<"device_async_create_group", []> {
+  let summary = "device side asynchronous create group operation";
+  let description = [{
+  The `gpu.device_async_create_group` op creates a group of memory accesses
+  containing all the pending `device_async_copy` operations associated with
+  argument tokens. Each token can only be part of one group.
+
+  It returns a token that can be use to wait until the group fully completes.
+
+  This is meant to be used with `gpu.device_async_wait` to synchronize copies
+  as explained in those ops descriptions.
+
+  Groups are executed in the order they are created.
+
+  Example:
+
+  ```mlir
+  %0 = gpu.device_async_create_group
+  ```
+  }];
+  let results = (outs GPU_DeviceAsyncToken:$asyncToken);
+  let arguments = (ins Variadic<GPU_DeviceAsyncToken>:$inputTokens);
+  let assemblyFormat = [{
+    $inputTokens attr-dict
+  }];
+}
+
+def GPU_DeviceAsyncWaitOp : GPU_Op<"device_async_wait", []> {
+  let summary = "Wait for async gpu ops to complete.";
+  let description = [{
+  The `gpu.device_async_wait` op will block the execution thread until the group
+  associated with the source token is fully completed.
+
+    The optional `$numGroup` attribute gives a lower bound of the number of
+    groups uncompleted when the wait can unblock the thread.
+  Example:
+
+  ```mlir
+  gpu.device_async_wait %0
+  ```
+  }];
+  let arguments = (ins GPU_DeviceAsyncToken:$asyncDependencies,
+                       OptionalAttr<I32Attr>:$numGroups);
+  let assemblyFormat = [{
+    $asyncDependencies attr-dict
+  }];
+}
+
 #endif // GPU_OPS

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 67cc76e8aab53..e10f9f770fa28 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -40,6 +40,14 @@ using namespace mlir;
 
 namespace {
 
+/// NVVM memory space identifiers.
+enum NVVMMemorySpace {
+  /// Global memory space identifier.
+  kGlobalMemorySpace = 1,
+  /// Shared memory space identifier.
+  kSharedMemorySpace = 3
+};
+
 /// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
   switch (mode) {
@@ -122,6 +130,82 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   }
 };
 
+struct GPUAsyncCopyLowering
+    : public ConvertOpToLLVMPattern<gpu::DeviceAsyncCopyOp> {
+  using ConvertOpToLLVMPattern<gpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto dstMemrefType = op.dst().getType().cast<MemRefType>();
+    Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.dst(),
+                                        adaptor.dstIndices(), rewriter);
+    auto i8Ty = IntegerType::get(op.getContext(), 8);
+    auto dstPointerType =
+        LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
+    dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
+
+    auto srcMemrefType = op.src().getType().cast<MemRefType>();
+
+    Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.src(),
+                                        adaptor.srcIndices(), rewriter);
+    auto srcPointerType =
+        LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
+    scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
+    // Intrinsics takes a global pointer so we need an address space cast.
+    auto srcPointerGlobalType =
+        LLVM::LLVMPointerType::get(i8Ty, NVVMMemorySpace::kGlobalMemorySpace);
+    scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
+                                                    scrPtr);
+    int64_t numElements = adaptor.numElements().getZExtValue();
+    int64_t sizeInBytes =
+        (dstMemrefType.getElementTypeBitWidth() / 8) * numElements;
+    rewriter.create<NVVM::CpAsyncOp>(loc, dstPtr, scrPtr,
+                                     rewriter.getI32IntegerAttr(sizeInBytes));
+
+    // Drop the result token.
+    Value zero = rewriter.create<LLVM::ConstantOp>(
+        op->getLoc(), IntegerType::get(op.getContext(), 32),
+        rewriter.getI32IntegerAttr(0));
+    rewriter.replaceOp(op, zero);
+    return success();
+  }
+};
+
+struct GPUAsyncCreateGroupLowering
+    : public ConvertOpToLLVMPattern<gpu::DeviceAsyncCreateGroupOp> {
+  using ConvertOpToLLVMPattern<
+      gpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
+    // Drop the result token.
+    Value zero = rewriter.create<LLVM::ConstantOp>(
+        op->getLoc(), IntegerType::get(op.getContext(), 32),
+        rewriter.getI32IntegerAttr(0));
+    rewriter.replaceOp(op, zero);
+    return success();
+  }
+};
+
+struct GPUAsyncWaitLowering
+    : public ConvertOpToLLVMPattern<gpu::DeviceAsyncWaitOp> {
+  using ConvertOpToLLVMPattern<gpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // If numGroup is not present pick 0 as a conservative correct value.
+    int32_t numGroups = adaptor.numGroups() ? *adaptor.numGroups() : 0;
+    rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 /// Import the GPU Ops to NVVM Patterns.
 #include "GPUToNVVM.cpp.inc"
 
@@ -159,7 +243,11 @@ struct LowerGpuOpsToNVVMOpsPass
         return llvm::None;
       return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
     });
-
+    /// device-side async tokens cannot be materialized in nvvm. We just convert
+    /// them to a dummy i32 type in order to easily drop them during conversion.
+    converter.addConversion([&](gpu::DeviceAsyncTokenType type) -> Type {
+      return converter.convertType(IntegerType::get(type.getContext(), 32));
+    });
     // Lowering for MMAMatrixType.
     converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
       return convertMMAToLLVMType(type);
@@ -259,6 +347,8 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                    "__nv_sqrt");
   patterns.add<OpToFuncCallLowering<math::TanhOp>>(converter, "__nv_tanhf",
                                                    "__nv_tanh");
+  patterns.add<GPUAsyncCopyLowering, GPUAsyncCreateGroupLowering,
+               GPUAsyncWaitLowering>(converter);
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index a75724421d7a7..3825aec5d30e2 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -117,6 +117,7 @@ struct GPUInlinerInterface : public DialectInlinerInterface {
 
 void GPUDialect::initialize() {
   addTypes<AsyncTokenType>();
+  addTypes<DeviceAsyncTokenType>();
   addTypes<MMAMatrixType>();
   addOperations<
 #define GET_OP_LIST
@@ -139,6 +140,9 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
   // Handle 'async token' types.
   if (keyword == "async.token")
     return AsyncTokenType::get(context);
+  // Handle 'device async token' types.
+  if (keyword == "device.async.token")
+    return DeviceAsyncTokenType::get(context);
 
   if (keyword == "mma_matrix") {
     SMLoc beginLoc = parser.getNameLoc();
@@ -179,6 +183,7 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
+      .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; })
       .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
         os << "mma_matrix<";
         auto shape = fragTy.getShape();
@@ -1203,6 +1208,41 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyDimOfAllocOp>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// GPU_DeviceAsyncCopyOp
+//===----------------------------------------------------------------------===//
+
+/// Return true if the last dimension of the MemRefType has unit stride. Also
+/// return true for memrefs with no strides.
+static bool isLastMemrefDimUnitStride(MemRefType type) {
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  auto successStrides = getStridesAndOffset(type, strides, offset);
+  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
+}
+
+LogicalResult DeviceAsyncCopyOp::verify() {
+  auto srcMemref = src().getType().cast<MemRefType>();
+  auto dstMemref = dst().getType().cast<MemRefType>();
+  unsigned workgroupAddressSpace = GPUDialect::getWorkgroupAddressSpace();
+  if (!isLastMemrefDimUnitStride(srcMemref))
+    return emitError("source memref most minor dim must have unit stride");
+  if (!isLastMemrefDimUnitStride(dstMemref))
+    return emitError("destination memref most minor dim must have unit stride");
+  if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace)
+    return emitError("destination memref must have memory space ")
+           << workgroupAddressSpace;
+  if (dstMemref.getElementType() != srcMemref.getElementType())
+    return emitError("source and destination must have the same element type");
+  if (size_t(srcMemref.getRank()) != srcIndices().size())
+    return emitOpError() << "expected " << srcMemref.getRank()
+                         << " source indices, got " << srcIndices().size();
+  if (size_t(dstMemref.getRank()) != dstIndices().size())
+    return emitOpError() << "expected " << dstMemref.getRank()
+                         << " destination indices, got " << dstIndices().size();
+  return success();
+}
+
 #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc"
 #include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc"
 

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 8219d2204772f..3c75bb944f08f 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -479,3 +479,36 @@ gpu.module @test_module {
     gpu.return
   }
 }
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: @async_cp(
+  // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: i64)
+  gpu.func @async_cp(
+    %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index) kernel {
+    // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<3 x i64>, array<3 x i64>)>
+    // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(2048 : index) : i64
+    // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX]], %[[S0]] : i64
+    // CHECK-DAG: %[[S1:.*]] = llvm.mlir.constant(128 : index) : i64
+    // CHECK-DAG: %[[FI0:.*]] = llvm.mul %[[IDX]], %[[S1]] : i64
+    // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[FI0]] : i64
+    // CHECK-DAG: %[[FI2:.*]] = llvm.add %[[FI1]], %[[IDX]] : i64
+    // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr<f32, 3>, i64) -> !llvm.ptr<f32, 3>
+    // CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr<f32, 3> to !llvm.ptr<i8, 3>
+    // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK-DAG: %[[S3:.*]] = llvm.mlir.constant(128 : index) : i64
+    // CHECK-DAG: %[[FI3:.*]] = llvm.mul %[[IDX]], %[[S3]]  : i64
+    // CHECK-DAG: %[[FI4:.*]] = llvm.add %[[FI3]], %[[IDX]]  : i64
+    // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+    // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+    // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr<i8> to !llvm.ptr<i8, 1>
+    // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16
+    %0 = gpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 : memref<128x128xf32> to memref<3x16x128xf32, 3>
+    // CHECK: nvvm.cp.async.commit.group
+    %1 = gpu.device_async_create_group %0
+    // CHECK: nvvm.cp.async.wait.group 1
+    gpu.device_async_wait %1 { numGroups = 1 : i32 }
+    gpu.return
+  }
+}

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 9ddda4e1dcf37..3b2b94b203962 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -555,3 +555,59 @@ func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp">, %
     %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
     return
 }
+
+// -----
+
+func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
+  // expected-error @+1 {{destination memref must have memory space 3}}
+  gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xf32>
+  return
+}
+
+// -----
+
+func @async_cp_memref_type(%dst : memref<16xi32, 3>, %src : memref<16xf32>, %i : index) -> () {
+  // expected-error @+1 {{source and destination must have the same element type}}
+  gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xi32, 3>
+  return
+}
+
+// -----
+
+func @async_cp_num_src_indices(%dst : memref<16xf32, 3>, %src : memref<16x16xf32>, %i : index) -> () {
+  // expected-error @+1 {{expected 2 source indices, got 1}}
+  gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16x16xf32> to memref<16xf32, 3>
+  return
+}
+
+// -----
+
+func @async_cp_num_dst_indices(%dst : memref<16x16xf32, 3>, %src : memref<16xf32>, %i : index) -> () {
+  // expected-error @+1 {{expected 2 destination indices, got 1}}
+  gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16x16xf32, 3>
+  return
+}
+
+// -----
+
+func @async_cp_num_src_stride(
+  %dst : memref<200x100xf32, 3>,
+  %src : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>,
+  %i : index) -> () {
+  // expected-error @+1 {{source memref most minor dim must have unit stride}}
+  gpu.device_async_copy %src[%i, %i], %dst[%i, %i], 16 :
+    memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>> to memref<200x100xf32, 3>
+  return
+}
+
+// -----
+
+func @async_cp_num_dst_stride(
+  %dst : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>, 3>,
+  %src : memref<200x100xf32>,
+  %i : index) -> () {
+  // expected-error @+1 {{destination memref most minor dim must have unit stride}}
+  gpu.device_async_copy %src[%i, %i], %dst[%i, %i], 16 :
+    memref<200x100xf32> to memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>, 3>
+  return
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 6ab6e9eb0a2d6..c1c5ff5570832 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -240,4 +240,16 @@ module attributes {gpu.container_module} {
     %3 = gpu.subgroup_mma_elementwise maxf %2, %1 : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
     return
   }
+
+  func @async_cp(%dst : memref<2x7x5xf32, 3>, %src : memref<4x5xf32>){
+    // CHECK-LABEL: func @async_cp
+    %c0 = arith.constant 0 : index
+    // CHECK: gpu.device_async_copy %{{.*}}[{{.*}}, {{.*}}], %{{.*}}[{{.*}}, {{.*}}, {{.*}}], 4 : memref<4x5xf32> to memref<2x7x5xf32, 3>
+    %0 = gpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 4 : memref<4x5xf32> to memref<2x7x5xf32, 3>
+    // CHECK: %{{.*}} = gpu.device_async_create_group
+    %token = gpu.device_async_create_group %0
+    // CHECK: gpu.device_async_wait %{{.*}} {numGroups = 1 : i32}
+    gpu.device_async_wait %token {numGroups = 1 : i32}
+    return
+  }
 }


        


More information about the Mlir-commits mailing list