[Mlir-commits] [mlir] 2c57396 - [mlir][nvgpu] Implement `nvgpu.device_async_copy` by NVVMToLLVM Pass
Guray Ozen
llvmlistbot at llvm.org
Tue Jul 11 03:18:33 PDT 2023
Author: Guray Ozen
Date: 2023-07-11T12:18:28+02:00
New Revision: 2c5739675cf8fc9191b8735be7c012e846fa49de
URL: https://github.com/llvm/llvm-project/commit/2c5739675cf8fc9191b8735be7c012e846fa49de
DIFF: https://github.com/llvm/llvm-project/commit/2c5739675cf8fc9191b8735be7c012e846fa49de.diff
LOG: [mlir][nvgpu] Implement `nvgpu.device_async_copy` by NVVMToLLVM Pass
`nvgpu.device_async_copy` is lowered into `cp.async` PTX instruction. However, NVPTX backend does not support its all mode especially when zero padding is needed. Therefore, current MLIR implementation genereates inline assembly for that.
This work simplifies PTX generation for `nvgpu.device_async_copy`, and implements it by `NVVMToLLVM` Pass.
Depends on D154060
Reviewed By: nicolasvasilache, manishucsd
Differential Revision: https://reviews.llvm.org/D154345
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/nvvm-typed-pointers.mlir
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Dialect/NVGPU/invalid.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 01294225a64d29..5dd37306990698 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -506,12 +506,30 @@ def NVVM_SyncWarpOp :
let assemblyFormat = "$mask attr-dict `:` type($mask)";
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/#id62
+def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">;
+def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">;
+def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">;
+def LoadCacheModifierLU : I32EnumAttrCase<"LU", 3, "lu">;
+def LoadCacheModifierCV : I32EnumAttrCase<"CV", 4, "cv">;
-def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
+/// Enum attribute of the
diff erent kinds.
+def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
+ "NVVM load cache modifier kind",
+ [LoadCacheModifierCA, LoadCacheModifierCG, LoadCacheModifierCS,
+ LoadCacheModifierLU, LoadCacheModifierCV]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+
+def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
+
+def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins LLVM_i8Ptr_shared:$dst,
LLVM_i8Ptr_global:$src,
I32Attr:$size,
- OptionalAttr<UnitAttr>:$bypass_l1)> {
+ LoadCacheModifierAttr:$modifier,
+ Optional<LLVM_Type>:$cpSize)> {
string llvmBuilder = [{
llvm::Intrinsic::ID id;
switch ($size) {
@@ -522,18 +540,40 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8;
break;
case 16:
- if(static_cast<bool>($bypass_l1))
+ if($modifier == NVVM::LoadCacheModifierKind::CG)
id = llvm::Intrinsic::nvvm_cp_async_cg_shared_global_16;
- else
+ else if($modifier == NVVM::LoadCacheModifierKind::CA)
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
+ else
+ llvm_unreachable("unsupported cache modifier");
break;
default:
llvm_unreachable("unsupported async copy size");
}
createIntrinsicCall(builder, id, {$dst, $src});
}];
- let assemblyFormat = "$dst `,` $src `,` $size attr-dict `:` type(operands)";
+ let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)";
let hasVerifier = 1;
+ let extraClassDeclaration = [{
+ bool hasIntrinsic() { if(getCpSize()) return false; return true; }
+
+ void getAsmValues(RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues) {
+ asmValues.push_back({getDst(), PTXRegisterMod::Read});
+ asmValues.push_back({getSrc(), PTXRegisterMod::Read});
+ asmValues.push_back({makeConstantI32(rewriter, getSize()), PTXRegisterMod::Read});
+ asmValues.push_back({getCpSize(), PTXRegisterMod::Read});
+ }
+ }];
+ let extraClassDefinition = [{
+ const char* $cppClass::getPtx() {
+ if(getModifier() == NVVM::LoadCacheModifierKind::CG)
+ return "cp.async.cg.shared.global [%0], [%1], %2, %3;\n";
+ if(getModifier() == NVVM::LoadCacheModifierKind::CA)
+ return "cp.async.ca.shared.global [%0], [%1], %2, %3;\n";
+ llvm_unreachable("unsupported cache modifier");
+ }
+ }];
}
def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> {
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 3d898e5af19c18..5694e7c28de67f 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -361,51 +361,6 @@ struct ConvertNVGPUToNVVMPass
}
};
-static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
- Value dstBytes, Value srcElements,
- mlir::MemRefType elementType,
- ConversionPatternRewriter &rewriter) {
- auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
- LLVM::AsmDialect::AD_ATT);
-
- const char *cpAsyncCgStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";
- const char *cpAsyncCaStr = "cp.async.ca.shared.global [$0], [$1], $2, $3;\n";
- const char *asmConstraints = "r,l,n,r";
-
- Value c3I32 = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
- Value bitwidth = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
- rewriter.getI32IntegerAttr(elementType.getElementTypeBitWidth()));
- Value srcElementsI32 =
- rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcElements);
- Value srcBytes = rewriter.create<LLVM::LShrOp>(
- loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32), c3I32);
-
- SmallVector<Value> asmVals{dstPtr, srcPtr, dstBytes, srcBytes};
-
- // Pick the right asm string based on the dstBytes which is a compile-time
- // constant.
- auto dstByteConstOp =
- dyn_cast<mlir::LLVM::ConstantOp>(dstBytes.getDefiningOp());
- auto dstByteAttr = dyn_cast<mlir::IntegerAttr>(dstByteConstOp.getValue());
- int64_t dstByteVal = dstByteAttr.getValue().getSExtValue();
-
- assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) &&
- "cp.async byte copy size must be 4, 8 or 16");
- // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
- // 16 dst bytes.
- const char *asmStr = (dstByteVal == 16) ? cpAsyncCgStr : cpAsyncCaStr;
-
- rewriter.create<LLVM::InlineAsmOp>(
- loc, LLVM::LLVMVoidType::get(rewriter.getContext()),
- /*operands=*/asmVals,
- /*asm_string=*/asmStr,
- /*constraints=*/asmConstraints, /*has_side_effects=*/true,
- /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
- /*operand_attrs=*/ArrayAttr());
-}
-
/// Returns the constraints for the sparse MMA inline assembly instruction.
static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
unsigned matBSize,
@@ -620,30 +575,38 @@ struct NVGPUAsyncCopyLowering
int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
- // bypass L1 is only supported for byte sizes of 16, we drop the hint
- // otherwise.
- UnitAttr bypassL1 =
- sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr();
-
- // When the optional SrcElements argument is present, the source (global
- // memory) of CpAsyncOp is read only for SrcElements number of elements. The
- // rest of the DstElements in the destination (shared memory) are filled
- // with zeros.
- if (op.getSrcElements())
- emitCpAsyncOpZfillAsm(loc, dstPtr, scrPtr,
- rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(),
- rewriter.getI32IntegerAttr(sizeInBytes)),
- adaptor.getSrcElements(), srcMemrefType, rewriter);
-
// When the optional SrcElements argument is *not* present, the regular
// CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
- // memory) to fill DstElements number of elements in the destination (shared
- // memory).
- else
- rewriter.create<NVVM::CpAsyncOp>(loc, dstPtr, scrPtr,
- rewriter.getI32IntegerAttr(sizeInBytes),
- bypassL1);
+ // memory) to fill DstElements number of elements in the destination
+ // (shared memory).
+ Value srcBytes = adaptor.getSrcElements();
+ if (srcBytes) {
+ // When the optional SrcElements argument is present, the source (global
+ // memory) of CpAsyncOp is read only for SrcElements number of elements.
+ // The rest of the DstElements in the destination (shared memory) are
+ // filled with zeros.
+ Value c3I32 = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
+ Value bitwidth = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
+ Value srcElementsI32 =
+ rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcBytes);
+ srcBytes = rewriter.create<LLVM::LShrOp>(
+ loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32),
+ c3I32);
+ }
+ // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
+ // 16 dst bytes.
+ NVVM::LoadCacheModifierKind cacheModifier =
+ (op.getBypassL1().value_or(false) && sizeInBytes == 16)
+ ? NVVM::LoadCacheModifierKind::CG
+ : NVVM::LoadCacheModifierKind::CA;
+
+ rewriter.create<NVVM::CpAsyncOp>(
+ loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
+ NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
+ srcBytes);
// Drop the result token.
Value zero = rewriter.create<LLVM::ConstantOp>(
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 9bec96c4d4bc5e..397fca52a93455 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -68,10 +68,13 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
LogicalResult CpAsyncOp::verify() {
+ if (getModifier() != LoadCacheModifierKind::CG &&
+ getModifier() != LoadCacheModifierKind::CA)
+ return emitError("Only CG and CA cache modifiers are supported.");
if (getSize() != 4 && getSize() != 8 && getSize() != 16)
return emitError("expected byte size to be either 4, 8 or 16.");
- if (getBypassL1() && getSize() != 16)
- return emitError("bypass l1 is only support for 16 bytes copy.");
+ if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
+ return emitError("CG cache modifier is only support for 16 bytes copy.");
return success();
}
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 0472d27906ead9..08384debaaf7c0 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -83,6 +83,20 @@ LogicalResult DeviceAsyncCopyOp::verify() {
return emitOpError() << "expected " << dstMemref.getRank()
<< " destination indices, got "
<< getDstIndices().size();
+ if (getBypassL1().has_value()) {
+ int64_t dstElements = getDstElements().getZExtValue();
+ int64_t sizeInBytes =
+ (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
+ int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
+ if (getBypassL1().value() && sizeInBytes != 16) {
+ return emitOpError() << "bypassL1 does not satify alignment for "
+ << dstMemref << " with destination element "
+ << dstElements
+ << ". Unset bypassL1, or set "
+ "destination element to "
+ << req;
+ }
+ }
return success();
}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 54b71389d8ee56..41369da08b8dcf 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -258,14 +258,14 @@ func.func @async_cp(
// CHECK-DAG: %[[FI4:.*]] = llvm.add %[[FI3]], %[[IDX1]] : i64
// CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr, i64) -> !llvm.ptr
// CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr<1>
- // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16
+ // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16, cache = ca
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 : memref<128x128xf32> to memref<3x16x128xf32, 3>
// CHECK: nvvm.cp.async.commit.group
%1 = nvgpu.device_async_create_group %0
// CHECK: nvvm.cp.async.wait.group 1
nvgpu.device_async_wait %1 { numGroups = 1 : i32 }
- // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1}
+ // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg
%2 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
return
}
@@ -288,7 +288,7 @@ func.func @async_cp_i4(
// CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64
// CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr
// CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr<1>
- // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16
+ // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16, cache = ca
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 32 : memref<128x64xi4> to memref<128x128xi4, 3>
return %0 : !nvgpu.device.async.token
}
@@ -296,11 +296,31 @@ func.func @async_cp_i4(
// -----
// CHECK-LABEL: @async_cp_zfill_f32_align4(
-// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index
func.func @async_cp_zfill_f32_align4(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
- // CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(16 : i32) : i32
- // CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
+ // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+ // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
+ // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64
+ // CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64
+ // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(128 : index) : i64
+ // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64
+ // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI1]], %[[LI]] : i64
+ // CHECK-DAG: %[[FI2:.*]] = llvm.add %[[FI1]], %[[IDX1]] : i64
+ // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
+ // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK-DAG: %[[S2:.*]] = llvm.mlir.constant(128 : index) : i64
+ // CHECK-DAG: %[[FI2:.*]] = llvm.mul %[[IDX1]], %[[S2]] : i64
+ // CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64
+ // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr
+ // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr<1>
+ // CHECK-DAG: %[[c1:.*]] = llvm.mlir.constant(3 : i32) : i32
+ // CHECK-DAG: %[[c2:.*]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK-DAG: %[[c3:.*]] = llvm.trunc %[[SRC1]] : i64 to i32
+ // CHECK-DAG: %[[c4:.*]] = llvm.mul %[[c2]], %[[c3]] : i32
+ // CHECK-DAG: %[[c5:.*]] = llvm.lshr %[[c4]], %[[c1]] : i32
+ // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16, cache = cg, %[[c5]]
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
// CHECK: nvvm.cp.async.commit.group
%1 = nvgpu.device_async_create_group %0
@@ -316,9 +336,29 @@ func.func @async_cp_zfill_f32_align4(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
func.func @async_cp_zfill_f32_align1(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
- // CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(4 : i32) : i32
- // CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
- %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
+ // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+ // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
+ // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)>
+ // CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64
+ // CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64
+ // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(128 : index) : i64
+ // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64
+ // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI1]], %[[LI]] : i64
+ // CHECK-DAG: %[[FI2:.*]] = llvm.add %[[FI1]], %[[IDX1]] : i64
+ // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
+ // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK-DAG: %[[S2:.*]] = llvm.mlir.constant(128 : index) : i64
+ // CHECK-DAG: %[[FI2:.*]] = llvm.mul %[[IDX1]], %[[S2]] : i64
+ // CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64
+ // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr
+ // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr<1>
+ // CHECK-DAG: %[[c1:.*]] = llvm.mlir.constant(3 : i32) : i32
+ // CHECK-DAG: %[[c2:.*]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK-DAG: %[[c3:.*]] = llvm.trunc %[[SRC1]] : i64 to i32
+ // CHECK-DAG: %[[c4:.*]] = llvm.mul %[[c2]], %[[c3]] : i32
+ // CHECK-DAG: %[[c5:.*]] = llvm.lshr %[[c4]], %[[c1]] : i32
+ // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 4, cache = ca, %[[c5]]
+ %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements : memref<128x128xf32> to memref<3x16x128xf32, 3>
// CHECK: nvvm.cp.async.commit.group
%1 = nvgpu.device_async_create_group %0
// CHECK: nvvm.cp.async.wait.group 1
diff --git a/mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir b/mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir
index 2dd329d4956ce6..1a37f1c046cf66 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir
@@ -21,14 +21,14 @@ func.func @async_cp(
// CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr<f32> to !llvm.ptr<i8>
// CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr<i8> to !llvm.ptr<i8, 1>
- // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16
+ // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16, cache = ca
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 : memref<128x128xf32> to memref<3x16x128xf32, 3>
// CHECK: nvvm.cp.async.commit.group
%1 = nvgpu.device_async_create_group %0
// CHECK: nvvm.cp.async.wait.group 1
nvgpu.device_async_wait %1 { numGroups = 1 : i32 }
- // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1}
+ // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg
%2 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
return
}
@@ -53,7 +53,7 @@ func.func @async_cp_i4(
// CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr<i4>, i64) -> !llvm.ptr<i4>
// CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr<i4> to !llvm.ptr<i8>
// CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr<i8> to !llvm.ptr<i8, 1>
- // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16
+ // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16, cache = ca
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 32 : memref<128x64xi4> to memref<128x128xi4, 3>
return %0 : !nvgpu.device.async.token
}
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 3d863efd44e769..81210c1ba37cec 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -2,28 +2,46 @@
// CHECK-LABEL : @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) -> i32{
- //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;", "=r,r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> i32
+ //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32
%res = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i32
llvm.return %res : i32
}
// CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic
llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32)-> i32 {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32
%res = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i32
llvm.return %res : i32
}
// CHECK-LABEL : @init_mbarrier_try_wait.parity.shared
llvm.func @init_mbarrier_try_wait.parity.shared(%barrier : !llvm.ptr<3>, %token : i32) -> i32 {
- // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> i32
+ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32
%res = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32
llvm.return %res : i32
}
// CHECK-LABEL : @init_mbarrier_try_wait.parity
llvm.func @init_mbarrier_try_wait.parity(%barrier : !llvm.ptr, %token : i32) -> i32{
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32
%res = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32
llvm.return %res : i32
}
+
+// CHECK-LABEL : @async_cp
+func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
+ // CHECK : nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
+ nvvm.cp.async.shared.global %dst, %src, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
+ // CHECK : nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
+ nvvm.cp.async.shared.global %dst, %src, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
+ return
+}
+
+// CHECK-LABEL : @async_cp_zfill
+func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
+ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [%0], [%1], %2, %3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
+ nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [%0], [%1], %2, %3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
+ nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+ return
+}
diff --git a/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
index c57a37da145b9e..bb177eb1500ad6 100644
--- a/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
@@ -278,15 +278,15 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
func.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
// expected-error @below {{expected byte size to be either 4, 8 or 16.}}
- nvvm.cp.async.shared.global %arg0, %arg1, 32 : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
+ nvvm.cp.async.shared.global %arg0, %arg1, 32, cache = ca : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
return
}
// -----
func.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
- // expected-error @below {{bypass l1 is only support for 16 bytes copy.}}
- nvvm.cp.async.shared.global %arg0, %arg1, 8 {bypass_l1} : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
+ // expected-error @below {{CG cache modifier is only support for 16 bytes copy.}}
+ nvvm.cp.async.shared.global %arg0, %arg1, 8, cache = cg : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
return
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 57e02427dc4ba2..65b1d2fa2511dc 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1281,15 +1281,15 @@ func.func @bitcast(%arg0: vector<2x3xf32>) {
func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
// expected-error @below {{expected byte size to be either 4, 8 or 16.}}
- nvvm.cp.async.shared.global %arg0, %arg1, 32 : !llvm.ptr<3>, !llvm.ptr<1>
+ nvvm.cp.async.shared.global %arg0, %arg1, 32, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
return
}
// -----
func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
- // expected-error @below {{bypass l1 is only support for 16 bytes copy.}}
- nvvm.cp.async.shared.global %arg0, %arg1, 8 {bypass_l1} : !llvm.ptr<3>, !llvm.ptr<1>
+ // expected-error @below {{CG cache modifier is only support for 16 bytes copy.}}
+ nvvm.cp.async.shared.global %arg0, %arg1, 8, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
return
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/nvvm-typed-pointers.mlir
index 5fbadd1dc414e3..68eb1ecca00b88 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-typed-pointers.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-typed-pointers.mlir
@@ -11,10 +11,10 @@ func.func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr<i32>, %arg1 : i32) -> !llvm.stru
// CHECK-LABEL: @cp_async
llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
-// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16
- nvvm.cp.async.shared.global %arg0, %arg1, 16 : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
-// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1}
- nvvm.cp.async.shared.global %arg0, %arg1, 16 {bypass_l1} : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
+// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, cache = ca
+ nvvm.cp.async.shared.global %arg0, %arg1, 16, cache=ca : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
+// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, cache = cg
+ nvvm.cp.async.shared.global %arg0, %arg1, 16, cache=cg : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
// CHECK: nvvm.cp.async.commit.group
nvvm.cp.async.commit.group
// CHECK: nvvm.cp.async.wait.group 0
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 1e1cd660a48e58..bbc7676b45eafc 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -289,10 +289,10 @@ func.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
// CHECK-LABEL: @cp_async
llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
-// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16
- nvvm.cp.async.shared.global %arg0, %arg1, 16 : !llvm.ptr<3>, !llvm.ptr<1>
-// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1}
- nvvm.cp.async.shared.global %arg0, %arg1, 16 {bypass_l1} : !llvm.ptr<3>, !llvm.ptr<1>
+// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = ca
+ nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
+// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg
+ nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
// CHECK: nvvm.cp.async.commit.group
nvvm.cp.async.commit.group
// CHECK: nvvm.cp.async.wait.group 0
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index 02ee7885ed9276..a0a8a115a4f424 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -185,3 +185,12 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
(vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
+
+// -----
+
+func.func @async_cp_zfill_f32_align1(
+ %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
+ // expected-error @+1 {{'nvgpu.device_async_copy' op bypassL1 does not satify alignment for 'memref<3x16x128xf32, 3>' with destination element 1. Unset bypassL1, or set destination element to 4}}
+ %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements {bypassL1} : memref<128x128xf32> to memref<3x16x128xf32, 3>
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 522cce57a88a6e..39af5895387d32 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -309,13 +309,13 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
// CHECK-LABEL: @cp_async
llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
- nvvm.cp.async.shared.global %arg0, %arg1, 4 : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
+ nvvm.cp.async.shared.global %arg0, %arg1, 4, cache = ca : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
- nvvm.cp.async.shared.global %arg0, %arg1, 8 : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
+ nvvm.cp.async.shared.global %arg0, %arg1, 8, cache = ca : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
- nvvm.cp.async.shared.global %arg0, %arg1, 16 : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
+ nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = ca : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
// CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
- nvvm.cp.async.shared.global %arg0, %arg1, 16 {bypass_l1} : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
+ nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = cg : !llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>
// CHECK: call void @llvm.nvvm.cp.async.commit.group()
nvvm.cp.async.commit.group
// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
More information about the Mlir-commits
mailing list