[Mlir-commits] [mlir] 95cb986 - [mlir][NVGPU] Support cache all (.ca) in nvgpu.device_async_copy
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Apr 18 05:01:37 PDT 2023
Author: Nicolas Vasilache
Date: 2023-04-18T05:00:53-07:00
New Revision: 95cb9862a8dcd3b8e9cdf0a27b5eafb910c9e983
URL: https://github.com/llvm/llvm-project/commit/95cb9862a8dcd3b8e9cdf0a27b5eafb910c9e983
DIFF: https://github.com/llvm/llvm-project/commit/95cb9862a8dcd3b8e9cdf0a27b5eafb910c9e983.diff
LOG: [mlir][NVGPU] Support cache all (.ca) in nvgpu.device_async_copy
This patch adds support for cache all (.ca) in conversion from nvgpu-to-nvvm for inline asm `cp.async`.
For sizes other than 16 bytes cp.async cache global is not allowed and cache all is required to generate a valid ptx.
Differential revision: https://reviews.llvm.org/D148604
Authored-by: Manish Gupta <manigupta at google.com>
Added:
Modified:
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index e88bb05d0b0b9..4a923fac76c88 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -367,7 +367,9 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
ConversionPatternRewriter &rewriter) {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
- const char *asmStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";
+
+ const char *cpAsyncCgStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";
+ const char *cpAsyncCaStr = "cp.async.ca.shared.global [$0], [$1], $2, $3;\n";
const char *asmConstraints = "r,l,n,r";
Value c3I32 = rewriter.create<LLVM::ConstantOp>(
@@ -382,6 +384,19 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
SmallVector<Value> asmVals{dstPtr, srcPtr, dstBytes, srcBytes};
+ // Pick the right asm string based on the dstBytes which is a compile-time
+ // constant.
+ auto dstByteConstOp =
+ dyn_cast<mlir::LLVM::ConstantOp>(dstBytes.getDefiningOp());
+ auto dstByteAttr = dstByteConstOp.getValue().dyn_cast<mlir::IntegerAttr>();
+ int64_t dstByteVal = dstByteAttr.getValue().getSExtValue();
+
+ assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) &&
+ "cp.async byte copy size must be 4, 8 or 16");
+ // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
+ // 16 dst bytes.
+ const char *asmStr = (dstByteVal == 16) ? cpAsyncCgStr : cpAsyncCaStr;
+
rewriter.create<LLVM::InlineAsmOp>(
loc, LLVM::LLVMVoidType::get(rewriter.getContext()),
/*operands=*/asmVals,
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dfbbd15393b04..54b71389d8ee5 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -295,12 +295,12 @@ func.func @async_cp_i4(
// -----
-// CHECK-LABEL: @async_cp_zfill(
+// CHECK-LABEL: @async_cp_zfill_f32_align4(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
-func.func @async_cp_zfill(
+func.func @async_cp_zfill_f32_align4(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
-
- // CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
+ // CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
// CHECK: nvvm.cp.async.commit.group
%1 = nvgpu.device_async_create_group %0
@@ -312,6 +312,24 @@ func.func @async_cp_zfill(
// -----
+// CHECK-LABEL: @async_cp_zfill_f32_align1(
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
+func.func @async_cp_zfill_f32_align1(
+ %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
+ // CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
+ %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
+ // CHECK: nvvm.cp.async.commit.group
+ %1 = nvgpu.device_async_create_group %0
+ // CHECK: nvvm.cp.async.wait.group 1
+ nvgpu.device_async_wait %1 { numGroups = 1 : i32 }
+
+ return
+}
+
+// -----
+
+
// CHECK-LABEL: func @mma_sp_sync_f16_16832(
func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
%arg1: vector<4x2xf16>,
More information about the Mlir-commits
mailing list