[Mlir-commits] [mlir] aafb52d - [mlir][GPUTransforms] NFC - Refactor GPUTransforms.cpp in preparation for improvements.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Mar 14 05:00:10 PDT 2023


Author: Nicolas Vasilache
Date: 2023-03-14T05:00:01-07:00
New Revision: aafb52d7c9226cd9925bd5135309bd02b6e3b59d

URL: https://github.com/llvm/llvm-project/commit/aafb52d7c9226cd9925bd5135309bd02b6e3b59d
DIFF: https://github.com/llvm/llvm-project/commit/aafb52d7c9226cd9925bd5135309bd02b6e3b59d.diff

LOG: [mlir][GPUTransforms] NFC - Refactor GPUTransforms.cpp in preparation for improvements.

Depends on: D145977

Differential Revision: https://reviews.llvm.org/D145980

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/test/Dialect/GPU/transform-gpu-failing.mlir
    mlir/test/Dialect/GPU/transform-gpu.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
index 463736722e229..7c6aa7e069879 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
@@ -33,6 +33,18 @@ class DialectRegistry;
 namespace transform {
 namespace gpu {
 
+/// Map the top level `scf.forall` op to GPU Thread Blocks.
+/// Mapping is one-to-one and the induction variables of `scf.forall` are
+/// rewritten to gpu.block_id according to the thread_dim_apping attribute.
+/// Dynamic, `scf.forall` trip counts are currently not supported.
+/// Dynamic block dim sizes are currently not supported.
+DiagnosedSilenceableFailure mapForallToBlocksImpl(
+    RewriterBase &rewriter, scf::ForallOp forallOp,
+    function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
+        blockIdGenerator,
+    SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
+    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes);
+
 /// Search `scf.forall` ops nested under `target` and map each such op to GPU
 /// threads. Mapping is one-to-one and the induction variables of `scf.forall`
 /// are rewritten to gpu.thread_id according to the thread_dim_mapping
@@ -43,24 +55,12 @@ namespace gpu {
 /// Dynamic block dim sizes are currently not supported.
 DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(
     RewriterBase &rewriter, Operation *target,
-    const SmallVectorImpl<int64_t> &blockDim,
+    const SmallVectorImpl<int64_t> &kernelBlockDims,
     function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
         threadIdGenerator,
     bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
     const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
 
-/// Map the top level `scf.forall` op to GPU Thread Blocks.
-/// Mapping is one-to-one and the induction variables of `scf.forall` are
-/// rewritten to gpu.block_id according to the thread_dim_apping attribute.
-/// Dynamic, `scf.forall` trip counts are currently not supported.
-/// Dynamic block dim sizes are currently not supported.
-DiagnosedSilenceableFailure mapForallToBlocksImpl(
-    RewriterBase &rewriter, scf::ForallOp forallOp,
-    function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
-        blockIdGenerator,
-    SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
-    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes);
-
 /// Find the unique top level scf::ForallOp within a given target op.
 DiagnosedSilenceableFailure
 findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp,

diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 7d129caa3084a..9396aa16f1392 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -652,15 +652,6 @@ def ForallOp : SCF_Op<"forall", [
     /// Checks if the lbs are zeros and steps are ones.
     bool isNormalized();
 
-    /// Helper to sort `values` according to matching `keys`.
-    /// Take a custom `compare` binary comparator which returns true if the first
-    /// element is smaller than the second (i.e. compatible with std::sort).
-    /// This is a helper typically used to sort numThreads values before they are
-    /// mapped to concrete physical dimensions of hardware.
-    static SmallVector<Value> getValuesSortedByKey(
-      ArrayRef<Attribute> keys, ValueRange values,
-      llvm::function_ref<bool(Attribute, Attribute)> compare);
-
     // The ensureTerminator method generated by SingleBlockImplicitTerminator is
     // unaware of the fact that our terminator also needs a region to be
     // well-formed. We override it here to ensure that we do the right thing.

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 2be3e74c62dd2..930bf46fca215 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -117,6 +117,14 @@ std::pair<ArrayAttr, SmallVector<Value>>
 decomposeMixedValues(Builder &b,
                      const SmallVectorImpl<OpFoldResult> &mixedValues);
 
+/// Helper to sort `values` according to matching `keys`.
+SmallVector<Value>
+getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
+                     llvm::function_ref<bool(Attribute, Attribute)> compare);
+SmallVector<OpFoldResult>
+getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
+                     llvm::function_ref<bool(Attribute, Attribute)> compare);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index bd003b844a0b5..93f00fe9cca01 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -17,30 +17,108 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
 
 using namespace mlir;
 using namespace mlir::gpu;
 using namespace mlir::transform;
 
+namespace {
+
+/// Helper type forfunctions that generate ids for the mapping of a scf.forall.
+using IdGeneratorFnType = llvm::function_ref<void(RewriterBase &, scf::ForallOp,
+                                                  SmallVectorImpl<Value> &)>;
+
+struct MappingToGpuHelper {
+  MappingToGpuHelper(SmallVector<DeviceMappingAttrInterface> mappingAttributes,
+                     IdGeneratorFnType idGenerator)
+      : mappingAttributes(mappingAttributes), idGenerator(idGenerator) {}
+
+  SmallVector<DeviceMappingAttrInterface> mappingAttributes;
+  IdGeneratorFnType idGenerator;
+};
+
+struct MappingToGpuBlocksHelper : public MappingToGpuHelper {
+
+  MappingToGpuBlocksHelper(MLIRContext *ctx)
+      : MappingToGpuHelper(
+            SmallVector<DeviceMappingAttrInterface>{
+                GPUBlockMappingAttr::get(ctx, Blocks::DimX),
+                GPUBlockMappingAttr::get(ctx, Blocks::DimY),
+                GPUBlockMappingAttr::get(ctx, Blocks::DimZ)},
+            IdGeneratorFnType{[](RewriterBase &rewriter, scf::ForallOp forallOp,
+                                 SmallVectorImpl<Value> &ids) {
+              OpBuilder::InsertionGuard guard(rewriter);
+              rewriter.setInsertionPoint(forallOp);
+              IndexType indexType = rewriter.getIndexType();
+              auto loc = forallOp->getLoc();
+              ids.assign(
+                  {rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
+                   rewriter.create<BlockIdOp>(loc, indexType, Dimension::y),
+                   rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)});
+            }}) {}
+};
+
+struct MappingToGpuThreadsHelper : public MappingToGpuHelper {
+  MappingToGpuThreadsHelper(MLIRContext *ctx)
+      : MappingToGpuHelper(
+            SmallVector<DeviceMappingAttrInterface>{
+                GPUThreadMappingAttr::get(ctx, Threads::DimX),
+                GPUThreadMappingAttr::get(ctx, Threads::DimY),
+                GPUThreadMappingAttr::get(ctx, Threads::DimZ)},
+            IdGeneratorFnType{[](RewriterBase &rewriter, scf::ForallOp forallOp,
+                                 SmallVectorImpl<Value> &ids) {
+              OpBuilder::InsertionGuard guard(rewriter);
+              rewriter.setInsertionPoint(forallOp);
+              IndexType indexType = rewriter.getIndexType();
+              auto loc = forallOp->getLoc();
+              ids.assign(
+                  {rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
+                   rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
+                   rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)});
+            }}) {}
+};
+
+} // namespace
+
+static DiagnosedSilenceableFailure
+failureHelper(std::optional<TransformOpInterface> transformOp,
+              scf::ForallOp forallOp, const Twine &message) {
+  if (transformOp.has_value())
+    return transformOp->emitSilenceableError() << message;
+  return emitDefiniteFailure(forallOp, message);
+}
+
 /// Check if given mapping attributes are one of the desired attributes
 static DiagnosedSilenceableFailure
-checkAttributeType(ArrayRef<DeviceMappingAttrInterface> threadMappingAttributes,
-                   const std::optional<ArrayAttr> &forallMapping,
-                   std::optional<TransformOpInterface> transformOp) {
-  if (!forallMapping.has_value())
-    return transformOp->emitSilenceableError() << "mapping must be present";
+checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
+                           scf::ForallOp forallOp) {
+  if (!forallOp.getMapping().has_value())
+    return failureHelper(transformOp, forallOp, "mapping must be present");
+
+  bool hasBlockMapping =
+      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
+        return attr.isa<GPUBlockMappingAttr>();
+      });
+  bool hasThreadMapping =
+      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
+        return attr.isa<GPUThreadMappingAttr>();
+      });
+  int64_t countMappingTypes = 0;
+  countMappingTypes += hasBlockMapping ? 1 : 0;
+  countMappingTypes += hasThreadMapping ? 1 : 0;
+  if (countMappingTypes > 1) {
+    return failureHelper(transformOp, forallOp,
+                         "cannot mix 
diff erent mapping types, use nesting");
+  }
 
   DenseSet<Attribute> seen;
-  for (Attribute map : forallMapping->getValue()) {
-    if (!llvm::is_contained(threadMappingAttributes, map)) {
-      return transformOp->emitDefiniteFailure()
-             << "mapping must be one of " << threadMappingAttributes;
-    }
+  for (Attribute map : forallOp.getMapping()->getValue()) {
     if (llvm::is_contained(seen, map)) {
-      return transformOp->emitDefiniteFailure()
-             << map
-             << " is duplicated, cannot map 
diff erent "
-                "loops to the same processor";
+      return failureHelper(transformOp, forallOp,
+                           "duplicated attribute, cannot map 
diff erent loops "
+                           "to the same processor");
     }
     seen.insert(map);
   }
@@ -48,6 +126,34 @@ checkAttributeType(ArrayRef<DeviceMappingAttrInterface> threadMappingAttributes,
   return DiagnosedSilenceableFailure::success();
 }
 
+static DiagnosedSilenceableFailure
+verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
+                 scf::ForallOp forallOp) {
+  // Check the types of the mapping attributes match.
+  DiagnosedSilenceableFailure typeRes =
+      checkMappingAttributeTypes(transformOp, forallOp);
+  if (!typeRes.succeeded())
+    return typeRes;
+
+  // Perform other non-types verifications.
+  if (!forallOp.isNormalized())
+    return failureHelper(transformOp, forallOp,
+                         "unsupported non-normalized loops");
+  if (forallOp.getNumResults() > 0)
+    return failureHelper(transformOp, forallOp,
+                         "only bufferized scf.forall can be mapped");
+  if (forallOp.getRank() > 3)
+    return failureHelper(transformOp, forallOp,
+                         "scf.forall with rank > 3 does not lower");
+  if (llvm::any_of(forallOp.getMixedUpperBound(), [&](OpFoldResult ofr) {
+        return !getConstantIntValue(ofr).has_value();
+      })) {
+    return failureHelper(transformOp, forallOp,
+                         "unsupported dynamic sizes in forall op");
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 /// Determines if the size of the kernel configuration is supported by the GPU
 /// architecture being used. It presently makes use of CUDA limitations, however
 /// that aspect may be enhanced for other GPUs.
@@ -169,44 +275,27 @@ alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch,
 
 DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
     RewriterBase &rewriter, scf::ForallOp forallOp,
-    function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
-        blockIdGenerator,
-    SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
+    IdGeneratorFnType blockIdGenerator, SmallVectorImpl<int64_t> &gridDims,
+    TransformOpInterface transformOp,
     const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
-  // Step 0. Target-specific verifications. There is no good place to anchor
-  // those right now: the ForallOp is target-independent and the
-  // transform op does not apply to individual ForallOp.
-  Location loc = forallOp->getLoc();
 
-  if (!forallOp.isNormalized())
-    return transformOp.emitSilenceableError()
-           << "unsupported non-normalized loops";
-  if (forallOp.getNumResults() > 0)
-    return transformOp.emitSilenceableError()
-           << "only bufferized scf.forall lowers to "
-              "gpu.block_id";
-  if (forallOp.getRank() > 3)
-    return transformOp.emitSilenceableError()
-           << "scf.forall with rank > 3 does not lower to "
-              "gpu.block_id";
-  if (llvm::any_of(forallOp.getMixedUpperBound(), [](OpFoldResult ofr) {
-        return !getConstantIntValue(ofr).has_value();
-      })) {
-    return transformOp.emitSilenceableError()
-           << "unsupported dynamic griddim size";
-  }
+  // Step 0. GPU-specific verifications. There is no better place to anchor
+  // those right now: the ForallOp is target-independent and the transform op
+  // does not apply to individual ForallOp.
+  DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp);
+  if (!diag.succeeded())
+    return diag;
+
   SmallVector<Attribute> blockMapping =
       llvm::to_vector(forallOp.getMapping()->getValue());
 
   // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
-  SmallVector<Value> numBlocks = forallOp.getUpperBound(rewriter);
+  SmallVector<OpFoldResult> numBlocks = forallOp.getMixedUpperBound();
   // Ensure we have 3 block sizes, one for each id.
-  Value one;
   for (auto attr : mappingAttributes) {
     if (!llvm::is_contained(blockMapping, attr)) {
       blockMapping.push_back(attr);
-      one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
-      numBlocks.push_back(one);
+      numBlocks.push_back(rewriter.getIndexAttr(1));
     }
   }
 
@@ -215,12 +304,14 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
                         DeviceMappingAttrInterface b) -> bool {
     return a.getMappingId() < b.getMappingId();
   };
-  SmallVector<Value> gridDimValues =
-      scf::ForallOp::getValuesSortedByKey(blockMapping, numBlocks, comparator);
-  for (Value v : gridDimValues)
-    gridDims.push_back(v.getDefiningOp<arith::ConstantIndexOp>().value());
+  SmallVector<OpFoldResult> gridDimValues =
+      getValuesSortedByKey(blockMapping, numBlocks, comparator);
+  gridDims =
+      llvm::to_vector(llvm::map_range(gridDimValues, [](OpFoldResult ofr) {
+        return getConstantIntValue(ofr).value();
+      }));
 
-  // Step 3. Generate the blockIds using the provided generator and map the
+  // Step 3. Generate the blockids using the provided generator and map the
   // induction variables to the newly created ops.
   SmallVector<Value> blockOps;
   blockIdGenerator(rewriter, forallOp, blockOps);
@@ -273,20 +364,6 @@ mlir::transform::gpu::findTopLevelForallOp(Operation *target,
   return DiagnosedSilenceableFailure::success();
 }
 
-/// This is a helper that is only used in rewriteTopLevelForallToGpuBlocks.
-/// It generates GPU dialect block_id.
-static void createGpuBlockIds(RewriterBase &rewriter, scf::ForallOp forallOp,
-                              SmallVectorImpl<Value> &blockOps) {
-  Location loc = forallOp->getLoc();
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(forallOp);
-  IndexType indexType = rewriter.getIndexType();
-  blockOps = SmallVector<Value>{
-      rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
-      rewriter.create<BlockIdOp>(loc, indexType, Dimension::y),
-      rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)};
-}
-
 DiagnosedSilenceableFailure
 transform::MapForallToBlocks::applyToOne(Operation *target,
                                          ApplyToEachResultList &results,
@@ -312,6 +389,10 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
     return diag;
   }
 
