[Mlir-commits] [mlir] 4319e19 - [mlir][nvgpu] Introduce Multicast Capability to `nvgpu.tma.async.load` (#76935)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 5 01:48:59 PST 2024
Author: Guray Ozen
Date: 2024-01-05T10:48:55+01:00
New Revision: 4319e1916dd13f5f7f56536acf01899320b82c6a
URL: https://github.com/llvm/llvm-project/commit/4319e1916dd13f5f7f56536acf01899320b82c6a
DIFF: https://github.com/llvm/llvm-project/commit/4319e1916dd13f5f7f56536acf01899320b82c6a.diff
LOG: [mlir][nvgpu] Introduce Multicast Capability to `nvgpu.tma.async.load` (#76935)
This PR improves the functionality of the `nvgpu.tma.async.load` Op by
adding support for multicast. While we already had this capability in
the lower-level `nvvm.cp.async.bulk.tensor.shared.cluster.global` NVVM
Op, this PR lowers mask information to the NVVM operation.
Added:
Modified:
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 440f7d0380eb17..7e139663d74b47 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -642,16 +642,18 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]
The Op uses `$barrier` mbarrier based completion mechanism.
}];
- let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$dst,
- NVGPU_MBarrierGroup:$barriers,
- NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
- Variadic<Index>:$coordinates,
- Index:$mbarId,
- Optional<I1>:$predicate);
+ let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$dst,
+ NVGPU_MBarrierGroup:$barriers,
+ NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
+ Variadic<Index>:$coordinates,
+ Index:$mbarId,
+ Optional<I16>:$multicastMask,
+ Optional<I1>:$predicate);
let assemblyFormat = [{
$tensorMapDescriptor `[` $coordinates `]` `,` $barriers `[` $mbarId `]`
`to` $dst
- (`,` `predicate` `=` $predicate^)?
+ (`multicast_mask` `=` $multicastMask^ )?
+ (`,` `predicate` `=` $predicate^)?
attr-dict `:` type($tensorMapDescriptor) `,` type($barriers)
`->` type($dst)
}];
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9cd3a5ce65ce5f..db84e5cf62a5e9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -990,7 +990,8 @@ struct NVGPUTmaAsyncLoadOpLowering
}
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
- ValueRange{}, Value{}, Value{}, adaptor.getPredicate());
+ ValueRange{}, adaptor.getMulticastMask(), Value{},
+ adaptor.getPredicate());
return success();
}
};
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index b68bed3aa53cf9..aebdd0a4ee4147 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -980,7 +980,7 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad(
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
- Value());
+ Value(), Value());
loadOps.push_back(loadOp);
auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
SmallVector<AffineExpr> symbols(mixedSizes.size());
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index e11449e6f7c457..b8a0f75d1cc8b9 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -704,6 +704,29 @@ func.func @async_tma_load_pred(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensor
func.return
}
+func.func @async_tma_load_multicast(
+ %tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d,
+ %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d,
+ %tensorMap5d: !tensorMap5d, %buffer1d: memref<128xf32,3>,
+ %buffer2d: memref<32x32xf32,3>, %buffer3d: memref<2x32x32xf32,3>,
+ %buffer4d: memref<2x2x32x32xf32,3>, %buffer5d: memref<2x2x2x32x32xf32,3>,
+ %mbarrier: !mbarrier,
+ %multicastMask: i16) {
+ %c0 = arith.constant 0 : index
+ %crd0 = arith.constant 0 : index
+ %crd1 = arith.constant 0 : index
+ // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}]
+ nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier[%c0] to %buffer1d multicast_mask = %multicastMask : !tensorMap1d, !mbarrier -> memref<128xf32,3>
+ // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}]
+ nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d multicast_mask = %multicastMask : !tensorMap2d, !mbarrier -> memref<32x32xf32,3>
+ // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d multicast_mask = %multicastMask : !tensorMap3d, !mbarrier -> memref<2x32x32xf32,3>
+ // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.load %tensorMap4d[%crd0, %crd1, %crd1, %crd0], %mbarrier[%c0] to %buffer4d multicast_mask = %multicastMask : !tensorMap4d, !mbarrier -> memref<2x2x32x32xf32,3>
+ // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.load %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], %mbarrier[%c0] to %buffer5d multicast_mask = %multicastMask : !tensorMap5d, !mbarrier -> memref<2x2x2x32x32xf32,3>
+ func.return
+}
func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
%crd0 = arith.constant 64 : index
More information about the Mlir-commits
mailing list