[Mlir-commits] [mlir] a3cd2ee - [mlir][nvgpu] Add a nvgpu.rewrite_copy_as_tma transform operation.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Aug 8 05:08:22 PDT 2023
Author: Nicolas Vasilache
Date: 2023-08-08T12:07:59Z
New Revision: a3cd2eeb2d2a69b2c4d1fe9b634123b9f9943d0b
URL: https://github.com/llvm/llvm-project/commit/a3cd2eeb2d2a69b2c4d1fe9b634123b9f9943d0b
DIFF: https://github.com/llvm/llvm-project/commit/a3cd2eeb2d2a69b2c4d1fe9b634123b9f9943d0b.diff
LOG: [mlir][nvgpu] Add a nvgpu.rewrite_copy_as_tma transform operation.
This revision adds support for direct lowering of a linalg.copy on buffers between global and shared memory to a tma async load + synchronization operations.
This uses the recently introduced Hopper NVVM and NVGPU abstraction to connect things end to end.
Differential Revision: https://reviews.llvm.org/D157087
Added:
mlir/test/Dialect/NVGPU/tmaload-transform.mlir
mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir
Modified:
mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
mlir/lib/Dialect/Utils/IndexingUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
index 114914dce3ed7d..b1d0e96897ee28 100644
--- a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
@@ -164,4 +164,33 @@ def RewriteMatmulAsMmaSyncOp :
}];
}
+//===----------------------------------------------------------------------===//
+// RewriteCopyAsTmaOp
+//===----------------------------------------------------------------------===//
+
+def RewriteCopyAsTmaOp :
+ Op<Transform_Dialect, "nvgpu.rewrite_copy_as_tma",
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformEachOpTrait,
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Rewrite a copy operation on memref to tma operations that transit through
+ shared memory.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // NVGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index ca0af5fd542242..e99cfb7fddf306 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -20,20 +21,17 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Support/LogicalResult.h"
+#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
-#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::nvgpu;
+using namespace mlir::NVVM;
using namespace mlir::transform;
#define DEBUG_TYPE "nvgpu-transforms"
@@ -517,7 +515,7 @@ struct MmaSyncBuilder {
/// Build a list of memref.load operations indexed at `(row, col)` indices
/// that make sense for a particular MMA instruction and specified via the
/// IndexCalculator callback.
- SmallVector<Value> buildMemrefLoads(OpBuilder &b, Location loc,
+ SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
OpFoldResult laneId, Value memref,
IndexCalculator indexFn);
@@ -527,7 +525,7 @@ struct MmaSyncBuilder {
/// data that makes sense for the particular MMA operation.
/// The `vectorShape` matches existing NVGPU dialect op specification but
/// could also be flattened in the future if needed for simplification.
- Value buildMmaSyncMemrefLoadOperand(OpBuilder &b, Location loc,
+ Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
OpFoldResult laneId, Value memref,
IndexCalculator indexFn,
ArrayRef<int64_t> vectorShape);
@@ -535,7 +533,7 @@ struct MmaSyncBuilder {
/// Build a list of memref.store operations indexed at `(row, col)` indices
/// that make sense for a particular MMA instruction and specified via the
/// IndexCalculator callback.
- SmallVector<Operation *> buildMemrefStores(OpBuilder &b, Location loc,
+ SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
ValueRange toStore,
OpFoldResult laneId, Value memref,
IndexCalculator indexFn);
@@ -546,7 +544,7 @@ struct MmaSyncBuilder {
/// data that makes sense for the particular MMA operation.
/// The `vectorShape` matches existing NVGPU dialect op specification but
/// could also be flattened in the future if needed for simplification.
- SmallVector<Operation *> buildMmaSyncMemrefStoreOperand(
+ SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
@@ -573,7 +571,7 @@ static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
}
}
-SmallVector<Value> MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc,
+SmallVector<Value> MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
OpFoldResult laneId,
Value memref,
IndexCalculator indexFn) {
@@ -591,10 +589,10 @@ SmallVector<Value> MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc,
return res;
}
-Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand(
+Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
- auto loads = buildMemrefLoads(b, loc, laneId, memref, indexFn);
+ auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);
Type elementType = getElementTypeOrSelf(memref.getType());
auto vt = VectorType::get(vectorShape, elementType);
@@ -614,7 +612,7 @@ Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand(
}
SmallVector<Operation *>
-MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc,
+MmaSyncBuilder::buildMemRefStores(OpBuilder &b, Location loc,
ValueRange toStore, OpFoldResult laneId,
Value memref, IndexCalculator indexFn) {
auto aff = [&](AffineExpr e) {
@@ -632,7 +630,7 @@ MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc,
return res;
}
-SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemrefStoreOperand(
+SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
SmallVector<Value> toStore;
@@ -647,7 +645,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemrefStoreOperand(
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
toStore.push_back(v);
});
- return buildMemrefStores(b, loc, toStore, laneId, memref, indexFn);
+ return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);
}
static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
@@ -690,22 +688,22 @@ MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
}
FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
- Value lhsMemref = linalgOp.getDpsInputOperand(0)->get();
- Value rhsMemref = linalgOp.getDpsInputOperand(1)->get();
- Value resMemref = linalgOp.getDpsInitOperand(0)->get();
- assert(lhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
+ Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
+ Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
+ Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
+ assert(lhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
"expected lhs to be a 2D memref");
- assert(rhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
+ assert(rhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
"expected rhs to be a 2D memref");
- assert(resMemref.getType().cast<MemRefType>().getRank() == 2 &&
+ assert(resMemRef.getType().cast<MemRefType>().getRank() == 2 &&
"expected res to be a 2D memref");
- int64_t m = cast<MemRefType>(lhsMemref.getType()).getShape()[0];
- int64_t n = cast<MemRefType>(rhsMemref.getType()).getShape()[1];
- int64_t k = cast<MemRefType>(lhsMemref.getType()).getShape()[1];
- Type lhsType = getElementTypeOrSelf(lhsMemref.getType());
- Type rhsType = getElementTypeOrSelf(rhsMemref.getType());
- Type resType = getElementTypeOrSelf(resMemref.getType());
+ int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
+ int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
+ int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
+ Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
+ Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
+ Type resType = getElementTypeOrSelf(resMemRef.getType());
FailureOr<MmaSyncInfo> maybeInfo =
getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
@@ -715,15 +713,15 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
MmaSyncInfo info = *maybeInfo;
auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
- Value lhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, lhsMemref,
+ Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
lhsIndexFn, lhsShape);
- Value rhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, rhsMemref,
+ Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
rhsIndexFn, rhsShape);
- Value res = buildMmaSyncMemrefLoadOperand(b, loc, laneId, resMemref,
+ Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
resIndexFn, resShape);
res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
info.tf32Enabled);
- buildMmaSyncMemrefStoreOperand(b, loc, res, laneId, resMemref, resIndexFn,
+ buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
resShape);
return res.getDefiningOp();
}
@@ -754,6 +752,284 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// Hopper builders.
+//===----------------------------------------------------------------------===//
+
+/// Helper to create the base Hopper-specific operations that are reused in
+/// various other places.
+struct HopperBuilder {
+ HopperBuilder(RewriterBase &rewriter, Location loc)
+ : rewriter(rewriter), loc(loc) {}
+
+ TypedValue<nvgpu::MBarrierType>
+ buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
+
+ /// Create tma descriptor op to initiate transfer from global to shared
+ /// memory. This must be done before the launch op, on the host.
+ TypedValue<nvgpu::TensorMapDescriptorType>
+ buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
+ gpu::LaunchOp launchOp);
+
+ /// Build a tma load from global memory to shared memory using `barrier` to
+ /// synchronize. Return the number of bytes that will be transferred.
+ OpFoldResult
+ buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
+ TypedValue<MemRefType> sharedMemref,
+ TypedValue<nvgpu::MBarrierType> barrier,
+ SmallVectorImpl<Operation *> &loadOps);
+ void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierType> barrier,
+ ArrayRef<OpFoldResult> sizes);
+
+ /// If threadIdx.x == 0 does TMA request + wait, else just wait.
+ /// Return the operation that performs the transfer on thread0.
+ // TODO: In the future, don't hardcode to thread 0 but elect a leader.
+ SmallVector<Operation *> buildPredicateLoadsOnThread0(
+ ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
+ ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
+ TypedValue<nvgpu::MBarrierType> barrier);
+
+ void buildTryWaitParity(TypedValue<nvgpu::MBarrierType> barrier);
+
+ RewriterBase &rewriter;
+ Location loc;
+};
+
+SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
+ ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
+ ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
+ TypedValue<nvgpu::MBarrierType> barrier) {
+ SmallVector<Operation *> loadOps;
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
+ Value cond =
+ rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
+ // clang-format off
+ rewriter.create<scf::IfOp>(
+ /*location=*/loc,
+ /*conditional=*/cond,
+ /*thenBuilder=*/
+ [&](OpBuilder &lb, Location loc) {
+ SmallVector<OpFoldResult> sizes;
+ sizes.reserve(globalDescriptors.size());
+ for (auto [desc, shmem] : llvm::zip_equal(
+ globalDescriptors, sharedMemBuffers)) {
+ OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
+ sizes.push_back(sz);
+ }
+ // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
+ // This may or may not have perf implications.
+ buildBarrierArriveTx(barrier, sizes);
+ rewriter.create<scf::YieldOp>(loc);
+ },
+ /*elseBuilder=*/
+ [&](OpBuilder &lb, Location loc) {
+ // TODO: is this for no-thread divergence?
+ // Should we just yield the size and hoist?
+ buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
+ rewriter.create<scf::YieldOp>(loc);
+ });
+ // clang-format on
+ return loadOps;
+}
+
+static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
+ return gpu::AddressSpaceAttr::get(
+ b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
+ // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
+}
+
+TypedValue<nvgpu::MBarrierType>
+HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
+ auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
+ Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
+ loc, nvgpu::MBarrierType::get(rewriter.getContext(), sharedMemorySpace));
+ rewriter.create<nvgpu::MBarrierInitOp>(
+ loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads));
+ rewriter.create<gpu::BarrierOp>(loc);
+ return cast<TypedValue<nvgpu::MBarrierType>>(barrier);
+}
+
+TypedValue<nvgpu::TensorMapDescriptorType>
+HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
+ gpu::LaunchOp launchOp) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(launchOp);
+ Value unrankedMemRef = rewriter.create<memref::CastOp>(
+ loc,
+ UnrankedMemRefType::get(memref.getType().getElementType(),
+ memref.getType().getMemorySpace()),
+ memref);
+ SmallVector<OpFoldResult> mixedSizes =
+ memref::getMixedSizes(rewriter, loc, memref);
+ SmallVector<Value> sizes =
+ getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
+
+ auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
+ Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
+ loc,
+ nvgpu::TensorMapDescriptorType::get(
+ rewriter.getContext(),
+ MemRefType::Builder(memref.getType())
+ .setMemorySpace(sharedMemorySpace),
+ TensorMapSwizzleKind::SWIZZLE_NONE,
+ TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
+ TensorMapInterleaveKind::INTERLEAVE_NONE),
+ unrankedMemRef, sizes);
+ return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
+}
+
+OpFoldResult HopperBuilder::buildTmaAsyncLoad(
+ TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
+ TypedValue<MemRefType> sharedMemref,
+ TypedValue<nvgpu::MBarrierType> barrier,
+ SmallVectorImpl<Operation *> &loadOps) {
+ MLIRContext *ctx = rewriter.getContext();
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
+ loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero});
+ loadOps.push_back(loadOp);
+ auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
+ SmallVector<AffineExpr> symbols(mixedSizes.size());
+ bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
+ AffineExpr prodExprInBytes =
+ computeProduct(ctx, symbols) *
+ (sharedMemref.getType().getElementTypeBitWidth() / 8);
+ auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
+ prodExprInBytes, mixedSizes);
+ return res;
+}
+
+void HopperBuilder::buildBarrierArriveTx(
+ TypedValue<nvgpu::MBarrierType> barrier,
+ ArrayRef<OpFoldResult> mixedSizes) {
+ assert(!mixedSizes.empty() && "expecte non-empty sizes");
+ MLIRContext *ctx = rewriter.getContext();
+ SmallVector<AffineExpr> symbols(mixedSizes.size());
+ bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
+ AffineExpr sumExpr = computeSum(ctx, symbols);
+ OpFoldResult size =
+ affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
+ Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
+ rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal);
+}
+
+void HopperBuilder::buildTryWaitParity(
+ TypedValue<nvgpu::MBarrierType> barrier) {
+ Value parity = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ // 10M is an arbitrary, not too small or too big number to specify the number
+ // of ticks before retry.
+ // TODO: hoist this in a default dialect constant.
+ Value ticksBeforeRetry =
+ rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
+ rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
+ ticksBeforeRetry);
+}
+
+//===----------------------------------------------------------------------===//
+// RewriteCopyAsTmaOp
+//===----------------------------------------------------------------------===//
+
+/// Helper to create the tma operations corresponding to `linalg::CopyOp`.
+struct CopyBuilder : public HopperBuilder {
+ CopyBuilder(RewriterBase &rewriter, Location loc)
+ : HopperBuilder(rewriter, loc) {}
+
+ SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps);
+};
+
+SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
+ MLIRContext *ctx = rewriter.getContext();
+ if (copyOps.empty())
+ return SmallVector<Operation *>();
+
+ auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
+ assert(launchOp && "expected launch op");
+
+ // 1. Init a barrier object in shared memory.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(copyOps.front());
+ AffineExpr bx, by, bz;
+ bindSymbols(ctx, bx, by, bz);
+ AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
+ OpFoldResult numThreads = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, prod,
+ ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
+ launchOp.getBlockSizeZ()});
+
+ TypedValue<nvgpu::MBarrierType> barrier =
+ buildAndInitBarrierInSharedMemory(numThreads);
+
+ SmallVector<TypedValue<MemRefType>> shmems;
+ SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs;
+ for (Operation *op : copyOps) {
+ auto copyOp = cast<linalg::CopyOp>(op);
+ auto inMemRef =
+ cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
+ MemRefType inMemRefType = inMemRef.getType();
+ assert(inMemRefType.getRank() == 2 && "expected in to be a 2D memref");
+
+ // 2. Build global memory descriptor.
+ TypedValue<nvgpu::TensorMapDescriptorType> globalDesc =
+ buildGlobalMemRefDescriptor(inMemRef, launchOp);
+ globalDescs.push_back(globalDesc);
+
+ // 3. Shared memory and descriptor for the tmp array.
+ auto shmem =
+ cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
+ shmems.push_back(shmem);
+ }
+
+ // 4. Load in from global memory to shared memory using tma.
+ OpBuilder::InsertionGuard g2(rewriter);
+ rewriter.setInsertionPoint(copyOps.front());
+ SmallVector<Operation *> results =
+ buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
+
+ // 5. Spin-loop until data is ready.
+ buildTryWaitParity(barrier);
+
+ // 6. Erase the ops that have now been rewritten.
+ for (Operation *op : copyOps)
+ rewriter.eraseOp(op);
+
+ return results;
+}
+
+DiagnosedSilenceableFailure
+transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payloadOps = state.getPayloadOps(getTarget());
+ gpu::LaunchOp commonLaunchOp;
+ Operation *firstOp, *failingOp;
+ if (llvm::any_of(payloadOps, [&](Operation *op) {
+ if (!commonLaunchOp) {
+ commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
+ firstOp = op;
+ }
+ auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
+ commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
+ !isa<linalg::CopyOp>(op);
+ if (fail)
+ failingOp = op;
+ return fail;
+ })) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "target ops must be linalg::CopyOp nested under a common "
+ "gpu.LaunchOp to be rewritten because the tma descriptors need to "
+ "be created on the host.\nBut got: "
+ << *firstOp << "\nand " << *failingOp;
+ return diag;
+ }
+
+ // TODO: more robust detection of copy, with transposes etc.
+ CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
+
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
@@ -767,6 +1043,7 @@ class NVGPUTransformDialectExtension
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<affine::AffineDialect>();
declareGeneratedDialect<nvgpu::NVGPUDialect>();
+ declareGeneratedDialect<NVVM::NVVMDialect>();
declareGeneratedDialect<vector::VectorDialect>();
registerTransformOps<
#define GET_OP_LIST
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 5821876139d064..2a774b599a8b68 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -161,13 +161,13 @@ AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
if (basis.empty())
return getAffineConstantExpr(0, ctx);
return std::accumulate(basis.begin(), basis.end(),
- getAffineConstantExpr(1, ctx),
+ getAffineConstantExpr(0, ctx),
std::plus<AffineExpr>());
}
AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
if (basis.empty())
- return getAffineConstantExpr(0, ctx);
+ return getAffineConstantExpr(1, ctx);
return std::accumulate(basis.begin(), basis.end(),
getAffineConstantExpr(1, ctx),
std::multiplies<AffineExpr>());
diff --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
new file mode 100644
index 00000000000000..646008b64f794f
--- /dev/null
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -0,0 +1,84 @@
+// RUN: mlir-opt %s \
+// RUN: -test-transform-dialect-interpreter \
+// RUN: -test-transform-dialect-erase-schedule \
+// RUN: | FileCheck %s
+
+memref.global "private" @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+
+// CHECK-LABEL: func.func @main()
+func.func @main() {
+ %c1 = arith.constant 1 : index
+ %c128 = arith.constant 128 : index
+
+ %0 = gpu.wait async
+ %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
+ %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
+
+ // CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32> to memref<*xf32>
+ // CHECK: %[[c64:.*]] = arith.constant 64 : index
+ // CHECK: %[[c8:.*]] = arith.constant 8 : index
+ // CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]]
+ // CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+ // CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32> to memref<*xf32>
+ // CHECK: %[[c8_2:.*]] = arith.constant 8 : index
+ // CHECK: %[[c128_2:.*]] = arith.constant 128 : index
+ // CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]]
+ // CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+ // CHECK: gpu.launch
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
+ // CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
+ // CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+ %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
+ %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+
+ // CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>
+ // CHECK: nvgpu.mbarrier.init %[[B]], %{{.*}} : <memorySpace = #gpu.address_space<workgroup>
+ // CHECK: gpu.barrier
+ //
+ // CHECK: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK: %[[TIDX:.*]] = gpu.thread_id x
+ // CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[TIDX]], %[[c0]] : index
+ //
+ // CHECK: scf.if %[[CMP]] {
+ //
+ // CHECK: %[[c0_7:.*]] = arith.constant 0 : index
+ // CHECK: nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]] to %[[G1]]
+ // CHECK-SAME: : <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>,
+ // CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
+ // CHECK-SAME: -> memref<64x8xf32, #gpu.address_space<workgroup>>
+ //
+ // CHECK: %[[c0_8:.*]] = arith.constant 0 : index
+ // CHECK: nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]] to %[[G2]]
+ // CHECK-SAME: : <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>,
+ // CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
+ // CHECK-SAME: -> memref<8x128xf32, #gpu.address_space<workgroup>>
+ //
+ // CHECK: %[[c6144:.*]] = arith.constant 6144 : index
+ // CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]], %[[c6144]] : <memorySpace = #gpu.address_space<workgroup>
+ // CHECK: } else {
+ // CHECK: %[[c0_7:.*]] = arith.constant 0 : index
+ // CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
+ // CHECK: }
+ //
+ // CHECK: %[[c0_6:.*]] = arith.constant 0 : index
+ // CHECK: %[[c10000000:.*]] = arith.constant 10000000 : index
+ // CHECK: nvgpu.mbarrier.try_wait.parity %[[B]], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
+
+ /// Both copies are matched and end up in the same async group.
+ linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, #gpu.address_space<workgroup>>)
+ linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, #gpu.address_space<workgroup>>)
+
+ gpu.terminator
+ }
+
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %copy = transform.structured.match ops{["linalg.copy"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.nvgpu.rewrite_copy_as_tma %copy : (!transform.any_op) -> ()
+}
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir
new file mode 100644
index 00000000000000..3a6bbe7f0d7721
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s \
+// RUN: -test-transform-dialect-interpreter \
+// RUN: -test-transform-dialect-erase-schedule \
+// RUN: -convert-nvgpu-to-nvvm -gpu-kernel-outlining \
+// RUN: -convert-scf-to-cf -convert-nvvm-to-llvm \
+// RUN: -convert-vector-to-llvm \
+// RUN: -convert-math-to-llvm \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-index-to-llvm=index-bitwidth=32 \
+// RUN: -convert-arith-to-llvm \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -canonicalize \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_90 features=+ptx80 dump-ptx}))' \
+// RUN: 2&>1 | FileCheck %s --check-prefixes=CHECK-PTX
+
+// CHECK-PTX: mbarrier.init.shared {{.*}} !llvm.ptr<3>, i32
+/// If branch
+// CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
+// CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
+// CHECK-PTX: mbarrier.arrive.expect_tx.shared
+/// Else branch
+// CHECK-PTX: mbarrier.arrive.expect_tx.shared
+// CHECK-PTX: mbarrier.try_wait.parity.shared
+
+// TODO: GPU layering does not currently work end-to-end. Activate the following
+// when fixed.
+// R-UN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \
+// R-UN: -gpu-to-llvm \
+// R-UN: -convert-func-to-llvm \
+// R-UN: -cse \
+// R-UN: -canonicalize \
+// R-UN: -reconcile-unrealized-casts \
+// R-UN: | mlir-cpu-runner \
+// R-UN: --shared-libs=%mlir_cuda_runtime \
+// R-UN: --shared-libs=%mlir_runner_utils \
+// R-UN: --entry-point-result=void \
+// R-UN: | FileCheck %s
+
+// C-HECK: [GPU] TMA BEFORE lhs[45][7] 0.000000
+// C-HECK: [GPU] TMA BEFORE rhs[7][0] 0.000000
+// C-HECK: [GPU] TMA LOADED lhs[45][7] 7.000000
+// C-HECK: [GPU] TMA LOADED rhs[7][0] 3.000000
+
+
+module @mymod {
+ memref.global "private" @bufferLhsGlobal : memref<64x8xf32, 3>
+ memref.global "private" @bufferRhsGlobal : memref<8x128xf32, 3>
+ func.func @main() {
+ %c10000000 = arith.constant 10000000 : index
+ %c6144 = arith.constant 6144 : index
+ %c45 = arith.constant 45 : index
+ %c7 = arith.constant 7 : index
+ %c64 = arith.constant 64 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 3.000000e+00 : f32
+ %alloc = memref.alloc() : memref<64x8xf32>
+ %alloc_0 = memref.alloc() : memref<8x128xf32>
+ scf.for %arg0 = %c0 to %c8 step %c1 {
+ scf.for %arg1 = %c0 to %c128 step %c1 {
+ memref.store %cst, %alloc_0[%arg0, %arg1] : memref<8x128xf32>
+ }
+ }
+ scf.for %arg0 = %c0 to %c64 step %c1 {
+ scf.for %arg1 = %c0 to %c8 step %c1 {
+ %5 = arith.index_cast %arg1 : index to i64
+ %6 = arith.uitofp %5 : i64 to f32
+ memref.store %6, %alloc[%arg0, %arg1] : memref<64x8xf32>
+ }
+ }
+ %0 = gpu.wait async
+ %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
+ %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
+ %1 = gpu.memcpy async [%0] %memref, %alloc : memref<64x8xf32>, memref<64x8xf32>
+ %2 = gpu.memcpy async [%0] %memref_1, %alloc_0 : memref<8x128xf32>, memref<8x128xf32>
+
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
+ %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, 3>
+ %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, 3>
+ linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, 3>)
+ linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, 3>)
+
+ %6 = gpu.thread_id x
+ %10 = arith.cmpi eq, %6, %c0 : index
+ scf.if %10 {
+ %11 = memref.load %out[%c45, %c7] : memref<64x8xf32, 3>
+ %12 = memref.load %out_1[%c7, %c0] : memref<8x128xf32, 3>
+ gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32
+ gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32
+ }
+ gpu.terminator
+ }
+
+ return
+ }
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %copy = transform.structured.match ops{["linalg.copy"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.nvgpu.rewrite_copy_as_tma %copy
+ : (!transform.any_op) -> ()
+}
More information about the Mlir-commits
mailing list