+  SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
+  if (!getGenerateGpuLaunch() && gridDim.size() != 3)
+    return transformOp.emitDefiniteFailure("transform require size-3 mapping");
+
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(topLevelForallOp);
 
@@ -328,23 +409,20 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
     topLevelForallOp = cast<scf::ForallOp>(newForallOp);
   }
 
-  SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
-  SmallVector<DeviceMappingAttrInterface> blockMappingAttributes = {
-      GPUBlockMappingAttr::get(getContext(), Blocks::DimX),
-      GPUBlockMappingAttr::get(getContext(), Blocks::DimY),
-      GPUBlockMappingAttr::get(getContext(), Blocks::DimZ)};
-
-  diag = checkAttributeType(blockMappingAttributes,
-                            topLevelForallOp.getMapping(), transformOp);
-  if (diag.succeeded())
-    diag = mlir::transform::gpu::mapForallToBlocksImpl(
-        rewriter, topLevelForallOp, createGpuBlockIds, gridDim, transformOp,
-        blockMappingAttributes);
-  if (diag.succeeded()) {
-    diag = alterGpuLaunch(rewriter, gpuLaunch,
-                          cast<TransformOpInterface>(getOperation()),
-                          gridDim[0], gridDim[1], gridDim[2]);
-  }
+  diag = verifyGpuMapping(transformOp, topLevelForallOp);
+  if (!diag.succeeded())
+    return diag;
+
+  MappingToGpuBlocksHelper helper(getContext());
+  diag = mlir::transform::gpu::mapForallToBlocksImpl(
+      rewriter, topLevelForallOp, helper.idGenerator, gridDim, transformOp,
+      helper.mappingAttributes);
+  if (!diag.succeeded())
+    return diag;
+
+  diag = alterGpuLaunch(rewriter, gpuLaunch,
+                        cast<TransformOpInterface>(getOperation()), gridDim[0],
+                        gridDim[1], gridDim[2]);
 
   results.push_back(gpuLaunch);
   return diag;
