[Mlir-commits] [mlir] 95cb986 - [mlir][NVGPU] Support cache all (.ca) in nvgpu.device_async_copy

Nicolas Vasilache llvmlistbot at llvm.org
Tue Apr 18 05:01:37 PDT 2023


Author: Nicolas Vasilache
Date: 2023-04-18T05:00:53-07:00
New Revision: 95cb9862a8dcd3b8e9cdf0a27b5eafb910c9e983

URL: https://github.com/llvm/llvm-project/commit/95cb9862a8dcd3b8e9cdf0a27b5eafb910c9e983
DIFF: https://github.com/llvm/llvm-project/commit/95cb9862a8dcd3b8e9cdf0a27b5eafb910c9e983.diff

LOG: [mlir][NVGPU] Support cache all (.ca) in nvgpu.device_async_copy

This patch adds support for cache all (.ca) in conversion from nvgpu-to-nvvm for inline asm `cp.async`.

For sizes other than 16 bytes cp.async cache global is not allowed and cache all is required to generate a valid ptx.

Differential revision: https://reviews.llvm.org/D148604

Authored-by: Manish Gupta <manigupta at google.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index e88bb05d0b0b9..4a923fac76c88 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -367,7 +367,9 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
                                   ConversionPatternRewriter &rewriter) {
   auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
                                                   LLVM::AsmDialect::AD_ATT);
-  const char *asmStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";
+
+  const char *cpAsyncCgStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";
+  const char *cpAsyncCaStr = "cp.async.ca.shared.global [$0], [$1], $2, $3;\n";
   const char *asmConstraints = "r,l,n,r";
 
   Value c3I32 = rewriter.create<LLVM::ConstantOp>(
@@ -382,6 +384,19 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
 
   SmallVector<Value> asmVals{dstPtr, srcPtr, dstBytes, srcBytes};
 
+  // Pick the right asm string based on the dstBytes which is a compile-time
+  // constant.
+  auto dstByteConstOp =
+      dyn_cast<mlir::LLVM::ConstantOp>(dstBytes.getDefiningOp());
+  auto dstByteAttr = dstByteConstOp.getValue().dyn_cast<mlir::IntegerAttr>();
+  int64_t dstByteVal = dstByteAttr.getValue().getSExtValue();
+
+  assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) &&
+          "cp.async byte copy size must be 4, 8 or 16");
+  // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
+  // 16 dst bytes.
+  const char *asmStr = (dstByteVal == 16) ? cpAsyncCgStr : cpAsyncCaStr;
+
   rewriter.create<LLVM::InlineAsmOp>(
       loc, LLVM::LLVMVoidType::get(rewriter.getContext()),
       /*operands=*/asmVals,

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dfbbd15393b04..54b71389d8ee5 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -295,12 +295,12 @@ func.func @async_cp_i4(
 
 // -----
 
-// CHECK-LABEL: @async_cp_zfill(
+// CHECK-LABEL: @async_cp_zfill_f32_align4(
 // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
-func.func @async_cp_zfill(
+func.func @async_cp_zfill_f32_align4(
   %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
-
-  // CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
+  // CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
   %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
   // CHECK: nvvm.cp.async.commit.group
   %1 = nvgpu.device_async_create_group %0
@@ -312,6 +312,24 @@ func.func @async_cp_zfill(
 
 // -----
 
+// CHECK-LABEL: @async_cp_zfill_f32_align1(
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
+func.func @async_cp_zfill_f32_align1(
+  %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
+  // CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
+  %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
+  // CHECK: nvvm.cp.async.commit.group
+  %1 = nvgpu.device_async_create_group %0
+  // CHECK: nvvm.cp.async.wait.group 1
+  nvgpu.device_async_wait %1 { numGroups = 1 : i32 }
+
+  return
+}
+
+// -----
+
+
 // CHECK-LABEL: func @mma_sp_sync_f16_16832(
 func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
                                  %arg1: vector<4x2xf16>,


        


More information about the Mlir-commits mailing list