[Mlir-commits] [mlir] a3cd2ee - [mlir][nvgpu] Add a nvgpu.rewrite_copy_as_tma transform operation.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Aug 8 05:08:22 PDT 2023


Author: Nicolas Vasilache
Date: 2023-08-08T12:07:59Z
New Revision: a3cd2eeb2d2a69b2c4d1fe9b634123b9f9943d0b

URL: https://github.com/llvm/llvm-project/commit/a3cd2eeb2d2a69b2c4d1fe9b634123b9f9943d0b
DIFF: https://github.com/llvm/llvm-project/commit/a3cd2eeb2d2a69b2c4d1fe9b634123b9f9943d0b.diff

LOG: [mlir][nvgpu] Add a nvgpu.rewrite_copy_as_tma transform operation.

This revision adds support for direct lowering of a linalg.copy on buffers between global and shared memory to a tma async load + synchronization operations.
This uses the recently introduced Hopper NVVM and NVGPU abstraction to connect things end to end.

Differential Revision: https://reviews.llvm.org/D157087

Added: 
    mlir/test/Dialect/NVGPU/tmaload-transform.mlir
    mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir

Modified: 
    mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
    mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
    mlir/lib/Dialect/Utils/IndexingUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
index 114914dce3ed7d..b1d0e96897ee28 100644
--- a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
@@ -164,4 +164,33 @@ def RewriteMatmulAsMmaSyncOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// RewriteCopyAsTmaOp
+//===----------------------------------------------------------------------===//
+
+def RewriteCopyAsTmaOp :
+  Op<Transform_Dialect, "nvgpu.rewrite_copy_as_tma",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformEachOpTrait,
+     TransformOpInterface,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Rewrite a copy operation on memref to tma operations that transit through
+    shared memory.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // NVGPU_TRANSFORM_OPS

diff  --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index ca0af5fd542242..e99cfb7fddf306 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -20,20 +21,17 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Support/LogicalResult.h"
+#include "mlir/IR/Value.h"
 #include "llvm/ADT/ArrayRef.h"
-#include "llvm/Support/Debug.h"
 
 using namespace mlir;
 using namespace mlir::linalg;
 using namespace mlir::nvgpu;
+using namespace mlir::NVVM;
 using namespace mlir::transform;
 
 #define DEBUG_TYPE "nvgpu-transforms"