@@ -354,56 +432,32 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
 // MapNestedForallToThreads
 //===----------------------------------------------------------------------===//
 
-/// Searches `scf.forall` ops nested under `target` and maps each such
-/// op to GPU threads. Mapping is one-to-one and the induction variables of
-/// `scf.forall` are rewritten to gpu.thread_id according to the
-/// thread_dim_mapping attribute. Sibling `scf.forall` are supported in
-/// which case, the union of the number of threads is computed and may result
-/// in predication. Dynamic, `scf.forall` trip counts are currently
-/// not supported. Dynamic block dim sizes are currently not supported.
 static DiagnosedSilenceableFailure rewriteOneForallToGpuThreads(
     RewriterBase &rewriter, scf::ForallOp forallOp,
-    const SmallVectorImpl<int64_t> &globalBlockDims,
+    const SmallVectorImpl<int64_t> &kernelBlockDims,
     const SmallVectorImpl<Value> &threadOps, bool syncAfterDistribute,
     std::optional<TransformOpInterface> transformOp,
-    const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
-  // Step 0. Target-specific verifications. There is no good place to anchor
-  // those right now: the ForallOp is target-independent and the
-  // transform op does not apply to individual ForallOp.
-  auto failureHelper =
-      [&](const Twine &message) -> DiagnosedSilenceableFailure {
-    if (transformOp.has_value()) {
-      return transformOp->emitSilenceableError() << message;
-    }
-    return emitDefiniteFailure(forallOp, message);
-  };
+    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
+
+  // Step 0. GPU-specific verifications. There is no better place to anchor
+  // those right now: the ForallOp is target-independent and the transform op
+  // does not apply to individual ForallOp.
+  DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp);
+  if (!diag.succeeded())
+    return diag;
+
   Location loc = forallOp->getLoc();
