[Mlir-commits] [mlir] 171a5a7 - [mlir][Linalg] Add a greedy transform to map copies to threads efficiently.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jul 10 09:16:10 PDT 2023


Author: Nicolas Vasilache
Date: 2023-07-10T16:11:04Z
New Revision: 171a5a761d24664786f65c861d083df306532a35

URL: https://github.com/llvm/llvm-project/commit/171a5a761d24664786f65c861d083df306532a35
DIFF: https://github.com/llvm/llvm-project/commit/171a5a761d24664786f65c861d083df306532a35.diff

LOG: [mlir][Linalg] Add a greedy transform to map copies to threads efficiently.

This revision adds a new transformation to map a copy operation to a gpu grid of threads.
It implements a first heuristic that allows trading off coalesced accesses vs predication and occupancy.

Differential Revision: https://reviews.llvm.org/D154836

Added: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h
    mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
    mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h
new file mode 100644
index 00000000000000..f9fd32c2031afb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h
@@ -0,0 +1,127 @@
+//===- GPUHeuristics.h - GPU heuristics for Linalg transforms ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_GPUHEURISTICS_H
+#define MLIR_DIALECT_LINALG_TRANSFORMOPS_GPUHEURISTICS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace transform {
+namespace gpu {
+
+/// Base struct to hold GPU mapping information for a given operation.
+struct MappingInfo {
+  /// Number of threads to use for the mapping.
+  /// Note: When the number of threads used is smaller than the total number of
+  /// available threads, predication ensues. It is often useful to use more
+  /// threads and saturate memory bandwidth for some operations, even if others
+  /// end up being predicated.
+  SmallVector<int64_t> numThreads;
+
+  /// Thread mapping attributes, one per entry of `numThreads`.
+  SmallVector<Attribute> threadMapping;
+};
+
+struct CopyMappingInfo : public MappingInfo {
+  /// Status of the mapping computation, invalid usually means too many threads
+  /// are required and we fail to map. This usually happens when the copy is too
+  /// large compared to the number of threads.
+  enum class Status { Success = 0, RequiresPredication, Invalid };
+
+  /// Greedily compute the MappingInfo to use to perform a copy of `sizes`
+  /// elements of bitwidth `elementalBitwidth`.
+  /// The `desiredBitAlignment` is the number of elements by which the most
+  /// minor dimension of the copy is expected to be aligned.
+  /// This is an approximation of the final alignment, for each row of the copy.
+  /// This is used to restrict the size of copied vector so that they match
+  /// potential subsequent cp.async.
+  /// If the alignment does not match the required alignment for a cp.async down
+  /// the line, the conversion to cp.async will be eventually skipped, possibly
+  /// degrading performance.
+  /// When `favorPredication` is false, the mapping is computed to fill all
+  /// threads with an equal amount of data to copy, so as to avoid predication.
+  /// Predication ends up requiring a split epilogue in current pipelining
+  /// implementations and is better avoided when possible.
+  CopyMappingInfo(MLIRContext *ctx, int totalNumThreads,
+                  int64_t desiredBitAlignment, ArrayRef<int64_t> sizes,
+                  bool favorPredication = false,
+                  int64_t elementalBitwidth = 32);
+
+private:
+  /// Determine the maximal vector size to use to copy a contiguous array of
+  /// `numContiguousElements`, each of bitwidth `elementalBitwidth`.
+  /// The `alignment` is the number of elements by which the most minor
+  /// dimension of the copy is aligned. This is an approximation of actual
+  /// memory alignment after bufferization, for each row of the copy. This is
+  /// used to restrict the of the copied vector so that it is properly aligned
+  /// with the requirements of cp.async. If the copy alignment does not match
+  /// the required aligned for a cp.async, thae conversion to cp.async will be
+  /// skipped.
+  /// Asserts that `elementalBitwidth` divides `numContiguousElements`.
+  static int64_t
+  maxContiguousElementsToTransfer(int64_t alignment,
+                                  int64_t numContiguousElements,
+                                  int64_t elementalBitwidth = 32);
+
+  /// Compute the number of threads to use to perform a copy of `sizes`
+  /// elements of `elementalBitwidth`.
+  /// The `alignment` is the number of elements by which the most minor
+  /// dimension of the copy is aligned. This is an approximation of actual
+  /// memory alignment after bufferization, for each row of the copy. This is
+  /// used to restrict the of the copied vector so that it is properly aligned
+  /// with the requirements of cp.async. If the copy alignment does not match
+  /// the required aligned for a cp.async, the conversion to cp.async will be
+  /// skipped.
+  /// When `favorPredication` is false, the implementation avoids predication
+  /// in the copy, even if it means reducing the granularity of the transfer.
+  /// Otherwise, the implementation will come up with a maximal assignment of
+  /// the remaining threads to sizes of interest, using a DP implementation.
+  Status inferNumThreads(int64_t totalNumThreads, ArrayRef<int64_t> sizes,
+                         int64_t desiredVectorSize, bool favorPredication);
+  Status inferNumThreadsImpl(int64_t totalNumThreads, ArrayRef<int64_t> sizes,
+                             int64_t desiredVectorSize);
+
+public:
+  // Pretty-printing and diagnostic methods.
+  void print(llvm::raw_ostream &os) const;
+  LLVM_DUMP_METHOD void dump() const;
+
+  /// Static quantity determining the number of bits to target in an individual
+  /// copy. Assumes that smaller increments of 64, 32, 16, 8 are also valid
+  /// transfer sizes. In the future we should have more hardware pluggability
+  /// here, especially when we want sub-byte granularity
+  static constexpr int64_t kMaxVectorLoadBitWidth = 128;
+
+  /// Most minor vector size (i.e. 1-D), in number of elements, used in a copy.
+  int64_t vectorSize;
+
+  /// Number of threads to use for the copy mapping, from most major to most
+  /// minor dims (i.e. numThreads.back() should be mapped to contiguous threads
+  /// for best coalescing).
+  using MappingInfo::numThreads;
+
+  /// Explicit computation / injection of the smallest bounding tile sizes after
+  /// mapping to `numThreads`. This is useful in masked scenarios.
+  SmallVector<int64_t> smallestBoundingTileSizes;
+
+  /// Thread mapping attributes, one per entry of `numThreads`.
+  using MappingInfo::threadMapping;
+
+  /// The status of a particular copy mapping. Must be checked before applying
+  /// transformations.
+  Status status;
+};
+
+} // namespace gpu
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_GPUHEURISTICS_H

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 797c6945c3cbff..d71f254d1dcbc7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -22,6 +22,7 @@ class TilingInterface;
 class RewriterBase;
 
 namespace linalg {
+class CopyOp;
 struct ForallTilingResult;
 class GenericOp;
 class LinalgOp;

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 086a37452debf0..82a1fcc5c8f4a4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2261,4 +2261,60 @@ def InsertSliceToCopyOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MapCopyToThreadsOp
+//===----------------------------------------------------------------------===//
+
+def MapCopyToThreadsOp :
+  Op<Transform_Dialect, "structured.gpu.map_copy_to_threads",
+    [FunctionalStyleTransformOpTrait, 
+     MemoryEffectsOpInterface,
+     TransformEachOpTrait, 
+     TransformOpInterface]> {
+  let description = [{
+    Targeted mapping of a copy operation on tensors to a GPU thread mapping.
+
+    This operation implements a greedy heuristic that determines a good 
+    distribution of threads to break down the copy operation into.
+    The heuristic is driven by considerations related to the underlying 
+    architecture for which good high-level decisions are needed assuming certain
+    hardware features. Relevant features are exposed via first-class attributes
+    to control the behavior of the transformation at a high level.
+
+    For now, a single heuristic is implemented and can be extended on a per-need
+    basis.
+
+    #### Return modes:
+
+    The operation always succeeds and returns a handle to the relevant tiled
+    linalg.copy op.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       I64Attr:$total_num_threads,
+                       I64Attr:$desired_bit_alignment);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat = [{
+    $target
+    `total_num_threads` `=` $total_num_threads
+    `desired_bit_alignment` `=` $desired_bit_alignment
+    attr-dict 
+    `:` functional-type(operands, results)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>,
+  ];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::CopyOp copyOp,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedNumThreads();
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
index 129863608d01c5..bef72c3283ca7f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRLinalgTransformOps
   DialectExtension.cpp
+  GPUHeuristics.cpp
   LinalgMatchOps.cpp
   LinalgTransformOps.cpp
   Syntax.cpp

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
new file mode 100644
index 00000000000000..e4fa8d6bc74b74
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
@@ -0,0 +1,267 @@
+//===- GPUHeuristics.cpp - Heuristics Implementation for Transforms -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cmath>
+#include <numeric>
+
+using namespace mlir;
+
+#define DEBUG_TYPE "linalg-transforms"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+static Attribute linearIdX(MLIRContext *ctx) {
+  return gpu::GPULinearIdMappingAttr::get(ctx, gpu::LinearId::DimX);
+}
+static Attribute linearIdY(MLIRContext *ctx) {
+  return gpu::GPULinearIdMappingAttr::get(ctx, gpu::LinearId::DimY);
+}
+static Attribute linearIdZ(MLIRContext *ctx) {
+  return gpu::GPULinearIdMappingAttr::get(ctx, gpu::LinearId::DimZ);
+}
+
+transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx,
+                                                 int totalNumThreads,
+                                                 int64_t desiredBitAlignment,
+                                                 ArrayRef<int64_t> copySizes,
+                                                 bool favorPredication,
+                                                 int64_t elementalBitwidth) {
+  assert(!copySizes.empty() && copySizes.size() <= 3 &&
+         "only 1,2,3-D copies are supported for now");
+
+  LDBG("START CopyMappingInfo, favorPredication: " << favorPredication);
+  LLVM_DEBUG(llvm::interleaveComma(copySizes, DBGS() << "--copy shape: ");
+             llvm::dbgs() << "\n";);
+
+  // Greedily find the largest vector size that can be used to copy the most
+  // minor dimension: we are in the business of filling kMaxVectorLoadBitWidth
+  // contiguous memory transactions with as few threads as possible.
+  int64_t desiredVectorSize = CopyMappingInfo::maxContiguousElementsToTransfer(
+      desiredBitAlignment, copySizes.back(), elementalBitwidth);
+
+  LDBG("--greedily determined vectorSize: "
+       << desiredVectorSize << " elements of " << elementalBitwidth
+       << "b each -> " << (desiredVectorSize * elementalBitwidth)
+       << "b total out of a max of " << kMaxVectorLoadBitWidth << "b");
+
+  status = inferNumThreads(totalNumThreads, copySizes, desiredVectorSize,
+                           favorPredication);
+  if (status == Status::Invalid)
+    return;
+
+  LLVM_DEBUG(llvm::interleaveComma(copySizes, DBGS() << "--copy: ");
+             llvm::dbgs() << "\n"; llvm::interleaveComma(
+                 this->numThreads, DBGS() << "--numThreads: ");
+             llvm::dbgs() << "\n";);
+  LDBG("--vectorSize: " << this->vectorSize);
+  assert(this->numThreads.size() == copySizes.size() &&
+         "compute copy mapping expected same number of threads and copy sizes");
+
+  // Compute the smallest bounding box.
+  this->smallestBoundingTileSizes = llvm::to_vector(
+      llvm::map_range(llvm::zip(copySizes, this->numThreads), [](auto &&pair) {
+        int64_t size, numThreads;
+        std::tie(size, numThreads) = pair;
+        return mlir::ceilDiv(size, numThreads);
+      }));
+  SmallVector<Attribute> allThreadMappings{linearIdZ(ctx), linearIdY(ctx),
+                                           linearIdX(ctx)};
+
+  // Set the thread mapping.
+  this->threadMapping =
+      llvm::to_vector(ArrayRef(allThreadMappings)
+                          .take_back(this->smallestBoundingTileSizes.size()));
+  LLVM_DEBUG(this->print(DBGS()); llvm::dbgs() << "\n");
+}
+
+int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer(
+    int64_t desiredBitAlignment, int64_t numContiguousElements,
+    int64_t elementalBitwidth) {
+  assert(kMaxVectorLoadBitWidth % elementalBitwidth == 0 &&
+         "elemental bitwidth does not divide kMaxVectorLoadBitWidth");
+  assert(desiredBitAlignment % elementalBitwidth == 0 &&
+         "elemental bitwidth does not divide desired bit alignment");
+  return std::gcd(
+      std::gcd(desiredBitAlignment / elementalBitwidth, numContiguousElements),
+      kMaxVectorLoadBitWidth / elementalBitwidth);
+}
+
+/// Get the list of all factors that divide `val`, not just the prime factors.
+static SmallVector<int64_t> getFactors(int64_t val) {
+  SmallVector<int64_t> factors;
+  factors.reserve(val);
+  for (int64_t factor = 1; factor <= val; ++factor) {
+    if (val % factor != 0)
+      continue;
+    factors.push_back(factor);
+  }
+  factors.push_back(val);
+  return factors;
+}
+
+static int64_t product(ArrayRef<int64_t> vals) {
+  int64_t res = 1;
+  for (auto val : vals)
+    res *= val;
+  return res;
+}
+
+/// Extract `result` from `sizes` with the following constraints:
+///   1. sizes[i] % result[i] for all i
+///   2. product_of_threadsPerDim <= maxNumThreads
+///   3. if `currentIndex` is sizes.size() - 1, then threadsPerDim[currentIndex]
+///      must be sizes[currentIndex].
+/// This is used to greedily extract the maximum number of threads usable for
+/// mapping a copy of size `sizes`, while being bounded by `totalNumThreads` and
+/// ensuring coalesced access along the most minor dimension.
+/// Return the number of threads used in the range:
+///   threadsPerDim[currentIndex .. sizes.end()]
+// The implementation uses a dynamic programming approach to greedily extract
+// the best combination under the constraints.
+// TODO: Implementation details can be improved but putting effort there is a
+// tradeoffs: `sizes` is expected to be of small rank and contain small values.
+static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes,
+                                               int64_t currentIndex,
+                                               int64_t maxNumThreads) {
+  assert(static_cast<size_t>(currentIndex) < sizes.size() &&
+         "currentIndex out of bounds");
+  std::string indent(2 * currentIndex, '-');
+  if (static_cast<size_t>(currentIndex) == sizes.size() - 1) {
+    LDBG(indent << "mandated globalBest: " << sizes[currentIndex]);
+    return SmallVector<int64_t>{sizes[currentIndex]};
+  }
+
+  int64_t best = 0;
+  int64_t s = sizes[currentIndex];
+  SmallVector<int64_t> factors = getFactors(s);
+  SmallVector<int64_t> localThreadsPerDim;
+  localThreadsPerDim.reserve(sizes.size());
+  LDBG(indent << "maximizeNumThreads in " << s
+              << " with limit: " << maxNumThreads);
+  for (auto factor : factors) {
+    auto nestedThreadsPerDim =
+        maximizeNumThreads(sizes, currentIndex + 1, maxNumThreads / factor);
+    int64_t localBest = factor * product(nestedThreadsPerDim);
+    if (localBest > best && localBest <= maxNumThreads) {
+      LDBG(indent << "new localBest: " << localBest);
+      LLVM_DEBUG(
+          llvm::interleaveComma(nestedThreadsPerDim,
+                                DBGS() << indent << "nestedThreadsPerDim: ");
+          llvm::dbgs() << "\n";);
+      localThreadsPerDim.clear();
+      localThreadsPerDim.push_back(factor);
+      llvm::append_range(localThreadsPerDim, nestedThreadsPerDim);
+      best = localBest;
+    }
+  }
+
+  LDBG(indent << "found globalBest: " << best);
+  LLVM_DEBUG(llvm::interleaveComma(localThreadsPerDim,
+                                   DBGS() << indent << "numThreads: ");
+             llvm::dbgs() << "\n";);
+
+  return localThreadsPerDim;
+}
+
+transform::gpu::CopyMappingInfo::Status
+transform::gpu::CopyMappingInfo::inferNumThreads(int64_t totalNumThreads,
+                                                 ArrayRef<int64_t> sizes,
+                                                 int64_t desiredVectorSize,
+                                                 bool favorPredication) {
+
+  if (!favorPredication) {
+    int64_t localVectorSize = desiredVectorSize;
+    for (; localVectorSize >= 1; localVectorSize /= 2) {
+      // Attempt to map the copy with predication and current fixed vector size:
+      //   1. if the status is Success, we are done.
+      //   2. if the status is Invalid, we fail immediately, no amount of
+      //   vector size reduction can offset the bad tile size selection from the
+      //   higher-level.
+      //   3. if the status is RequiresPredication, we try again with a smaller
+      //   vector size.
+      Status status =
+          inferNumThreadsImpl(totalNumThreads, sizes, localVectorSize);
+      if (status == Status::Success || status == Status::Invalid)
+        return status;
+
+      LDBG("requires predication, try reducing vector size to "
+           << (localVectorSize / 2));
+    }
+  }
+
+  // If we have not yet returned, it means that we have tried all vector sizes
+  // and we still require predication. Restart from the original vector size and
+  // do not attempt to
+  return inferNumThreadsImpl(totalNumThreads, sizes, desiredVectorSize);
+}
+
+transform::gpu::CopyMappingInfo::Status
+transform::gpu::CopyMappingInfo::inferNumThreadsImpl(
+    int64_t totalNumThreads, ArrayRef<int64_t> sizes,
+    int64_t desiredVectorSize) {
+  assert(sizes.back() % desiredVectorSize == 0 &&
+         "most-minor size not divisible by actualVectorSize");
+
+  LDBG("inferNumThreadsImpl with totalNumThreads: "
+       << totalNumThreads << " and vectorSize: " << desiredVectorSize);
+
+  // Scale the most minor size to account for the chosen vector size and
+  // maximize the number of threads without exceeding the total number of
+  // threads.
+  SmallVector<int64_t> scaledSizes{sizes};
+  scaledSizes.back() /= desiredVectorSize;
+  if (scaledSizes.back() > totalNumThreads) {
+    LDBG("--Too few threads given the required vector size -> FAIL");
+    return Status::Invalid;
+  }
+  SmallVector<int64_t> inferredNumThreads =
+      maximizeNumThreads(scaledSizes, 0, totalNumThreads);
+
+  LLVM_DEBUG(llvm::interleaveComma(inferredNumThreads,
+                                   DBGS() << "inferred numThreads: ");
+             llvm::dbgs() << "\n";
+             LDBG("computed actualVectorSize: " << desiredVectorSize););
+
+  // Corner case: we cannot use more threads than available. If the dimension of
+  // the copy is so bad it is because higher-level tiling did not do its job, we
+  // do not try to recover from it here.
+  int64_t totalNumThreadsUsed = product(inferredNumThreads);
+  LDBG("--totalNumThreadsUsed: " << totalNumThreadsUsed);
+  if (totalNumThreadsUsed == 0 || totalNumThreadsUsed > totalNumThreads) {
+    LDBG("--Too few threads given the required vector size -> FAIL");
+    return Status::Invalid;
+  }
+
+  this->vectorSize = desiredVectorSize;
+  this->numThreads = inferredNumThreads;
+  if (totalNumThreadsUsed == totalNumThreads)
+    return Status::Success;
+
+  return Status::RequiresPredication;
+}
+
+void transform::gpu::CopyMappingInfo::print(llvm::raw_ostream &os) const {
+  os << "MappingInfo{";
+  os << "CopyMappingInfo: ";
+  os << "valid: " << (status != Status::Invalid) << ", ";
+  os << "vectorSize: " << vectorSize << ", ";
+  llvm::interleaveComma(numThreads, os << ", numThreads: {");
+  llvm::interleaveComma(smallestBoundingTileSizes,
+                        os << "}, smallestBoundingTileSizes: {");
+  llvm::interleaveComma(threadMapping, os << "}, threadMapping: {");
+  os << "}}";
+}

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index eed58bbbf1e18c..cf12e78145fd77 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
 #include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -31,6 +32,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/TilingInterface.h"
@@ -1700,10 +1702,10 @@ transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
     if (mapping.size() > 1)
       return emitDefaultDefiniteFailure(target);
 
-    auto addressSpace = cast<gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
+    auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
 
     if (addressSpace.getAddressSpace() ==
-        gpu::GPUDialect::getWorkgroupAddressSpace()) {
+        mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
       promotionOptions =
           promotionOptions
               .setAllocationDeallocationFns(allocateWorkgroupMemory,
@@ -1711,7 +1713,7 @@ transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
               .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
               .setUseFullTileBuffers({false, false});
     } else if (addressSpace.getAddressSpace() ==
-               gpu::GPUDialect::getPrivateAddressSpace()) {
+               mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
       promotionOptions =
           promotionOptions
               .setAllocationDeallocationFns(allocateGPUPrivateMemory,
@@ -3211,6 +3213,72 @@ DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
   return diag;
 }
 
+//===----------------------------------------------------------------------===//
+// MapCopyToThreadsOp
+//===----------------------------------------------------------------------===//
+DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::CopyOp copyOp,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  auto transformOp = cast<TransformOpInterface>(getOperation());
+  ShapedType resultShapedType;
+  if (copyOp) {
+    resultShapedType =
+        cast<ShapedType>(copyOp.getDpsInitOperand(0)->get().getType());
+  }
+  if (!copyOp || !resultShapedType.hasStaticShape()) {
+    DiagnosedSilenceableFailure diag =
+        transformOp.emitSilenceableError()
+        << "only statically sized linalg.copy ops of rank <= 3 are supported";
+    diag.attachNote(copyOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  // Conservatively set the minimum viable desired bitwidth alignment.
+  int64_t desiredBitAlignment = getDesiredBitAlignment();
+  int64_t eltBitwidth =
+      resultShapedType.getElementType().getIntOrFloatBitWidth();
+  if (desiredBitAlignment % eltBitwidth != 0) {
+    desiredBitAlignment = eltBitwidth;
+  }
+
+  gpu::CopyMappingInfo mapping(
+      /*ctx=*/getContext(),
+      /*totalNumThreads=*/getTotalNumThreads(),
+      /*alignment=*/desiredBitAlignment,
+      /*sizes=*/resultShapedType.getShape(),
+      /*favorPredication=*/false,
+      /*elementalBitwidth=*/
+      resultShapedType.getElementType().getIntOrFloatBitWidth());
+  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
+    DiagnosedSilenceableFailure diag =
+        transformOp.emitSilenceableError()
+        << "too few threads to map copy op to threads on the most minor "
+           "dimension, given alignment and vector size constraints, try "
+           "smaller tile size of mapping to more threads";
+    diag.attachNote(copyOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  // OpBuilder only used to compute attributes.
+  OpBuilder b(getContext());
+  linalg::ForallTilingResult tilingResult;
+  DiagnosedSilenceableFailure diag = tileToForallOpImpl(
+      /*rewriter=*/rewriter,
+      /*state=*/state,
+      /*transformOp=*/transformOp,
+      /*target=*/copyOp,
+      /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
+      /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
+      /*mapping=*/b.getArrayAttr(mapping.threadMapping),
+      /*tilingResult=*/tilingResult);
+  if (!diag.succeeded())
+    return diag;
+
+  results.push_back(tilingResult.tiledOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES

diff  --git a/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir b/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir
new file mode 100644
index 00000000000000..d822e2fac8051d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir
@@ -0,0 +1,407 @@
+// RUN: mlir-opt -test-transform-dialect-interpreter -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
+
+
+!tt = tensor<8xf16>
+
+// CHECK-LABEL: func @copy_1d_8xf16
+func.func @copy_1d_8xf16(%t0: !tt, %out: !tt) -> !tt {
+  /// Too little data for all threads, needs predication, while keeping most
+  /// minor transfer size -> 1 thread.
+  // CHECK: scf.forall {{.*}} in (1) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<8xf16>
+  // CHECK: {mapping = [#gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 128
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<16xf16>
+
+// CHECK-LABEL: func @copy_1d_16xf16
+func.func @copy_1d_16xf16(%t0: !tt, %out: !tt) -> !tt {
+  /// Too little data for all threads, needs predication, while keeping most
+  /// minor transfer size -> 2 threads.
+  // CHECK: scf.forall {{.*}} in (2) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<8xf16>
+  // CHECK: {mapping = [#gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 128
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<20xf16>
+
+// CHECK-LABEL: func @copy_1d_20xf16
+func.func @copy_1d_20xf16(%t0: !tt, %out: !tt) -> !tt {
+  /// Too little data for all threads, needs predication, while keeping most
+  /// minor transfer size -> 5 threads.
+  // CHECK: scf.forall {{.*}} in (5) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<4xf16>
+  // CHECK: {mapping = [#gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 128
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+
+// -----
+
+!tt = tensor<20xf16>
+
+// CHECK-LABEL: func @copy_1d_20xf16
+func.func @copy_1d_20xf16(%t0: !tt, %out: !tt) -> !tt {
+  /// Too little data for all threads, needs predication, while keeping most
+  /// minor transfer size -> 5 threads.
+  // CHECK: scf.forall {{.*}} in (5) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<4xf16>
+  // CHECK: {mapping = [#gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 128
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<128xf16>
+
+// CHECK-LABEL: func @copy_1d_128xf16
+func.func @copy_1d_128xf16(%t0: !tt, %out: !tt) -> !tt {
+  /// Enough data for all threads and no need for predication but we must reduce
+  /// the transfer size to 4xf16.
+  // CHECK: scf.forall {{.*}} in (32) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<4xf16>
+  // CHECK: {mapping = [#gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 128
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<256xf16>
+
+// CHECK-LABEL: func @copy_1d_256xf16
+func.func @copy_1d_256xf16(%t0: !tt, %out: !tt) -> !tt {
+  /// Enough data for all threads and no need for predication.
+  // CHECK: scf.forall {{.*}} in (32) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<8xf16>
+  // CHECK: {mapping = [#gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 128
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<16x32x64xi8>
+
+// CHECK-LABEL: func @copy_3d_16x32x64xi8
+func.func @copy_3d_16x32x64xi8(%t0: !tt, %out: !tt) -> !tt {
+  // CHECK: scf.forall {{.*}} in (1, 8, 4) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<16x4x16xi8>
+  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 128
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<16x32x64xi8>
+
+// CHECK-LABEL: func @copy_3d_16x32x64xi8
+func.func @copy_3d_16x32x64xi8(%t0: !tt, %out: !tt) -> !tt {
+  // CHECK: scf.forall {{.*}} in (1, 4, 8) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<16x8x8xi8>
+  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 64
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<4x8x16xi8>
+
+// CHECK-LABEL: func @copy_3d_4x8x16xi8
+func.func @copy_3d_4x8x16xi8(%t0: !tt, %out: !tt) -> !tt {
+  // CHECK: scf.forall {{.*}} in (4, 8, 1) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<1x1x16xi8>
+  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 128
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<4x8x16xi8>
+
+// CHECK-LABEL: func @copy_3d_4x8x16xi8
+func.func @copy_3d_4x8x16xi8(%t0: !tt, %out: !tt) -> !tt {
+  // CHECK: scf.forall {{.*}} in (1, 2, 16) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<4x4x1xi8>
+  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 8
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<3x5x7xi8>
+
+// CHECK-LABEL: func @copy_3d_3x5x7xi8
+func.func @copy_3d_3x5x7xi8(%t0: !tt, %out: !tt) -> !tt {
+  // Best effort greedy mapping: first 7, then skip 5 (as 7*5 overflows 32), then
+  // take 3.
+  // DP mapping: 7 mandated most minor, then skip 5  (as 7*5 overflows 32), then
+  // take 3.
+  // CHECK: scf.forall {{.*}} in (3, 1, 7) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<1x5x1xi8>
+  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 8
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<16x15x5xi8>
+
+// CHECK-LABEL: func @copy_3d_16x15x5xi8
+func.func @copy_3d_16x15x5xi8(%t0: !tt, %out: !tt) -> !tt {
+  // DP mapping: 5 mandated most minor, then 3 to allow 8 on the outermost.
+  // CHECK: scf.forall {{.*}} in (8, 3, 5) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<2x5x1xi8>
+  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 128 desired_bit_alignment = 8
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<16x15x40xi8>
+
+// CHECK-LABEL: func @copy_3d_16x15x40xi8
+func.func @copy_3d_16x15x40xi8(%t0: !tt, %out: !tt) -> !tt {
+  // DP mapping: 5 mandated most minor, then 3 to allow 8 on the outermost.
+  // CHECK: scf.forall {{.*}} in (8, 3, 5) {{.*}}
+  // CHECK:   linalg.copy {{.*}} -> tensor<2x5x8xi8>
+  // CHECK: {mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>]}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 128 desired_bit_alignment = 64
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+
+////////////////////////////////////////////////////////////////////////////////
+// Tests below are expected to fail.
+////////////////////////////////////////////////////////////////////////////////
+
+// -----
+
+!tt = tensor<1024xf16>
+
+// NO-CHECK-LABEL-ON-EXPECTED-ERROR
+func.func @copy_1d_1024xf16(%t0: !tt, %out: !tt) -> !tt {
+  /// Too much data for all threads, we do not try to recover here, this is the
+  /// job of higher-level transformations to select better tile sizes and number
+  /// of threads.
+
+  // expected-note @below {{target op}}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{too few threads to map copy op to threads on the most minor dimension, given alignment and vector size constraints}}
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 128
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<257xf16>
+
+// NO-CHECK-LABEL-ON-EXPECTED-ERROR
+func.func @copy_1d_257xf16(%t0: !tt, %out: !tt) -> !tt {
+  /// Too much data for all threads, we do not try to recover here, this is the
+  /// job of higher-level transformations to select better tile sizes and number
+  /// of threads.
+  
+  // expected-note @below {{target op}}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{too few threads to map copy op to threads on the most minor dimension, given alignment and vector size constraints}}
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 128
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<512xi8>
+
+// NO-CHECK-LABEL-ON-EXPECTED-ERROR
+func.func @copy_1d_512xi8(%t0: !tt, %out: !tt) -> !tt {
+  /// Too much data for all threads given the forced alignment to 8b, 
+  /// we do not try to recover here, this is the job of higher-level 
+  /// transformations to select better tile sizes and number of threads.
+  // expected-note @below {{target op}}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{too few threads to map copy op to threads on the most minor dimension, given alignment and vector size constraints}}
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 8
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}
+
+// -----
+
+!tt = tensor<16x32x64xi8>
+
+// NO-CHECK-LABEL-ON-EXPECTED-ERROR
+func.func @copy_3d_16x32x64xi8(%t0: !tt, %out: !tt) -> !tt {
+  /// Too much data for all threads given the forced alignment to 8b, 
+  /// we do not try to recover here, this is the job of higher-level 
+  /// transformations to select better tile sizes and number of threads.
+  // expected-note @below {{target op}}
+  %0 = linalg.copy ins(%t0: !tt) outs(%out: !tt) -> !tt 
+  return %0 : !tt
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{too few threads to map copy op to threads on the most minor dimension, given alignment and vector size constraints}}
+  transform.structured.gpu.map_copy_to_threads %0 
+    total_num_threads = 32 desired_bit_alignment = 8
+      : (!transform.any_op) -> (!transform.op<"linalg.copy">)
+}


        


More information about the Mlir-commits mailing list