@@ -517,7 +515,7 @@ struct MmaSyncBuilder {
   /// Build a list of memref.load operations indexed at `(row, col)` indices
   /// that make sense for a particular MMA instruction and specified via the
   /// IndexCalculator callback.
-  SmallVector<Value> buildMemrefLoads(OpBuilder &b, Location loc,
+  SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
                                       OpFoldResult laneId, Value memref,
                                       IndexCalculator indexFn);
 
@@ -527,7 +525,7 @@ struct MmaSyncBuilder {
   /// data that makes sense for the particular MMA operation.
   /// The `vectorShape` matches existing NVGPU dialect op specification but
   /// could also be flattened in the future if needed for simplification.
-  Value buildMmaSyncMemrefLoadOperand(OpBuilder &b, Location loc,
+  Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
                                       OpFoldResult laneId, Value memref,
                                       IndexCalculator indexFn,
                                       ArrayRef<int64_t> vectorShape);
@@ -535,7 +533,7 @@ struct MmaSyncBuilder {
   /// Build a list of memref.store operations indexed at `(row, col)` indices
   /// that make sense for a particular MMA instruction and specified via the
   /// IndexCalculator callback.
-  SmallVector<Operation *> buildMemrefStores(OpBuilder &b, Location loc,
+  SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
                                              ValueRange toStore,
                                              OpFoldResult laneId, Value memref,
                                              IndexCalculator indexFn);
@@ -546,7 +544,7 @@ struct MmaSyncBuilder {
   /// data that makes sense for the particular MMA operation.
   /// The `vectorShape` matches existing NVGPU dialect op specification but
   /// could also be flattened in the future if needed for simplification.
-  SmallVector<Operation *> buildMmaSyncMemrefStoreOperand(
+  SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
       OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
       Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
 
@@ -573,7 +571,7 @@ static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
   }
 }
 
-SmallVector<Value> MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc,
+SmallVector<Value> MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
                                                     OpFoldResult laneId,
                                                     Value memref,
                                                     IndexCalculator indexFn) {
@@ -591,10 +589,10 @@ SmallVector<Value> MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc,
   return res;
 }
 
-Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand(
+Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
     OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
     IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
-  auto loads = buildMemrefLoads(b, loc, laneId, memref, indexFn);
+  auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);
 
   Type elementType = getElementTypeOrSelf(memref.getType());
   auto vt = VectorType::get(vectorShape, elementType);
@@ -614,7 +612,7 @@ Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand(
 }
 
 SmallVector<Operation *>
-MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc,
+MmaSyncBuilder::buildMemRefStores(OpBuilder &b, Location loc,
                                   ValueRange toStore, OpFoldResult laneId,
                                   Value memref, IndexCalculator indexFn) {
   auto aff = [&](AffineExpr e) {
@@ -632,7 +630,7 @@ MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc,
   return res;
 }
 
-SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemrefStoreOperand(
+SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
     OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
     Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
   SmallVector<Value> toStore;
@@ -647,7 +645,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemrefStoreOperand(
       [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
         toStore.push_back(v);
       });
-  return buildMemrefStores(b, loc, toStore, laneId, memref, indexFn);
+  return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);
 }
 
 static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
@@ -690,22 +688,22 @@ MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
 }
 
 FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
-  Value lhsMemref = linalgOp.getDpsInputOperand(0)->get();
-  Value rhsMemref = linalgOp.getDpsInputOperand(1)->get();
-  Value resMemref = linalgOp.getDpsInitOperand(0)->get();
-  assert(lhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
+  Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
+  Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
+  Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
+  assert(lhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
          "expected lhs to be a 2D memref");
-  assert(rhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
+  assert(rhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
          "expected rhs to be a 2D memref");
-  assert(resMemref.getType().cast<MemRefType>().getRank() == 2 &&
+  assert(resMemRef.getType().cast<MemRefType>().getRank() == 2 &&
          "expected res to be a 2D memref");
 
-  int64_t m = cast<MemRefType>(lhsMemref.getType()).getShape()[0];
-  int64_t n = cast<MemRefType>(rhsMemref.getType()).getShape()[1];
-  int64_t k = cast<MemRefType>(lhsMemref.getType()).getShape()[1];
-  Type lhsType = getElementTypeOrSelf(lhsMemref.getType());
-  Type rhsType = getElementTypeOrSelf(rhsMemref.getType());
-  Type resType = getElementTypeOrSelf(resMemref.getType());
+  int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
+  int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
+  int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
+  Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
+  Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
+  Type resType = getElementTypeOrSelf(resMemRef.getType());
 
   FailureOr<MmaSyncInfo> maybeInfo =
       getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
@@ -715,15 +713,15 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
   MmaSyncInfo info = *maybeInfo;
   auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
   auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
-  Value lhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, lhsMemref,
+  Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
                                             lhsIndexFn, lhsShape);
-  Value rhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, rhsMemref,
+  Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
                                             rhsIndexFn, rhsShape);
-  Value res = buildMmaSyncMemrefLoadOperand(b, loc, laneId, resMemref,
+  Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
                                             resIndexFn, resShape);
   res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
                                    info.tf32Enabled);
-  buildMmaSyncMemrefStoreOperand(b, loc, res, laneId, resMemref, resIndexFn,
+  buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
                                  resShape);
   return res.getDefiningOp();
 }