-  if (!forallOp.isNormalized())
-    return failureHelper("unsupported non-normalized loops");
-  if (forallOp.getNumResults() > 0)
-    return failureHelper("only bufferized scf.forall lowers to gpu.thread_id");
-  if (forallOp.getRank() > 3)
-    return failureHelper(
-        "scf.forall with rank > 3 does not lower to gpu.thread_id");
-  if (llvm::any_of(forallOp.getMixedUpperBound(), [](OpFoldResult ofr) {
-        return !getConstantIntValue(ofr).has_value();
-      })) {
-    return failureHelper("unsupported dynamic blockdim size");
-  }
-  if (!forallOp.getMapping().has_value())
-    return failureHelper("mapping must be present");
-  SmallVector<Attribute> threadMapping =
+
+  SmallVector<Attribute> mapping =
       llvm::to_vector(forallOp.getMapping()->getValue());
 
-  // Step 1. Complete the threadMapping to a full mapping (with 1s) if
+  // Step 1. Complete the mapping to a full mapping (with 1s) if
   // necessary.
-  SmallVector<Value> numThreads = forallOp.getUpperBound(rewriter);
-  // Ensure we have 3 block sizes, one for each id.
-  Value one;
-  for (auto attr : threadMappingAttributes) {
-    if (!llvm::is_contained(threadMapping, attr)) {
-      threadMapping.push_back(attr);
-      one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  SmallVector<OpFoldResult> numThreads = forallOp.getMixedUpperBound();
+  Attribute one = rewriter.getIndexAttr(1);
+  for (auto attr : mappingAttributes) {
+    if (std::find(mapping.begin(), mapping.end(), attr) == mapping.end()) {
+      mapping.push_back(attr);
       numThreads.push_back(one);
     }
   }
@@ -413,27 +467,28 @@ static DiagnosedSilenceableFailure rewriteOneForallToGpuThreads(
                         DeviceMappingAttrInterface b) -> bool {
     return a.getMappingId() < b.getMappingId();
   };
-  SmallVector<Value> blockDimValues = scf::ForallOp::getValuesSortedByKey(
-      threadMapping, numThreads, comparator);
+  SmallVector<OpFoldResult> blockDimValues =
+      getValuesSortedByKey(mapping, numThreads, comparator);
   SmallVector<int64_t> blockDims =
-      llvm::to_vector(llvm::map_range(blockDimValues, [](Value v) {
-        return v.getDefiningOp<arith::ConstantIndexOp>().value();
+      llvm::to_vector(llvm::map_range(blockDimValues, [](OpFoldResult ofr) {
+        return getConstantIntValue(ofr).value();
       }));
 
   // Step 3. Create the gpu.thread ops and map the induction variables to the
   // newly created ops.
   // Replace ids of dimension size 1 by zero to simplify the IR.
+  // TODO
   SmallVector<Value> threadOpsUpdated(threadOps.begin(), threadOps.end());
-  assert(threadOps.size() == globalBlockDims.size());
+  assert(threadOps.size() == kernelBlockDims.size());
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  for (size_t i : llvm::seq(size_t(0), globalBlockDims.size())) {
-    if (globalBlockDims[i] == 1)
+  for (size_t i : llvm::seq(size_t(0), kernelBlockDims.size())) {
+    if (kernelBlockDims[i] == 1)
       threadOpsUpdated[i] = zero;
   }
   IRMapping bvm;
-  for (auto [blockIdx, blockDim] :
-       llvm::zip(forallOp.getInductionVars(), threadMapping)) {
-    bvm.map(blockIdx,
+  for (auto [threadIdx, blockDim] :
+       llvm::zip(forallOp.getInductionVars(), mapping)) {
+    bvm.map(threadIdx,
             threadOpsUpdated[blockDim.cast<DeviceMappingAttrInterface>()
                                  .getMappingId()]);
   }
@@ -441,18 +496,20 @@ static DiagnosedSilenceableFailure rewriteOneForallToGpuThreads(
   // Step 4. Maybe create conditionals to predicate the region.
   Value predicate;
   for (auto [threadId, blockDim, globalBlockDim] :
-       llvm::zip(threadOpsUpdated, blockDims, globalBlockDims)) {
+       llvm::zip(threadOpsUpdated, blockDims, kernelBlockDims)) {
     if (blockDim > globalBlockDim) {
       return failureHelper(
-          "The requested GPU threads are fewer than the number of loop trip "
-          "counts. Try to tile scf.forall before mapping or set "
-          "small blockDim.");
+          transformOp, forallOp,
+          "Trying to map to fewer GPU threads than loop iterations but "
+          "overprovisioning is not yet supported. "
+          "Try additional tiling of the before mapping or map to more "
+          "threads.");
     }
     if (blockDim == globalBlockDim)
       continue;
-    Value blockIdx = rewriter.create<arith::ConstantIndexOp>(loc, blockDim);
+    Value threadIdx = rewriter.create<arith::ConstantIndexOp>(loc, blockDim);
     Value tmpPredicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ult, threadId, blockIdx);
+        loc, arith::CmpIPredicate::ult, threadId, threadIdx);
     predicate =
         predicate ? rewriter.create<arith::AndIOp>(loc, predicate, tmpPredicate)
                   : tmpPredicate;
@@ -497,28 +554,25 @@ static DiagnosedSilenceableFailure rewriteOneForallToGpuThreads(
 
 DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
     RewriterBase &rewriter, Operation *target,
-    const SmallVectorImpl<int64_t> &blockDim,
-    function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
-        threadIdGenerator,
+    const SmallVectorImpl<int64_t> &blockDim, IdGeneratorFnType idGenerator,
     bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
-    const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
+    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
   target->walk([&](scf::ForallOp forallOp) {
     // Ignore cases with 
diff erent attributes.
     for (Attribute map : forallOp.getMapping()->getValue()) {
-      if (!llvm::is_contained(threadMappingAttributes, map)) {
+      if (!llvm::is_contained(mappingAttributes, map)) {
         return WalkResult::skip();
       }
     }
-    diag = checkAttributeType(threadMappingAttributes, forallOp.getMapping(),
-                              transformOp);
+    diag = verifyGpuMapping(transformOp, forallOp);
     if (diag.succeeded()) {
       rewriter.setInsertionPoint(forallOp);
       SmallVector<Value> threadOps;
-      threadIdGenerator(rewriter, forallOp, threadOps);
+      idGenerator(rewriter, forallOp, threadOps);
       diag = rewriteOneForallToGpuThreads(rewriter, forallOp, blockDim,
                                           threadOps, syncAfterDistribute,
-                                          transformOp, threadMappingAttributes);
+                                          transformOp, mappingAttributes);
     }
     return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt();
   });
@@ -530,48 +584,36 @@ DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
   LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
   auto transformOp = cast<TransformOpInterface>(getOperation());
 
-  if (!gpuLaunch) {
-    return emitSilenceableError() << "Given target is not gpu.launch";
-  }
+  // Basic high-level verifications.
+  if (!gpuLaunch)
+    return emitSilenceableError() << "Given target is not a gpu.launch";
 
   SmallVector<int64_t> blockDim = extractFromI64ArrayAttr(getBlockDim());
-  blockDim.resize(/*size=*/3, /*value=*/1);
+  if (blockDim.size() != 3)
+    return transformOp.emitDefiniteFailure("transform require size-3 mapping");
 
   DiagnosedSilenceableFailure diag =
       checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
                      blockDim[0], blockDim[1], blockDim[2]);
   if (diag.isSilenceableFailure()) {
-    diag.attachNote(getLoc()) << getBlockDimAttrName() << " is very large";
+    diag.attachNote(getLoc()) << getBlockDimAttrName() << " is too large";
     return diag;
   }
 
   MLIRContext *ctx = getContext();
   IRRewriter rewriter(ctx);
   rewriter.setInsertionPoint(target);
-
-  SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
-      GPUThreadMappingAttr::get(ctx, Threads::DimX),
-      GPUThreadMappingAttr::get(ctx, Threads::DimY),
-      GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
-  auto threadIdGenerator = [](RewriterBase &rewriter, scf::ForallOp forallOp,
-                              SmallVectorImpl<Value> &threadIds) {
-    IndexType indexType = rewriter.getIndexType();
-    threadIds.assign({rewriter.create<ThreadIdOp>(forallOp->getLoc(), indexType,
-                                                  Dimension::x),
-                      rewriter.create<ThreadIdOp>(forallOp->getLoc(), indexType,
-                                                  Dimension::y),
-                      rewriter.create<ThreadIdOp>(forallOp->getLoc(), indexType,
-                                                  Dimension::z)});
-  };
+  MappingToGpuThreadsHelper helper(ctx);
   diag = mlir::transform::gpu::mapNestedForallToThreadsImpl(
-      rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute(),
-      transformOp, threadMappingAttributes);
+      rewriter, target, blockDim, helper.idGenerator, getSyncAfterDistribute(),
+      transformOp, helper.mappingAttributes);
 
-  if (diag.succeeded()) {
-    diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
-                          std::nullopt, std::nullopt, blockDim[0], blockDim[1],
-                          blockDim[2]);
-  }
+  if (!diag.succeeded())
+    return diag;
+
+  diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
+                        std::nullopt, std::nullopt, blockDim[0], blockDim[1],
+                        blockDim[2]);
 
   results.push_back(gpuLaunch.getOperation());
   return diag;

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ed5e1bf04c5dc..485b60d6c6699 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1391,23 +1391,6 @@ InParallelOp ForallOp::getTerminator() {
   return cast<InParallelOp>(getBody()->getTerminator());
 }
 
-/// Helper to sort `values` according to matching `keys`.
-SmallVector<Value> ForallOp::getValuesSortedByKey(
-    ArrayRef<Attribute> keys, ValueRange values,
-    llvm::function_ref<bool(Attribute, Attribute)> compare) {
-  if (keys.empty())
-    return values;
-  assert(keys.size() == values.size() && "unexpected mismatching sizes");
-  auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
-  std::sort(indices.begin(), indices.end(),
-            [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
-  SmallVector<Value> res;
-  res.reserve(values.size());
-  for (int64_t i = 0, e = indices.size(); i < e; ++i)
-    res.push_back(values[indices[i]]);
-  return res;
-}
-
 ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
   auto tidxArg = val.dyn_cast<BlockArgument>();
   if (!tidxArg)

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index bf80acb754dd7..907a8c1e4914d 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -89,7 +89,7 @@ OpFoldResult getAsOpFoldResult(Value val) {
 /// Given an array of values, try to extract a constant Attribute from each
 /// value. If this fails, return the original value.
 SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) {
-  return llvm::to_vector<4>(
+  return llvm::to_vector(
       llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
 }
 
@@ -108,7 +108,7 @@ OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) {
 
 SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
                                                  ArrayRef<int64_t> values) {
-  return llvm::to_vector<4>(llvm::map_range(
+  return llvm::to_vector(llvm::map_range(
       values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
 }
 
@@ -192,4 +192,34 @@ decomposeMixedValues(Builder &b,
   return {b.getI64ArrayAttr(staticValues), dynamicValues};
 }
 
+/// Helper to sort `values` according to matching `keys`.
+template <typename K, typename V>
+static SmallVector<V>
+getValuesSortedByKeyImpl(ArrayRef<K> keys, ArrayRef<V> values,
+                         llvm::function_ref<bool(K, K)> compare) {
+  if (keys.empty())
+    return SmallVector<V>{values};
+  assert(keys.size() == values.size() && "unexpected mismatching sizes");
+  auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
+  std::sort(indices.begin(), indices.end(),
+            [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
+  SmallVector<V> res;
+  res.reserve(values.size());
+  for (int64_t i = 0, e = indices.size(); i < e; ++i)
+    res.push_back(values[indices[i]]);
+  return res;
+}
+
+SmallVector<Value>
+getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
+                     llvm::function_ref<bool(Attribute, Attribute)> compare) {
+  return getValuesSortedByKeyImpl(keys, values, compare);
+}
+
+SmallVector<OpFoldResult>
+getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
+                     llvm::function_ref<bool(Attribute, Attribute)> compare) {
+  return getValuesSortedByKeyImpl(keys, values, compare);
+}
+
 } // namespace mlir

diff  --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
index 45fbf7695a446..50f49727d3e68 100644
--- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
@@ -7,7 +7,7 @@ func.func @map_nested_forall_to_threads_not_gpu_launch() -> () {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["tensor.empty"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{Given target is not gpu.launch}}
+  // expected-error @below {{Given target is not a gpu.launch}}
   %1 = transform.gpu.map_nested_forall_to_threads %funcop
 }
 
@@ -48,7 +48,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{Trying to launch a GPU kernel with gridDim = (1, 1, 1) blockDim = (1200, 9, 1). It is larger than the limits.}}
-  // expected-note @below {{"blockDim" is very large}}
+  // expected-note @below {{"blockDim" is too large}}
   transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [1200, 9, 1] }
 }
 
@@ -89,7 +89,7 @@ func.func @map_nested_forall_to_threads_fewer_threads(%x: memref<2 x 32 x f32>,
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{The requested GPU threads are fewer than the number of loop trip counts. Try to tile scf.forall before mapping or set small blockDim.}}
+  // expected-error @below {{Trying to map to fewer GPU threads than loop iterations but overprovisioning is not yet supported. Try additional tiling of the before mapping or map to more threads.}}
   transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] }
 }
 
@@ -115,7 +115,7 @@ func.func @map_nested_forall_to_threads_dynamic_trip_count(%x: memref<2 x 32 x f
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{unsupported dynamic blockdim size}}
+  // expected-error @below {{unsupported dynamic sizes}}
   transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] }
 }
 
@@ -137,7 +137,7 @@ transform.sequence failures(propagate) {
   %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   %forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{only bufferized scf.forall lowers to gpu.thread_id}}
+  // expected-error @below {{only bufferized scf.forall can be mapped}}
   transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] }
 }
 
@@ -270,8 +270,8 @@ func.func @saxpy2d_singleloop(%x: !type, %y: !type, %stream : !gpu.async.token)
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{#gpu.thread<x> is duplicated, cannot map 
diff erent loops to the same processor}}
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [32, 32]}
+  // expected-error @below {{duplicated attribute, cannot map 
diff erent loops to the same processor}}
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [32, 32, 1]}
 }
 
 // -----

diff  --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index 107d1fe1ff7d3..447ff1597657d 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -33,7 +33,7 @@ func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_forall_to_blocks %funcop { gridDim = [12, 9]}
+  transform.gpu.map_forall_to_blocks %funcop { gridDim = [12, 9, 1]}
 }
 
 // -----
@@ -87,7 +87,7 @@ func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9] }
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9, 1] }
 }
 
 // -----
@@ -192,7 +192,7 @@ func.func @saxpy2d_singleloop(%x: !type, %y: !type, %stream : !gpu.async.token)
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [32]}
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [32, 1, 1]}
 }
 
 // -----
@@ -267,5 +267,5 @@ func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %str
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9] }
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9, 1] }
 }


        


More information about the Mlir-commits mailing list