[Mlir-commits] [mlir] fbf69f9 - [mlir][NVGPU] Adding Support for cp_async_zfill via Inline Asm
Thomas Raoux
llvmlistbot at llvm.org
Fri Sep 2 14:30:16 PDT 2022
Author: Manish Gupta
Date: 2022-09-02T21:29:26Z
New Revision: fbf69f95b614e3428a5322005070986eae1dfb5a
URL: https://github.com/llvm/llvm-project/commit/fbf69f95b614e3428a5322005070986eae1dfb5a
DIFF: https://github.com/llvm/llvm-project/commit/fbf69f95b614e3428a5322005070986eae1dfb5a.diff
LOG: [mlir][NVGPU] Adding Support for cp_async_zfill via Inline Asm
`cp_async_zfill` is currently not present in the nvvm backend, this patch adds `cp_async_zfill` support by adding inline asm when lowering from `nvgpu` to `nvvm`.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D132269
Added:
Modified:
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index d0dd5a63021dc..3b92f5461eca0 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -151,6 +151,14 @@ def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
`bypassL1` attribute is hint to the backend and hardware that
the copy should by pass the L1 cache, this may be dropped by the backend or
hardware.
+ `dstElements` attribute is the total number of elements written to
+ destination (shared memory).
+ `srcElements` argument is the total number of elements read from
+ source (global memory).
+
+ srcElements` is an optional argument and when present it only reads
+ srcElements number of elements from the source global memory and zero fills
+ the rest of the elements in the destination shared memory.
In order to do a copy and wait for the result we need the following
combination:
@@ -183,10 +191,11 @@ def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
Variadic<Index>:$dstIndices,
Arg<AnyMemRef, "", [MemRead]>:$src,
Variadic<Index>:$srcIndices,
- IndexAttr:$numElements,
+ IndexAttr:$dstElements,
+ Optional<Index>:$srcElements,
OptionalAttr<UnitAttr>:$bypassL1);
let assemblyFormat = [{
- $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $numElements
+ $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $dstElements (`,` $srcElements^)?
attr-dict `:` type($src) `to` type($dst)
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9f1d19ddd6dcd..c4c49f2edd5ff 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -354,6 +354,35 @@ struct ConvertNVGPUToNVVMPass
}
};
+static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
+ Value dstBytes, Value srcElements,
+ mlir::MemRefType elementType,
+ ConversionPatternRewriter &rewriter) {
+ auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
+ LLVM::AsmDialect::AD_ATT);
+ const char *asmStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";
+ const char *asmConstraints = "r,l,n,r";
+
+ Value c3I32 = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
+ Value bitwidth = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(elementType.getElementTypeBitWidth()));
+ Value srcElementsI32 =
+ rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcElements);
+ Value srcBytes = rewriter.create<LLVM::LShrOp>(
+ loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32), c3I32);
+
+ SmallVector<Value> asmVals{dstPtr, srcPtr, dstBytes, srcBytes};
+
+ rewriter.create<LLVM::InlineAsmOp>(
+ loc, LLVM::LLVMVoidType::get(rewriter.getContext()), /*operands=*/asmVals,
+ /*asm_string=*/asmStr,
+ /*constraints=*/asmConstraints, /*has_side_effects=*/true,
+ /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
+ /*operand_attrs=*/ArrayAttr());
+}
+
struct NVGPUAsyncCopyLowering
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
using ConvertOpToLLVMPattern<
@@ -383,15 +412,33 @@ struct NVGPUAsyncCopyLowering
i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
scrPtr);
- int64_t numElements = adaptor.getNumElements().getZExtValue();
+ int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =
- (dstMemrefType.getElementTypeBitWidth() * numElements) / 8;
+ (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
// bypass L1 is only supported for byte sizes of 16, we drop the hint
// otherwise.
UnitAttr bypassL1 =
sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr();
- rewriter.create<NVVM::CpAsyncOp>(
- loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1);
+
+ // When the optional SrcElements argument is present, the source (global
+ // memory) of CpAsyncOp is read only for SrcElements number of elements. The
+ // rest of the DstElements in the destination (shared memory) are filled
+ // with zeros.
+ if (op.getSrcElements())
+ emitCpAsyncOpZfillAsm(loc, dstPtr, scrPtr,
+ rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(sizeInBytes)),
+ adaptor.getSrcElements(), srcMemrefType, rewriter);
+
+ // When the optional SrcElements argument is *not* present, the regular
+ // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
+ // memory) to fill DstElements number of elements in the destination (shared
+ // memory).
+ else
+ rewriter.create<NVVM::CpAsyncOp>(loc, dstPtr, scrPtr,
+ rewriter.getI32IntegerAttr(sizeInBytes),
+ bypassL1);
// Drop the result token.
Value zero = rewriter.create<LLVM::ConstantOp>(
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index aa71a26f81069..0a9f8d5611903 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -297,3 +297,19 @@ func.func @async_cp_i4(
return %0 : !nvgpu.device.async.token
}
+// -----
+
+// CHECK-LABEL: @async_cp_zfill(
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
+func.func @async_cp_zfill(
+ %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
+
+ // CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>, i32, i32) -> !llvm.void
+ %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
+ // CHECK: nvvm.cp.async.commit.group
+ %1 = nvgpu.device_async_create_group %0
+ // CHECK: nvvm.cp.async.wait.group 1
+ nvgpu.device_async_wait %1 { numGroups = 1 : i32 }
+
+ return
+}
More information about the Mlir-commits
mailing list