@@ -754,6 +752,284 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// Hopper builders.
+//===----------------------------------------------------------------------===//
+
+/// Helper to create the base Hopper-specific operations that are reused in
+/// various other places.
+struct HopperBuilder {
+  HopperBuilder(RewriterBase &rewriter, Location loc)
+      : rewriter(rewriter), loc(loc) {}
+
+  TypedValue<nvgpu::MBarrierType>
+  buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
+
+  /// Create tma descriptor op to initiate transfer from global to shared
+  /// memory. This must be done before the launch op, on the host.
+  TypedValue<nvgpu::TensorMapDescriptorType>
+  buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
+                              gpu::LaunchOp launchOp);
+
+  /// Build a tma load from global memory to shared memory using `barrier` to
+  /// synchronize. Return the number of bytes that will be transferred.
+  OpFoldResult
+  buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
+                    TypedValue<MemRefType> sharedMemref,
+                    TypedValue<nvgpu::MBarrierType> barrier,
+                    SmallVectorImpl<Operation *> &loadOps);
+  void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierType> barrier,
+                            ArrayRef<OpFoldResult> sizes);
+
+  /// If threadIdx.x == 0 does TMA request + wait, else just wait.
+  /// Return the operation that performs the transfer on thread0.
+  // TODO: In the future, don't hardcode to thread 0 but elect a leader.
+  SmallVector<Operation *> buildPredicateLoadsOnThread0(
+      ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
+      ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
+      TypedValue<nvgpu::MBarrierType> barrier);
+
+  void buildTryWaitParity(TypedValue<nvgpu::MBarrierType> barrier);
+
+  RewriterBase &rewriter;
+  Location loc;
+};
+
+SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
+    ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
+    ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
+    TypedValue<nvgpu::MBarrierType> barrier) {
+  SmallVector<Operation *> loadOps;
+  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
+  Value cond =
+      rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
+  // clang-format off
+  rewriter.create<scf::IfOp>(
+    /*location=*/loc,
+    /*conditional=*/cond,
+    /*thenBuilder=*/
+    [&](OpBuilder &lb, Location loc) {
+      SmallVector<OpFoldResult> sizes;
+      sizes.reserve(globalDescriptors.size());
+      for (auto [desc, shmem] : llvm::zip_equal(
+              globalDescriptors, sharedMemBuffers)) {
+        OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
+        sizes.push_back(sz);
+      }
+      // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
+      // This may or may not have perf implications.
+      buildBarrierArriveTx(barrier, sizes);
+      rewriter.create<scf::YieldOp>(loc);
+    },
+    /*elseBuilder=*/
+    [&](OpBuilder &lb, Location loc) {
+      // TODO: is this for no-thread divergence?
+      // Should we just yield the size and hoist?
+      buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
+      rewriter.create<scf::YieldOp>(loc);
+    });
+  // clang-format on
+  return loadOps;
+}
+
+static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
+  return gpu::AddressSpaceAttr::get(
+      b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
+  // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
+}
+
+TypedValue<nvgpu::MBarrierType>
+HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
+  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
+  Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
+      loc, nvgpu::MBarrierType::get(rewriter.getContext(), sharedMemorySpace));
+  rewriter.create<nvgpu::MBarrierInitOp>(
+      loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads));
+  rewriter.create<gpu::BarrierOp>(loc);
+  return cast<TypedValue<nvgpu::MBarrierType>>(barrier);
+}
+
+TypedValue<nvgpu::TensorMapDescriptorType>
+HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
+                                           gpu::LaunchOp launchOp) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(launchOp);
+  Value unrankedMemRef = rewriter.create<memref::CastOp>(
+      loc,
+      UnrankedMemRefType::get(memref.getType().getElementType(),
+                              memref.getType().getMemorySpace()),
+      memref);
+  SmallVector<OpFoldResult> mixedSizes =
+      memref::getMixedSizes(rewriter, loc, memref);
+  SmallVector<Value> sizes =
+      getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
+
+  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
+  Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
+      loc,
+      nvgpu::TensorMapDescriptorType::get(
+          rewriter.getContext(),
+          MemRefType::Builder(memref.getType())
+              .setMemorySpace(sharedMemorySpace),
+          TensorMapSwizzleKind::SWIZZLE_NONE,
+          TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
+          TensorMapInterleaveKind::INTERLEAVE_NONE),
+      unrankedMemRef, sizes);
+  return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
+}
+
+OpFoldResult HopperBuilder::buildTmaAsyncLoad(
+    TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
+    TypedValue<MemRefType> sharedMemref,
+    TypedValue<nvgpu::MBarrierType> barrier,
+    SmallVectorImpl<Operation *> &loadOps) {
+  MLIRContext *ctx = rewriter.getContext();
+  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
+      loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero});
+  loadOps.push_back(loadOp);
+  auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
+  SmallVector<AffineExpr> symbols(mixedSizes.size());
+  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
+  AffineExpr prodExprInBytes =
+      computeProduct(ctx, symbols) *
+      (sharedMemref.getType().getElementTypeBitWidth() / 8);
+  auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
+                                                   prodExprInBytes, mixedSizes);
+  return res;
+}
+
+void HopperBuilder::buildBarrierArriveTx(
+    TypedValue<nvgpu::MBarrierType> barrier,
+    ArrayRef<OpFoldResult> mixedSizes) {
+  assert(!mixedSizes.empty() && "expecte non-empty sizes");
+  MLIRContext *ctx = rewriter.getContext();
+  SmallVector<AffineExpr> symbols(mixedSizes.size());
+  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
+  AffineExpr sumExpr = computeSum(ctx, symbols);
+  OpFoldResult size =
+      affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
+  Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
+  rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal);
+}
+
+void HopperBuilder::buildTryWaitParity(
+    TypedValue<nvgpu::MBarrierType> barrier) {
+  Value parity = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  // 10M is an arbitrary, not too small or too big number to specify the number
+  // of ticks before retry.
+  // TODO: hoist this in a default dialect constant.
+  Value ticksBeforeRetry =
+      rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
+  rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
+                                                  ticksBeforeRetry);
+}
+
+//===----------------------------------------------------------------------===//
+// RewriteCopyAsTmaOp
+//===----------------------------------------------------------------------===//
+
+/// Helper to create the tma operations corresponding to `linalg::CopyOp`.
+struct CopyBuilder : public HopperBuilder {
+  CopyBuilder(RewriterBase &rewriter, Location loc)
+      : HopperBuilder(rewriter, loc) {}
+
+  SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps);
+};
+
+SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
+  MLIRContext *ctx = rewriter.getContext();
+  if (copyOps.empty())
+    return SmallVector<Operation *>();
+
+  auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
+  assert(launchOp && "expected launch op");
+
+  // 1. Init a barrier object in shared memory.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(copyOps.front());
+  AffineExpr bx, by, bz;
+  bindSymbols(ctx, bx, by, bz);
+  AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
+  OpFoldResult numThreads = affine::makeComposedFoldedAffineApply(
+      rewriter, loc, prod,
+      ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
+                             launchOp.getBlockSizeZ()});
+
+  TypedValue<nvgpu::MBarrierType> barrier =
+      buildAndInitBarrierInSharedMemory(numThreads);
+
+  SmallVector<TypedValue<MemRefType>> shmems;
+  SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs;
+  for (Operation *op : copyOps) {
+    auto copyOp = cast<linalg::CopyOp>(op);
+    auto inMemRef =
+        cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
+    MemRefType inMemRefType = inMemRef.getType();
+    assert(inMemRefType.getRank() == 2 && "expected in to be a 2D memref");
+
+    // 2. Build global memory descriptor.
+    TypedValue<nvgpu::TensorMapDescriptorType> globalDesc =
+        buildGlobalMemRefDescriptor(inMemRef, launchOp);
+    globalDescs.push_back(globalDesc);
+
+    // 3. Shared memory and descriptor for the tmp array.
+    auto shmem =
+        cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
+    shmems.push_back(shmem);
+  }
+
+  // 4. Load in from global memory to shared memory using tma.
+  OpBuilder::InsertionGuard g2(rewriter);
+  rewriter.setInsertionPoint(copyOps.front());
+  SmallVector<Operation *> results =
+      buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
+
+  // 5. Spin-loop until data is ready.
+  buildTryWaitParity(barrier);
+
+  // 6. Erase the ops that have now been rewritten.
+  for (Operation *op : copyOps)
+    rewriter.eraseOp(op);
+
+  return results;
+}
+
+DiagnosedSilenceableFailure
+transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
+                                     transform::TransformResults &results,
+                                     transform::TransformState &state) {
+  auto payloadOps = state.getPayloadOps(getTarget());
+  gpu::LaunchOp commonLaunchOp;
+  Operation *firstOp, *failingOp;
+  if (llvm::any_of(payloadOps, [&](Operation *op) {
+        if (!commonLaunchOp) {
+          commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
+          firstOp = op;
+        }
+        auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
+                    commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
+                    !isa<linalg::CopyOp>(op);
+        if (fail)
+          failingOp = op;
+        return fail;
+      })) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "target ops must be linalg::CopyOp nested under a common "
+           "gpu.LaunchOp to be rewritten because the tma descriptors need to "
+           "be created on the host.\nBut got: "
+        << *firstOp << "\nand " << *failingOp;
+    return diag;
+  }
+
+  // TODO: more robust detection of copy, with transposes etc.
+  CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
@@ -767,6 +1043,7 @@ class NVGPUTransformDialectExtension
     declareGeneratedDialect<arith::ArithDialect>();
     declareGeneratedDialect<affine::AffineDialect>();
     declareGeneratedDialect<nvgpu::NVGPUDialect>();
+    declareGeneratedDialect<NVVM::NVVMDialect>();
     declareGeneratedDialect<vector::VectorDialect>();
     registerTransformOps<
 #define GET_OP_LIST

diff  --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 5821876139d064..2a774b599a8b68 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -161,13 +161,13 @@ AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
   if (basis.empty())
     return getAffineConstantExpr(0, ctx);
   return std::accumulate(basis.begin(), basis.end(),
-                         getAffineConstantExpr(1, ctx),
+                         getAffineConstantExpr(0, ctx),
                          std::plus<AffineExpr>());
 }
 
 AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
   if (basis.empty())
-    return getAffineConstantExpr(0, ctx);
+    return getAffineConstantExpr(1, ctx);
   return std::accumulate(basis.begin(), basis.end(),
                          getAffineConstantExpr(1, ctx),
                          std::multiplies<AffineExpr>());

diff  --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
new file mode 100644
index 00000000000000..646008b64f794f
--- /dev/null
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -0,0 +1,84 @@
+// RUN: mlir-opt %s \
+// RUN:     -test-transform-dialect-interpreter \
+// RUN:     -test-transform-dialect-erase-schedule \
+// RUN: | FileCheck %s
+
+memref.global "private" @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+
+// CHECK-LABEL: func.func @main()
+func.func @main() {
+  %c1 = arith.constant 1 : index
+  %c128 = arith.constant 128 : index
+
+  %0 = gpu.wait async
+  %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
+  %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
+
+  //      CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32> to memref<*xf32>
+  //      CHECK: %[[c64:.*]] = arith.constant 64 : index
+  //      CHECK: %[[c8:.*]] = arith.constant 8 : index
+  //      CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]] 
+  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+  //      CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32> to memref<*xf32>
+  //      CHECK: %[[c8_2:.*]] = arith.constant 8 : index
+  //      CHECK: %[[c128_2:.*]] = arith.constant 128 : index
+  //      CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]] 
+  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+  // CHECK: gpu.launch
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+            threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
+    //      CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
+    //      CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+    %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
+    %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+    
+    //      CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>
+    //      CHECK: nvgpu.mbarrier.init %[[B]], %{{.*}} : <memorySpace = #gpu.address_space<workgroup>
+    //      CHECK: gpu.barrier
+    //
+    //      CHECK: %[[c0:.*]] = arith.constant 0 : index
+    //      CHECK: %[[TIDX:.*]] = gpu.thread_id  x
+    //      CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[TIDX]], %[[c0]] : index
+    //
+    //      CHECK: scf.if %[[CMP]] {
+    //
+    //      CHECK:   %[[c0_7:.*]] = arith.constant 0 : index
+    //      CHECK:   nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]] to %[[G1]] 
+    // CHECK-SAME:     : <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>, 
+    // CHECK-SAME:        swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
+    // CHECK-SAME:     -> memref<64x8xf32, #gpu.address_space<workgroup>>
+    //
+    //      CHECK:   %[[c0_8:.*]] = arith.constant 0 : index
+    //      CHECK:   nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]] to %[[G2]] 
+    // CHECK-SAME:     : <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>,
+    // CHECK-SAME:         swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup> 
+    // CHECK-SAME:    -> memref<8x128xf32, #gpu.address_space<workgroup>>
+    //
+    //      CHECK:   %[[c6144:.*]] = arith.constant 6144 : index
+    //      CHECK:   nvgpu.mbarrier.arrive.expect_tx %[[B]], %[[c6144]] : <memorySpace = #gpu.address_space<workgroup>
+    //      CHECK: } else {
+    //      CHECK:   %[[c0_7:.*]] = arith.constant 0 : index
+    //      CHECK:   nvgpu.mbarrier.arrive.expect_tx %[[B]], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
+    //      CHECK: }
+    //
+    //      CHECK: %[[c0_6:.*]] = arith.constant 0 : index
+    //      CHECK: %[[c10000000:.*]] = arith.constant 10000000 : index
+    //      CHECK: nvgpu.mbarrier.try_wait.parity %[[B]], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
+
+    /// Both copies are matched and end up in the same async group.    
+    linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, #gpu.address_space<workgroup>>)
+    linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, #gpu.address_space<workgroup>>)
+
+    gpu.terminator
+  }
+  
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %copy = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.nvgpu.rewrite_copy_as_tma %copy  : (!transform.any_op) -> ()
+}

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir
new file mode 100644
index 00000000000000..3a6bbe7f0d7721
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s \
+// RUN:     -test-transform-dialect-interpreter \
+// RUN:     -test-transform-dialect-erase-schedule \
+// RUN:     -convert-nvgpu-to-nvvm -gpu-kernel-outlining \
+// RUN:     -convert-scf-to-cf -convert-nvvm-to-llvm \
+// RUN:     -convert-vector-to-llvm \
+// RUN:     -convert-math-to-llvm \
+// RUN:     -expand-strided-metadata \
+// RUN:     -lower-affine \
+// RUN:     -convert-index-to-llvm=index-bitwidth=32 \
+// RUN:     -convert-arith-to-llvm \
+// RUN:     -finalize-memref-to-llvm \
+// RUN:     -convert-func-to-llvm \
+// RUN:     -canonicalize \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_90 features=+ptx80 dump-ptx}))' \
+// RUN: 2&>1 | FileCheck %s --check-prefixes=CHECK-PTX
+
+// CHECK-PTX: mbarrier.init.shared {{.*}} !llvm.ptr<3>, i32
+/// If branch
+// CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
+// CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
+// CHECK-PTX: mbarrier.arrive.expect_tx.shared
+/// Else branch
+// CHECK-PTX: mbarrier.arrive.expect_tx.shared
+// CHECK-PTX: mbarrier.try_wait.parity.shared
+
+// TODO: GPU layering does not currently work end-to-end. Activate the following
+// when fixed.
+// R-UN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \
+// R-UN:     -gpu-to-llvm \
+// R-UN:     -convert-func-to-llvm \
+// R-UN:     -cse \
+// R-UN:     -canonicalize \
+// R-UN:     -reconcile-unrealized-casts \
+// R-UN: | mlir-cpu-runner \
+// R-UN:   --shared-libs=%mlir_cuda_runtime \
+// R-UN:   --shared-libs=%mlir_runner_utils \
+// R-UN:   --entry-point-result=void \
+// R-UN: | FileCheck %s
+
+// C-HECK: [GPU] TMA BEFORE lhs[45][7] 0.000000
+// C-HECK: [GPU] TMA BEFORE rhs[7][0] 0.000000
+// C-HECK: [GPU] TMA LOADED lhs[45][7] 7.000000
+// C-HECK: [GPU] TMA LOADED rhs[7][0] 3.000000
+
+
+module @mymod {
+  memref.global "private" @bufferLhsGlobal : memref<64x8xf32, 3>
+  memref.global "private" @bufferRhsGlobal : memref<8x128xf32, 3>
+  func.func @main() {
+    %c10000000 = arith.constant 10000000 : index
+    %c6144 = arith.constant 6144 : index
+    %c45 = arith.constant 45 : index
+    %c7 = arith.constant 7 : index
+    %c64 = arith.constant 64 : index
+    %c1 = arith.constant 1 : index
+    %c0 = arith.constant 0 : index
+    %c8 = arith.constant 8 : index
+    %c128 = arith.constant 128 : index
+    %cst = arith.constant 3.000000e+00 : f32
+    %alloc = memref.alloc() : memref<64x8xf32>
+    %alloc_0 = memref.alloc() : memref<8x128xf32>
+    scf.for %arg0 = %c0 to %c8 step %c1 {
+      scf.for %arg1 = %c0 to %c128 step %c1 {
+        memref.store %cst, %alloc_0[%arg0, %arg1] : memref<8x128xf32>
+      }
+    }
+    scf.for %arg0 = %c0 to %c64 step %c1 {
+      scf.for %arg1 = %c0 to %c8 step %c1 {
+        %5 = arith.index_cast %arg1 : index to i64
+        %6 = arith.uitofp %5 : i64 to f32
+        memref.store %6, %alloc[%arg0, %arg1] : memref<64x8xf32>
+      }
+    }
+    %0 = gpu.wait async
+    %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
+    %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
+    %1 = gpu.memcpy async [%0] %memref, %alloc : memref<64x8xf32>, memref<64x8xf32>
+    %2 = gpu.memcpy async [%0] %memref_1, %alloc_0 : memref<8x128xf32>, memref<8x128xf32>
+    
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+              threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
+      %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, 3>
+      %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, 3>
+      linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, 3>)
+      linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, 3>)
+
+      %6 = gpu.thread_id  x
+      %10 = arith.cmpi eq, %6, %c0 : index
+      scf.if %10 {
+        %11 = memref.load %out[%c45, %c7] : memref<64x8xf32, 3>
+        %12 = memref.load %out_1[%c7, %c0] : memref<8x128xf32, 3>
+        gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32
+        gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32
+      }
+      gpu.terminator
+    }
+    
+    return
+  }
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %copy = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.nvgpu.rewrite_copy_as_tma %copy 
+    : (!transform.any_op) -> ()
+}


        


More information about the Mlir-commits mailing list